mirror of https://github.com/langgenius/dify.git
feat: knowledge pipeline (#25360)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: jyong <718720800@qq.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: quicksand <quicksandzn@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Hanqing Zhao <sherry9277@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry <xh001x@hotmail.com>
This commit is contained in:
parent
7dadb33003
commit
85cda47c70
|
|
@ -8,6 +8,8 @@ on:
|
||||||
- "deploy/enterprise"
|
- "deploy/enterprise"
|
||||||
- "build/**"
|
- "build/**"
|
||||||
- "release/e-*"
|
- "release/e-*"
|
||||||
|
- "deploy/rag-dev"
|
||||||
|
- "feat/rag-2"
|
||||||
tags:
|
tags:
|
||||||
- "*"
|
- "*"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ on:
|
||||||
workflow_run:
|
workflow_run:
|
||||||
workflows: ["Build and Push API & Web"]
|
workflows: ["Build and Push API & Web"]
|
||||||
branches:
|
branches:
|
||||||
- "deploy/dev"
|
- "deploy/rag-dev"
|
||||||
types:
|
types:
|
||||||
- completed
|
- completed
|
||||||
|
|
||||||
|
|
@ -12,12 +12,13 @@ jobs:
|
||||||
deploy:
|
deploy:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: |
|
if: |
|
||||||
github.event.workflow_run.conclusion == 'success'
|
github.event.workflow_run.conclusion == 'success' &&
|
||||||
|
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||||
steps:
|
steps:
|
||||||
- name: Deploy to server
|
- name: Deploy to server
|
||||||
uses: appleboy/ssh-action@v0.1.8
|
uses: appleboy/ssh-action@v0.1.8
|
||||||
with:
|
with:
|
||||||
host: ${{ secrets.SSH_HOST }}
|
host: ${{ secrets.RAG_SSH_HOST }}
|
||||||
username: ${{ secrets.SSH_USER }}
|
username: ${{ secrets.SSH_USER }}
|
||||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
script: |
|
script: |
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ permissions:
|
||||||
statuses: write
|
statuses: write
|
||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
python-style:
|
python-style:
|
||||||
name: Python Style
|
name: Python Style
|
||||||
|
|
@ -44,6 +43,10 @@ jobs:
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: uv sync --project api --dev
|
run: uv sync --project api --dev
|
||||||
|
|
||||||
|
- name: Run Import Linter
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
run: uv run --directory api --dev lint-imports
|
||||||
|
|
||||||
- name: Run Basedpyright Checks
|
- name: Run Basedpyright Checks
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: dev/basedpyright-check
|
run: dev/basedpyright-check
|
||||||
|
|
|
||||||
|
|
@ -461,6 +461,16 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
||||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||||
MAX_VARIABLE_SIZE=204800
|
MAX_VARIABLE_SIZE=204800
|
||||||
|
|
||||||
|
# GraphEngine Worker Pool Configuration
|
||||||
|
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||||
|
GRAPH_ENGINE_MIN_WORKERS=1
|
||||||
|
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||||
|
GRAPH_ENGINE_MAX_WORKERS=10
|
||||||
|
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||||
|
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
|
||||||
|
# Seconds of idle time before scaling down workers (default: 5.0)
|
||||||
|
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
|
||||||
|
|
||||||
# Workflow storage configuration
|
# Workflow storage configuration
|
||||||
# Options: rdbms, hybrid
|
# Options: rdbms, hybrid
|
||||||
# rdbms: Use only the relational database (default)
|
# rdbms: Use only the relational database (default)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,105 @@
|
||||||
|
[importlinter]
|
||||||
|
root_packages =
|
||||||
|
core
|
||||||
|
configs
|
||||||
|
controllers
|
||||||
|
models
|
||||||
|
tasks
|
||||||
|
services
|
||||||
|
|
||||||
|
[importlinter:contract:workflow]
|
||||||
|
name = Workflow
|
||||||
|
type=layers
|
||||||
|
layers =
|
||||||
|
graph_engine
|
||||||
|
graph_events
|
||||||
|
graph
|
||||||
|
nodes
|
||||||
|
node_events
|
||||||
|
entities
|
||||||
|
containers =
|
||||||
|
core.workflow
|
||||||
|
ignore_imports =
|
||||||
|
core.workflow.nodes.base.node -> core.workflow.graph_events
|
||||||
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
||||||
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
||||||
|
|
||||||
|
core.workflow.nodes.node_factory -> core.workflow.graph
|
||||||
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||||
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||||
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
||||||
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||||
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||||
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||||
|
|
||||||
|
[importlinter:contract:rsc]
|
||||||
|
name = RSC
|
||||||
|
type = layers
|
||||||
|
layers =
|
||||||
|
graph_engine
|
||||||
|
response_coordinator
|
||||||
|
containers =
|
||||||
|
core.workflow.graph_engine
|
||||||
|
|
||||||
|
[importlinter:contract:worker]
|
||||||
|
name = Worker
|
||||||
|
type = layers
|
||||||
|
layers =
|
||||||
|
graph_engine
|
||||||
|
worker
|
||||||
|
containers =
|
||||||
|
core.workflow.graph_engine
|
||||||
|
|
||||||
|
[importlinter:contract:graph-engine-architecture]
|
||||||
|
name = Graph Engine Architecture
|
||||||
|
type = layers
|
||||||
|
layers =
|
||||||
|
graph_engine
|
||||||
|
orchestration
|
||||||
|
command_processing
|
||||||
|
event_management
|
||||||
|
error_handler
|
||||||
|
graph_traversal
|
||||||
|
graph_state_manager
|
||||||
|
worker_management
|
||||||
|
domain
|
||||||
|
containers =
|
||||||
|
core.workflow.graph_engine
|
||||||
|
|
||||||
|
[importlinter:contract:domain-isolation]
|
||||||
|
name = Domain Model Isolation
|
||||||
|
type = forbidden
|
||||||
|
source_modules =
|
||||||
|
core.workflow.graph_engine.domain
|
||||||
|
forbidden_modules =
|
||||||
|
core.workflow.graph_engine.worker_management
|
||||||
|
core.workflow.graph_engine.command_channels
|
||||||
|
core.workflow.graph_engine.layers
|
||||||
|
core.workflow.graph_engine.protocols
|
||||||
|
|
||||||
|
[importlinter:contract:worker-management]
|
||||||
|
name = Worker Management
|
||||||
|
type = forbidden
|
||||||
|
source_modules =
|
||||||
|
core.workflow.graph_engine.worker_management
|
||||||
|
forbidden_modules =
|
||||||
|
core.workflow.graph_engine.orchestration
|
||||||
|
core.workflow.graph_engine.command_processing
|
||||||
|
core.workflow.graph_engine.event_management
|
||||||
|
|
||||||
|
|
||||||
|
[importlinter:contract:graph-traversal-components]
|
||||||
|
name = Graph Traversal Components
|
||||||
|
type = layers
|
||||||
|
layers =
|
||||||
|
edge_processor
|
||||||
|
skip_propagator
|
||||||
|
containers =
|
||||||
|
core.workflow.graph_engine.graph_traversal
|
||||||
|
|
||||||
|
[importlinter:contract:command-channels]
|
||||||
|
name = Command Channels Independence
|
||||||
|
type = independence
|
||||||
|
modules =
|
||||||
|
core.workflow.graph_engine.command_channels.in_memory_channel
|
||||||
|
core.workflow.graph_engine.command_channels.redis_channel
|
||||||
27
api/app.py
27
api/app.py
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -17,20 +16,20 @@ else:
|
||||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||||
# so we need to disable gevent in debug mode.
|
# so we need to disable gevent in debug mode.
|
||||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||||
from gevent import monkey
|
# from gevent import monkey
|
||||||
|
#
|
||||||
|
# # gevent
|
||||||
|
# monkey.patch_all()
|
||||||
|
#
|
||||||
|
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||||
|
#
|
||||||
|
# # grpc gevent
|
||||||
|
# grpc_gevent.init_gevent()
|
||||||
|
|
||||||
# gevent
|
# import psycogreen.gevent # type: ignore
|
||||||
monkey.patch_all()
|
#
|
||||||
|
# psycogreen.gevent.patch_psycopg()
|
||||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
|
||||||
|
|
||||||
# grpc gevent
|
|
||||||
grpc_gevent.init_gevent()
|
|
||||||
|
|
||||||
import psycogreen.gevent # type: ignore
|
|
||||||
|
|
||||||
psycogreen.gevent.patch_psycopg()
|
|
||||||
|
|
||||||
from app_factory import create_app
|
from app_factory import create_app
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import psycogreen.gevent as pscycogreen_gevent # type: ignore
|
||||||
|
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _log(message: str):
|
||||||
|
print(message, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
# grpc gevent
|
||||||
|
grpc_gevent.init_gevent()
|
||||||
|
_log("gRPC patched with gevent.")
|
||||||
|
pscycogreen_gevent.patch_psycopg()
|
||||||
|
_log("psycopg2 patched with gevent.")
|
||||||
|
|
||||||
|
|
||||||
|
from app import app, celery
|
||||||
|
|
||||||
|
__all__ = ["app", "celery"]
|
||||||
268
api/commands.py
268
api/commands.py
|
|
@ -1,7 +1,6 @@
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import operator
|
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -14,11 +13,13 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from core.plugin.entities.plugin import ToolProviderID
|
from core.helper import encrypter
|
||||||
|
from core.plugin.impl.plugin import PluginInstaller
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
|
from core.tools.entities.tool_entities import CredentialType
|
||||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||||
from events.app_event import app_was_created
|
from events.app_event import app_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
@ -31,12 +32,16 @@ from models import Tenant
|
||||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||||
|
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||||
from models.provider import Provider, ProviderModel
|
from models.provider import Provider, ProviderModel
|
||||||
|
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||||
|
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||||
from models.tools import ToolOAuthSystemClient
|
from models.tools import ToolOAuthSystemClient
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||||
from services.plugin.data_migration import PluginDataMigration
|
from services.plugin.data_migration import PluginDataMigration
|
||||||
from services.plugin.plugin_migration import PluginMigration
|
from services.plugin.plugin_migration import PluginMigration
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -1246,15 +1251,17 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||||
|
|
||||||
def _count_orphaned_draft_variables() -> dict[str, Any]:
|
def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Count orphaned draft variables by app.
|
Count orphaned draft variables by app, including associated file counts.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with statistics about orphaned variables
|
Dictionary with statistics about orphaned variables and files
|
||||||
"""
|
"""
|
||||||
query = """
|
# Count orphaned variables by app
|
||||||
|
variables_query = """
|
||||||
SELECT
|
SELECT
|
||||||
wdv.app_id,
|
wdv.app_id,
|
||||||
COUNT(*) as variable_count
|
COUNT(*) as variable_count,
|
||||||
|
COUNT(wdv.file_id) as file_count
|
||||||
FROM workflow_draft_variables AS wdv
|
FROM workflow_draft_variables AS wdv
|
||||||
WHERE NOT EXISTS(
|
WHERE NOT EXISTS(
|
||||||
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
|
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
|
||||||
|
|
@ -1264,14 +1271,21 @@ def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with db.engine.connect() as conn:
|
with db.engine.connect() as conn:
|
||||||
result = conn.execute(sa.text(query))
|
result = conn.execute(sa.text(variables_query))
|
||||||
orphaned_by_app = {row[0]: row[1] for row in result}
|
orphaned_by_app = {}
|
||||||
|
total_files = 0
|
||||||
|
|
||||||
total_orphaned = sum(orphaned_by_app.values())
|
for row in result:
|
||||||
|
app_id, variable_count, file_count = row
|
||||||
|
orphaned_by_app[app_id] = {"variables": variable_count, "files": file_count}
|
||||||
|
total_files += file_count
|
||||||
|
|
||||||
|
total_orphaned = sum(app_data["variables"] for app_data in orphaned_by_app.values())
|
||||||
app_count = len(orphaned_by_app)
|
app_count = len(orphaned_by_app)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_orphaned_variables": total_orphaned,
|
"total_orphaned_variables": total_orphaned,
|
||||||
|
"total_orphaned_files": total_files,
|
||||||
"orphaned_app_count": app_count,
|
"orphaned_app_count": app_count,
|
||||||
"orphaned_by_app": orphaned_by_app,
|
"orphaned_by_app": orphaned_by_app,
|
||||||
}
|
}
|
||||||
|
|
@ -1300,6 +1314,7 @@ def cleanup_orphaned_draft_variables(
|
||||||
stats = _count_orphaned_draft_variables()
|
stats = _count_orphaned_draft_variables()
|
||||||
|
|
||||||
logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"])
|
logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"])
|
||||||
|
logger.info("Found %s associated offload files", stats["total_orphaned_files"])
|
||||||
logger.info("Across %s non-existent apps", stats["orphaned_app_count"])
|
logger.info("Across %s non-existent apps", stats["orphaned_app_count"])
|
||||||
|
|
||||||
if stats["total_orphaned_variables"] == 0:
|
if stats["total_orphaned_variables"] == 0:
|
||||||
|
|
@ -1308,10 +1323,10 @@ def cleanup_orphaned_draft_variables(
|
||||||
|
|
||||||
if dry_run:
|
if dry_run:
|
||||||
logger.info("DRY RUN: Would delete the following:")
|
logger.info("DRY RUN: Would delete the following:")
|
||||||
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=operator.itemgetter(1), reverse=True)[
|
for app_id, data in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1]["variables"], reverse=True)[
|
||||||
:10
|
:10
|
||||||
]: # Show top 10
|
]: # Show top 10
|
||||||
logger.info(" App %s: %s variables", app_id, count)
|
logger.info(" App %s: %s variables, %s files", app_id, data["variables"], data["files"])
|
||||||
if len(stats["orphaned_by_app"]) > 10:
|
if len(stats["orphaned_by_app"]) > 10:
|
||||||
logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10)
|
logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10)
|
||||||
return
|
return
|
||||||
|
|
@ -1320,7 +1335,8 @@ def cleanup_orphaned_draft_variables(
|
||||||
if not force:
|
if not force:
|
||||||
click.confirm(
|
click.confirm(
|
||||||
f"Are you sure you want to delete {stats['total_orphaned_variables']} "
|
f"Are you sure you want to delete {stats['total_orphaned_variables']} "
|
||||||
f"orphaned draft variables from {stats['orphaned_app_count']} apps?",
|
f"orphaned draft variables and {stats['total_orphaned_files']} associated files "
|
||||||
|
f"from {stats['orphaned_app_count']} apps?",
|
||||||
abort=True,
|
abort=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1353,3 +1369,231 @@ def cleanup_orphaned_draft_variables(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)
|
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)
|
||||||
|
|
||||||
|
|
||||||
|
@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.")
|
||||||
|
@click.option("--provider", prompt=True, help="Provider name")
|
||||||
|
@click.option("--client-params", prompt=True, help="Client Params")
|
||||||
|
def setup_datasource_oauth_client(provider, client_params):
|
||||||
|
"""
|
||||||
|
Setup datasource oauth client
|
||||||
|
"""
|
||||||
|
provider_id = DatasourceProviderID(provider)
|
||||||
|
provider_name = provider_id.provider_name
|
||||||
|
plugin_id = provider_id.plugin_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
# json validate
|
||||||
|
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||||
|
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||||
|
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||||
|
return
|
||||||
|
|
||||||
|
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
||||||
|
deleted_count = (
|
||||||
|
db.session.query(DatasourceOauthParamConfig)
|
||||||
|
.filter_by(
|
||||||
|
provider=provider_name,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
)
|
||||||
|
.delete()
|
||||||
|
)
|
||||||
|
if deleted_count > 0:
|
||||||
|
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||||
|
|
||||||
|
click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow"))
|
||||||
|
oauth_client = DatasourceOauthParamConfig(
|
||||||
|
provider=provider_name,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
system_credentials=client_params_dict,
|
||||||
|
)
|
||||||
|
db.session.add(oauth_client)
|
||||||
|
db.session.commit()
|
||||||
|
click.echo(click.style(f"provider: {provider_name}", fg="green"))
|
||||||
|
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
|
||||||
|
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
|
||||||
|
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
|
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||||
|
def transform_datasource_credentials():
|
||||||
|
"""
|
||||||
|
Transform datasource credentials
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
installer_manager = PluginInstaller()
|
||||||
|
plugin_migration = PluginMigration()
|
||||||
|
|
||||||
|
notion_plugin_id = "langgenius/notion_datasource"
|
||||||
|
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||||
|
jina_plugin_id = "langgenius/jina_datasource"
|
||||||
|
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||||
|
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||||
|
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||||
|
oauth_credential_type = CredentialType.OAUTH2
|
||||||
|
api_key_credential_type = CredentialType.API_KEY
|
||||||
|
|
||||||
|
# deal notion credentials
|
||||||
|
deal_notion_count = 0
|
||||||
|
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||||
|
if notion_credentials:
|
||||||
|
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||||
|
for notion_credential in notion_credentials:
|
||||||
|
tenant_id = notion_credential.tenant_id
|
||||||
|
if tenant_id not in notion_credentials_tenant_mapping:
|
||||||
|
notion_credentials_tenant_mapping[tenant_id] = []
|
||||||
|
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||||
|
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||||
|
# check notion plugin is installed
|
||||||
|
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||||
|
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||||
|
if notion_plugin_id not in installed_plugins_ids:
|
||||||
|
if notion_plugin_unique_identifier:
|
||||||
|
# install notion plugin
|
||||||
|
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||||
|
auth_count = 0
|
||||||
|
for notion_tenant_credential in notion_tenant_credentials:
|
||||||
|
auth_count += 1
|
||||||
|
# get credential oauth params
|
||||||
|
access_token = notion_tenant_credential.access_token
|
||||||
|
# notion info
|
||||||
|
notion_info = notion_tenant_credential.source_info
|
||||||
|
workspace_id = notion_info.get("workspace_id")
|
||||||
|
workspace_name = notion_info.get("workspace_name")
|
||||||
|
workspace_icon = notion_info.get("workspace_icon")
|
||||||
|
new_credentials = {
|
||||||
|
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||||
|
"workspace_id": workspace_id,
|
||||||
|
"workspace_name": workspace_name,
|
||||||
|
"workspace_icon": workspace_icon,
|
||||||
|
}
|
||||||
|
datasource_provider = DatasourceProvider(
|
||||||
|
provider="notion_datasource",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=notion_plugin_id,
|
||||||
|
auth_type=oauth_credential_type.value,
|
||||||
|
encrypted_credentials=new_credentials,
|
||||||
|
name=f"Auth {auth_count}",
|
||||||
|
avatar_url=workspace_icon or "default",
|
||||||
|
is_default=False,
|
||||||
|
)
|
||||||
|
db.session.add(datasource_provider)
|
||||||
|
deal_notion_count += 1
|
||||||
|
db.session.commit()
|
||||||
|
# deal firecrawl credentials
|
||||||
|
deal_firecrawl_count = 0
|
||||||
|
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||||
|
if firecrawl_credentials:
|
||||||
|
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||||
|
for firecrawl_credential in firecrawl_credentials:
|
||||||
|
tenant_id = firecrawl_credential.tenant_id
|
||||||
|
if tenant_id not in firecrawl_credentials_tenant_mapping:
|
||||||
|
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||||
|
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||||
|
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||||
|
# check firecrawl plugin is installed
|
||||||
|
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||||
|
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||||
|
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||||
|
if firecrawl_plugin_unique_identifier:
|
||||||
|
# install firecrawl plugin
|
||||||
|
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||||
|
|
||||||
|
auth_count = 0
|
||||||
|
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||||
|
auth_count += 1
|
||||||
|
# get credential api key
|
||||||
|
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||||
|
api_key = credentials_json.get("config", {}).get("api_key")
|
||||||
|
base_url = credentials_json.get("config", {}).get("base_url")
|
||||||
|
new_credentials = {
|
||||||
|
"firecrawl_api_key": api_key,
|
||||||
|
"base_url": base_url,
|
||||||
|
}
|
||||||
|
datasource_provider = DatasourceProvider(
|
||||||
|
provider="firecrawl",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=firecrawl_plugin_id,
|
||||||
|
auth_type=api_key_credential_type.value,
|
||||||
|
encrypted_credentials=new_credentials,
|
||||||
|
name=f"Auth {auth_count}",
|
||||||
|
avatar_url="default",
|
||||||
|
is_default=False,
|
||||||
|
)
|
||||||
|
db.session.add(datasource_provider)
|
||||||
|
deal_firecrawl_count += 1
|
||||||
|
db.session.commit()
|
||||||
|
# deal jina credentials
|
||||||
|
deal_jina_count = 0
|
||||||
|
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
|
||||||
|
if jina_credentials:
|
||||||
|
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||||
|
for jina_credential in jina_credentials:
|
||||||
|
tenant_id = jina_credential.tenant_id
|
||||||
|
if tenant_id not in jina_credentials_tenant_mapping:
|
||||||
|
jina_credentials_tenant_mapping[tenant_id] = []
|
||||||
|
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||||
|
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||||
|
# check jina plugin is installed
|
||||||
|
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||||
|
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||||
|
if jina_plugin_id not in installed_plugins_ids:
|
||||||
|
if jina_plugin_unique_identifier:
|
||||||
|
# install jina plugin
|
||||||
|
print(jina_plugin_unique_identifier)
|
||||||
|
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||||
|
|
||||||
|
auth_count = 0
|
||||||
|
for jina_tenant_credential in jina_tenant_credentials:
|
||||||
|
auth_count += 1
|
||||||
|
# get credential api key
|
||||||
|
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||||
|
api_key = credentials_json.get("config", {}).get("api_key")
|
||||||
|
new_credentials = {
|
||||||
|
"integration_secret": api_key,
|
||||||
|
}
|
||||||
|
datasource_provider = DatasourceProvider(
|
||||||
|
provider="jina",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=jina_plugin_id,
|
||||||
|
auth_type=api_key_credential_type.value,
|
||||||
|
encrypted_credentials=new_credentials,
|
||||||
|
name=f"Auth {auth_count}",
|
||||||
|
avatar_url="default",
|
||||||
|
is_default=False,
|
||||||
|
)
|
||||||
|
db.session.add(datasource_provider)
|
||||||
|
deal_jina_count += 1
|
||||||
|
db.session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||||
|
return
|
||||||
|
click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green"))
|
||||||
|
click.echo(
|
||||||
|
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
|
||||||
|
)
|
||||||
|
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
|
@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.")
|
||||||
|
@click.option(
|
||||||
|
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
|
||||||
|
)
|
||||||
|
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
|
||||||
|
def install_rag_pipeline_plugins(input_file, output_file, workers):
|
||||||
|
"""
|
||||||
|
Install rag pipeline plugins
|
||||||
|
"""
|
||||||
|
click.echo(click.style("Installing rag pipeline plugins", fg="yellow"))
|
||||||
|
plugin_migration = PluginMigration()
|
||||||
|
plugin_migration.install_rag_pipeline_plugins(
|
||||||
|
input_file,
|
||||||
|
output_file,
|
||||||
|
workers,
|
||||||
|
)
|
||||||
|
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
from .app_config import DifyConfig
|
from .app_config import DifyConfig
|
||||||
|
|
||||||
dify_config = DifyConfig()
|
dify_config = DifyConfig() # type: ignore
|
||||||
|
|
|
||||||
|
|
@ -505,6 +505,22 @@ class UpdateConfig(BaseSettings):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowVariableTruncationConfig(BaseSettings):
|
||||||
|
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
||||||
|
# 100KB
|
||||||
|
1024_000,
|
||||||
|
description="Maximum size for variable to trigger final truncation.",
|
||||||
|
)
|
||||||
|
WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH: PositiveInt = Field(
|
||||||
|
100000,
|
||||||
|
description="maximum length for string to trigger tuncation, measure in number of characters",
|
||||||
|
)
|
||||||
|
WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH: PositiveInt = Field(
|
||||||
|
1000,
|
||||||
|
description="maximum length for array to trigger truncation.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowConfig(BaseSettings):
|
class WorkflowConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Configuration for workflow execution
|
Configuration for workflow execution
|
||||||
|
|
@ -535,6 +551,28 @@ class WorkflowConfig(BaseSettings):
|
||||||
default=200 * 1024,
|
default=200 * 1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# GraphEngine Worker Pool Configuration
|
||||||
|
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||||
|
description="Minimum number of workers per GraphEngine instance",
|
||||||
|
default=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||||
|
description="Maximum number of workers per GraphEngine instance",
|
||||||
|
default=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
|
||||||
|
description="Queue depth threshold that triggers worker scale up",
|
||||||
|
default=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
|
||||||
|
description="Seconds of idle time before scaling down workers",
|
||||||
|
default=5.0,
|
||||||
|
ge=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecutionConfig(BaseSettings):
|
class WorkflowNodeExecutionConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
|
|
@ -1041,5 +1079,6 @@ class FeatureConfig(
|
||||||
CeleryBeatConfig,
|
CeleryBeatConfig,
|
||||||
CeleryScheduleTasksConfig,
|
CeleryScheduleTasksConfig,
|
||||||
WorkflowLogConfig,
|
WorkflowLogConfig,
|
||||||
|
WorkflowVariableTruncationConfig,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -220,11 +220,28 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HostedFetchPipelineTemplateConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration for fetching pipeline templates
|
||||||
|
"""
|
||||||
|
|
||||||
|
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
||||||
|
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
||||||
|
default="remote",
|
||||||
|
)
|
||||||
|
|
||||||
|
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||||
|
description="Domain for fetching remote pipeline templates",
|
||||||
|
default="https://tmpl.dify.ai",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HostedServiceConfig(
|
class HostedServiceConfig(
|
||||||
# place the configs in alphabet order
|
# place the configs in alphabet order
|
||||||
HostedAnthropicConfig,
|
HostedAnthropicConfig,
|
||||||
HostedAzureOpenAiConfig,
|
HostedAzureOpenAiConfig,
|
||||||
HostedFetchAppTemplateConfig,
|
HostedFetchAppTemplateConfig,
|
||||||
|
HostedFetchPipelineTemplateConfig,
|
||||||
HostedMinmaxConfig,
|
HostedMinmaxConfig,
|
||||||
HostedOpenAiConfig,
|
HostedOpenAiConfig,
|
||||||
HostedSparkConfig,
|
HostedSparkConfig,
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ def no_key_cache_key(namespace: str, key: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
# Returns whether the obtained value is obtained, and None if it does not
|
# Returns whether the obtained value is obtained, and None if it does not
|
||||||
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
|
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any:
|
||||||
if namespace_cache:
|
if namespace_cache:
|
||||||
kv_data = namespace_cache.get(CONFIGURATIONS)
|
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||||
if kv_data is None:
|
if kv_data is None:
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
||||||
from contexts.wrapper import RecyclableContextVar
|
from contexts.wrapper import RecyclableContextVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
|
|
@ -32,3 +33,11 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
|
||||||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||||
ContextVar("plugin_model_schemas")
|
ContextVar("plugin_model_schemas")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
|
||||||
|
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
|
||||||
|
)
|
||||||
|
|
||||||
|
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||||
|
ContextVar("datasource_plugin_providers_lock")
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,7 @@ from . import (
|
||||||
init_validate,
|
init_validate,
|
||||||
ping,
|
ping,
|
||||||
setup,
|
setup,
|
||||||
|
spec,
|
||||||
version,
|
version,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -114,6 +115,15 @@ from .datasets import (
|
||||||
metadata,
|
metadata,
|
||||||
website,
|
website,
|
||||||
)
|
)
|
||||||
|
from .datasets.rag_pipeline import (
|
||||||
|
datasource_auth,
|
||||||
|
datasource_content_preview,
|
||||||
|
rag_pipeline,
|
||||||
|
rag_pipeline_datasets,
|
||||||
|
rag_pipeline_draft_variable,
|
||||||
|
rag_pipeline_import,
|
||||||
|
rag_pipeline_workflow,
|
||||||
|
)
|
||||||
|
|
||||||
# Import explore controllers
|
# Import explore controllers
|
||||||
from .explore import (
|
from .explore import (
|
||||||
|
|
@ -238,6 +248,8 @@ __all__ = [
|
||||||
"datasets",
|
"datasets",
|
||||||
"datasets_document",
|
"datasets_document",
|
||||||
"datasets_segments",
|
"datasets_segments",
|
||||||
|
"datasource_auth",
|
||||||
|
"datasource_content_preview",
|
||||||
"email_register",
|
"email_register",
|
||||||
"endpoint",
|
"endpoint",
|
||||||
"extension",
|
"extension",
|
||||||
|
|
@ -263,10 +275,16 @@ __all__ = [
|
||||||
"parameter",
|
"parameter",
|
||||||
"ping",
|
"ping",
|
||||||
"plugin",
|
"plugin",
|
||||||
|
"rag_pipeline",
|
||||||
|
"rag_pipeline_datasets",
|
||||||
|
"rag_pipeline_draft_variable",
|
||||||
|
"rag_pipeline_import",
|
||||||
|
"rag_pipeline_workflow",
|
||||||
"recommended_app",
|
"recommended_app",
|
||||||
"saved_message",
|
"saved_message",
|
||||||
"setup",
|
"setup",
|
||||||
"site",
|
"site",
|
||||||
|
"spec",
|
||||||
"statistic",
|
"statistic",
|
||||||
"tags",
|
"tags",
|
||||||
"tool_providers",
|
"tool_providers",
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
||||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
from extensions.ext_database import db
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
|
from models import App
|
||||||
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/rule-generate")
|
@console_ns.route("/rule-generate")
|
||||||
|
|
@ -205,9 +208,6 @@ class InstructionGenerateApi(Resource):
|
||||||
try:
|
try:
|
||||||
# Generate from nothing for a workflow node
|
# Generate from nothing for a workflow node
|
||||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||||
from models import App, db
|
|
||||||
from services.workflow_service import WorkflowService
|
|
||||||
|
|
||||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||||
if not app:
|
if not app:
|
||||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||||
|
|
@ -261,6 +261,7 @@ class InstructionGenerateApi(Resource):
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
ideal_output=args["ideal_output"],
|
ideal_output=args["ideal_output"],
|
||||||
|
workflow_service=WorkflowService(),
|
||||||
)
|
)
|
||||||
return {"error": "incompatible parameters"}, 400
|
return {"error": "incompatible parameters"}, 400
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
from core.helper.trace_id_helper import get_external_trace_id
|
from core.helper.trace_id_helper import get_external_trace_id
|
||||||
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory, variable_factory
|
from factories import file_factory, variable_factory
|
||||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||||
|
|
@ -536,7 +537,12 @@ class WorkflowTaskStopApi(Resource):
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
# Stop using both mechanisms for backward compatibility
|
||||||
|
# Legacy stop flag mechanism (without user check)
|
||||||
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||||
|
|
||||||
|
# New graph engine command channel mechanism
|
||||||
|
GraphEngineManager.send_stop_command(task_id)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
from core.workflow.enums import WorkflowExecutionStatus
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
|
|
|
||||||
|
|
@ -13,14 +13,16 @@ from controllers.console.app.error import (
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||||
|
from core.file import helpers as file_helpers
|
||||||
from core.variables.segment_group import SegmentGroup
|
from core.variables.segment_group import SegmentGroup
|
||||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||||
from core.variables.types import SegmentType
|
from core.variables.types import SegmentType
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
|
from extensions.ext_database import db
|
||||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||||
from factories.variable_factory import build_segment_with_type
|
from factories.variable_factory import build_segment_with_type
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
from models import App, AppMode, db
|
from models import App, AppMode
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.workflow import WorkflowDraftVariable
|
from models.workflow import WorkflowDraftVariable
|
||||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||||
|
|
@ -74,6 +76,22 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||||
return value_type.exposed_type().value
|
return value_type.exposed_type().value
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
||||||
|
"""Serialize full_content information for large variables."""
|
||||||
|
if not variable.is_truncated():
|
||||||
|
return None
|
||||||
|
|
||||||
|
variable_file = variable.variable_file
|
||||||
|
assert variable_file is not None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"size_bytes": variable_file.size,
|
||||||
|
"value_type": variable_file.value_type.exposed_type().value,
|
||||||
|
"length": variable_file.length,
|
||||||
|
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||||
|
|
@ -83,11 +101,13 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||||
"value_type": fields.String(attribute=_serialize_variable_type),
|
"value_type": fields.String(attribute=_serialize_variable_type),
|
||||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||||
"visible": fields.Boolean,
|
"visible": fields.Boolean,
|
||||||
|
"is_truncated": fields.Boolean(attribute=lambda model: model.file_id is not None),
|
||||||
}
|
}
|
||||||
|
|
||||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||||
value=fields.Raw(attribute=_serialize_var_value),
|
value=fields.Raw(attribute=_serialize_var_value),
|
||||||
|
full_content=fields.Raw(attribute=_serialize_full_content),
|
||||||
)
|
)
|
||||||
|
|
||||||
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
|
|
@ -9,6 +11,8 @@ from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||||
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
|
|
@ -19,6 +23,7 @@ from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import DataSourceOauthBinding, Document
|
from models import DataSourceOauthBinding, Document
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -111,6 +116,18 @@ class DataSourceNotionListApi(Resource):
|
||||||
@marshal_with(integrate_notion_info_list_fields)
|
@marshal_with(integrate_notion_info_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||||
|
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||||
|
if not credential_id:
|
||||||
|
raise ValueError("Credential id is required.")
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
credential = datasource_provider_service.get_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
credential_id=credential_id,
|
||||||
|
provider="notion_datasource",
|
||||||
|
plugin_id="langgenius/notion_datasource",
|
||||||
|
)
|
||||||
|
if not credential:
|
||||||
|
raise NotFound("Credential not found.")
|
||||||
exist_page_ids = []
|
exist_page_ids = []
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
# import notion in the exist dataset
|
# import notion in the exist dataset
|
||||||
|
|
@ -134,31 +151,49 @@ class DataSourceNotionListApi(Resource):
|
||||||
data_source_info = json.loads(document.data_source_info)
|
data_source_info = json.loads(document.data_source_info)
|
||||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||||
# get all authorized pages
|
# get all authorized pages
|
||||||
data_source_bindings = session.scalars(
|
from core.datasource.datasource_manager import DatasourceManager
|
||||||
select(DataSourceOauthBinding).filter_by(
|
|
||||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||||
|
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||||
|
datasource_name="notion_datasource",
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||||
|
)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
if credential:
|
||||||
|
datasource_runtime.runtime.credentials = credential
|
||||||
|
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||||
|
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
||||||
|
datasource_runtime.get_online_document_pages(
|
||||||
|
user_id=current_user.id,
|
||||||
|
datasource_parameters={},
|
||||||
|
provider_type=datasource_runtime.datasource_provider_type(),
|
||||||
)
|
)
|
||||||
).all()
|
)
|
||||||
if not data_source_bindings:
|
try:
|
||||||
return {"notion_info": []}, 200
|
pages = []
|
||||||
pre_import_info_list = []
|
workspace_info = {}
|
||||||
for data_source_binding in data_source_bindings:
|
for message in online_document_result:
|
||||||
source_info = data_source_binding.source_info
|
result = message.result
|
||||||
pages = source_info["pages"]
|
for info in result:
|
||||||
# Filter out already bound pages
|
workspace_info = {
|
||||||
for page in pages:
|
"workspace_id": info.workspace_id,
|
||||||
if page["page_id"] in exist_page_ids:
|
"workspace_name": info.workspace_name,
|
||||||
page["is_bound"] = True
|
"workspace_icon": info.workspace_icon,
|
||||||
else:
|
}
|
||||||
page["is_bound"] = False
|
for page in info.pages:
|
||||||
pre_import_info = {
|
page_info = {
|
||||||
"workspace_name": source_info["workspace_name"],
|
"page_id": page.page_id,
|
||||||
"workspace_icon": source_info["workspace_icon"],
|
"page_name": page.page_name,
|
||||||
"workspace_id": source_info["workspace_id"],
|
"type": page.type,
|
||||||
"pages": pages,
|
"parent_id": page.parent_id,
|
||||||
}
|
"is_bound": page.page_id in exist_page_ids,
|
||||||
pre_import_info_list.append(pre_import_info)
|
"page_icon": page.page_icon,
|
||||||
return {"notion_info": pre_import_info_list}, 200
|
}
|
||||||
|
pages.append(page_info)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
||||||
|
|
||||||
|
|
||||||
class DataSourceNotionApi(Resource):
|
class DataSourceNotionApi(Resource):
|
||||||
|
|
@ -166,27 +201,25 @@ class DataSourceNotionApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, workspace_id, page_id, page_type):
|
def get(self, workspace_id, page_id, page_type):
|
||||||
|
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||||
|
if not credential_id:
|
||||||
|
raise ValueError("Credential id is required.")
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
credential = datasource_provider_service.get_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
credential_id=credential_id,
|
||||||
|
provider="notion_datasource",
|
||||||
|
plugin_id="langgenius/notion_datasource",
|
||||||
|
)
|
||||||
|
|
||||||
workspace_id = str(workspace_id)
|
workspace_id = str(workspace_id)
|
||||||
page_id = str(page_id)
|
page_id = str(page_id)
|
||||||
with Session(db.engine) as session:
|
|
||||||
data_source_binding = session.execute(
|
|
||||||
select(DataSourceOauthBinding).where(
|
|
||||||
db.and_(
|
|
||||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
|
||||||
DataSourceOauthBinding.provider == "notion",
|
|
||||||
DataSourceOauthBinding.disabled == False,
|
|
||||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one_or_none()
|
|
||||||
if not data_source_binding:
|
|
||||||
raise NotFound("Data source binding not found.")
|
|
||||||
|
|
||||||
extractor = NotionExtractor(
|
extractor = NotionExtractor(
|
||||||
notion_workspace_id=workspace_id,
|
notion_workspace_id=workspace_id,
|
||||||
notion_obj_id=page_id,
|
notion_obj_id=page_id,
|
||||||
notion_page_type=page_type,
|
notion_page_type=page_type,
|
||||||
notion_access_token=data_source_binding.access_token,
|
notion_access_token=credential.get("integration_secret"),
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -211,10 +244,12 @@ class DataSourceNotionApi(Resource):
|
||||||
extract_settings = []
|
extract_settings = []
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
workspace_id = notion_info["workspace_id"]
|
workspace_id = notion_info["workspace_id"]
|
||||||
|
credential_id = notion_info.get("credential_id")
|
||||||
for page in notion_info["pages"]:
|
for page in notion_info["pages"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.NOTION.value,
|
datasource_type=DatasourceType.NOTION.value,
|
||||||
notion_info={
|
notion_info={
|
||||||
|
"credential_id": credential_id,
|
||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_obj_id": page["page_id"],
|
"notion_obj_id": page["page_id"],
|
||||||
"notion_page_type": page["type"],
|
"notion_page_type": page["type"],
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ from controllers.console.wraps import (
|
||||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||||
|
|
@ -33,6 +32,7 @@ from fields.document_fields import document_status_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||||
from models.dataset import DatasetPermissionEnum
|
from models.dataset import DatasetPermissionEnum
|
||||||
|
from models.provider_ids import ModelProviderID
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -337,6 +337,15 @@ class DatasetApi(Resource):
|
||||||
location="json",
|
location="json",
|
||||||
help="Invalid external knowledge api id.",
|
help="Invalid external knowledge api id.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"icon_info",
|
||||||
|
type=dict,
|
||||||
|
required=False,
|
||||||
|
nullable=True,
|
||||||
|
location="json",
|
||||||
|
help="Invalid icon info.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
|
|
||||||
|
|
@ -387,7 +396,7 @@ class DatasetApi(Resource):
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor or current_user.is_dataset_operator:
|
if not (current_user.is_editor or current_user.is_dataset_operator):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -503,10 +512,12 @@ class DatasetIndexingEstimateApi(Resource):
|
||||||
notion_info_list = args["info_list"]["notion_info_list"]
|
notion_info_list = args["info_list"]["notion_info_list"]
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
workspace_id = notion_info["workspace_id"]
|
workspace_id = notion_info["workspace_id"]
|
||||||
|
credential_id = notion_info.get("credential_id")
|
||||||
for page in notion_info["pages"]:
|
for page in notion_info["pages"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.NOTION.value,
|
datasource_type=DatasourceType.NOTION.value,
|
||||||
notion_info={
|
notion_info={
|
||||||
|
"credential_id": credential_id,
|
||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_obj_id": page["page_id"],
|
"notion_obj_id": page["page_id"],
|
||||||
"notion_page_type": page["type"],
|
"notion_page_type": page["type"],
|
||||||
|
|
@ -730,6 +741,19 @@ class DatasetApiDeleteApi(Resource):
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<string:status>")
|
||||||
|
class DatasetEnableApiApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, dataset_id, status):
|
||||||
|
dataset_id_str = str(dataset_id)
|
||||||
|
|
||||||
|
DatasetService.update_dataset_api_status(dataset_id_str, status == "enable")
|
||||||
|
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/datasets/api-base-info")
|
@console_ns.route("/datasets/api-base-info")
|
||||||
class DatasetApiBaseUrlApi(Resource):
|
class DatasetApiBaseUrlApi(Resource):
|
||||||
@api.doc("get_dataset_api_base_info")
|
@api.doc("get_dataset_api_base_info")
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from argparse import ArgumentTypeError
|
from argparse import ArgumentTypeError
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
@ -53,6 +54,7 @@ from fields.document_fields import (
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||||
|
from models.dataset import DocumentPipelineExecutionLog
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||||
|
|
||||||
|
|
@ -542,6 +544,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.NOTION.value,
|
datasource_type=DatasourceType.NOTION.value,
|
||||||
notion_info={
|
notion_info={
|
||||||
|
"credential_id": data_source_info["credential_id"],
|
||||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||||
"notion_obj_id": data_source_info["notion_page_id"],
|
"notion_obj_id": data_source_info["notion_page_id"],
|
||||||
"notion_page_type": data_source_info["type"],
|
"notion_page_type": data_source_info["type"],
|
||||||
|
|
@ -716,7 +719,7 @@ class DocumentApi(DocumentResource):
|
||||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||||
elif metadata == "without":
|
elif metadata == "without":
|
||||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||||
document_process_rules = document.dataset_process_rule.to_dict()
|
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||||
data_source_info = document.data_source_detail_dict
|
data_source_info = document.data_source_detail_dict
|
||||||
response = {
|
response = {
|
||||||
"id": document.id,
|
"id": document.id,
|
||||||
|
|
@ -1108,3 +1111,64 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||||
DocumentService.sync_website_document(dataset_id, document)
|
DocumentService.sync_website_document(dataset_id, document)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, dataset_id, document_id):
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
|
document_id = str(document_id)
|
||||||
|
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
|
if not dataset:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
document = DocumentService.get_document(dataset.id, document_id)
|
||||||
|
if not document:
|
||||||
|
raise NotFound("Document not found.")
|
||||||
|
log = (
|
||||||
|
db.session.query(DocumentPipelineExecutionLog)
|
||||||
|
.filter_by(document_id=document_id)
|
||||||
|
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not log:
|
||||||
|
return {
|
||||||
|
"datasource_info": None,
|
||||||
|
"datasource_type": None,
|
||||||
|
"input_data": None,
|
||||||
|
"datasource_node_id": None,
|
||||||
|
}, 200
|
||||||
|
return {
|
||||||
|
"datasource_info": json.loads(log.datasource_info),
|
||||||
|
"datasource_type": log.datasource_type,
|
||||||
|
"input_data": log.input_data,
|
||||||
|
"datasource_node_id": log.datasource_node_id,
|
||||||
|
}, 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
|
||||||
|
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||||
|
api.add_resource(DatasetInitApi, "/datasets/init")
|
||||||
|
api.add_resource(
|
||||||
|
DocumentIndexingEstimateApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate"
|
||||||
|
)
|
||||||
|
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
|
||||||
|
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
||||||
|
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||||
|
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||||
|
api.add_resource(
|
||||||
|
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
|
||||||
|
)
|
||||||
|
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
|
||||||
|
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
|
||||||
|
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
|
||||||
|
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
|
||||||
|
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
|
||||||
|
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||||
|
|
||||||
|
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||||
|
api.add_resource(
|
||||||
|
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -71,3 +71,9 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
|
||||||
error_code = "child_chunk_delete_index_error"
|
error_code = "child_chunk_delete_index_error"
|
||||||
description = "Delete child chunk index failed: {message}"
|
description = "Delete child chunk index failed: {message}"
|
||||||
code = 500
|
code = 500
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineNotFoundError(BaseHTTPException):
|
||||||
|
error_code = "pipeline_not_found"
|
||||||
|
description = "Pipeline not found."
|
||||||
|
code = 404
|
||||||
|
|
|
||||||
|
|
@ -148,7 +148,7 @@ class ExternalApiTemplateApi(Resource):
|
||||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor or current_user.is_dataset_operator:
|
if not (current_user.is_editor or current_user.is_dataset_operator):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
|
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,362 @@
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from flask import make_response, redirect, request
|
||||||
|
from flask_login import current_user
|
||||||
|
from flask_restx import Resource, reqparse
|
||||||
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
|
from libs.helper import StrLen
|
||||||
|
from libs.login import login_required
|
||||||
|
from models.provider_ids import DatasourceProviderID
|
||||||
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider_id: str):
|
||||||
|
user = current_user
|
||||||
|
tenant_id = user.current_tenant_id
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
credential_id = request.args.get("credential_id")
|
||||||
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
provider_name = datasource_provider_id.provider_name
|
||||||
|
plugin_id = datasource_provider_id.plugin_id
|
||||||
|
oauth_config = DatasourceProviderService().get_oauth_client(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
datasource_provider_id=datasource_provider_id,
|
||||||
|
)
|
||||||
|
if not oauth_config:
|
||||||
|
raise ValueError(f"No OAuth Client Config for {provider_id}")
|
||||||
|
|
||||||
|
context_id = OAuthProxyService.create_proxy_context(
|
||||||
|
user_id=current_user.id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
provider=provider_name,
|
||||||
|
credential_id=credential_id,
|
||||||
|
)
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||||
|
authorization_url_response = oauth_handler.get_authorization_url(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
provider=provider_name,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
system_credentials=oauth_config,
|
||||||
|
)
|
||||||
|
response = make_response(jsonable_encoder(authorization_url_response))
|
||||||
|
response.set_cookie(
|
||||||
|
"context_id",
|
||||||
|
context_id,
|
||||||
|
httponly=True,
|
||||||
|
samesite="Lax",
|
||||||
|
max_age=OAuthProxyService.__MAX_AGE__,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceOAuthCallback(Resource):
|
||||||
|
@setup_required
|
||||||
|
def get(self, provider_id: str):
|
||||||
|
context_id = request.cookies.get("context_id") or request.args.get("context_id")
|
||||||
|
if not context_id:
|
||||||
|
raise Forbidden("context_id not found")
|
||||||
|
|
||||||
|
context = OAuthProxyService.use_proxy_context(context_id)
|
||||||
|
if context is None:
|
||||||
|
raise Forbidden("Invalid context_id")
|
||||||
|
|
||||||
|
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||||
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
plugin_id = datasource_provider_id.plugin_id
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
oauth_client_params = datasource_provider_service.get_oauth_client(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
datasource_provider_id=datasource_provider_id,
|
||||||
|
)
|
||||||
|
if not oauth_client_params:
|
||||||
|
raise NotFound()
|
||||||
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
oauth_response = oauth_handler.get_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
provider=datasource_provider_id.provider_name,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
system_credentials=oauth_client_params,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
credential_id = context.get("credential_id")
|
||||||
|
if credential_id:
|
||||||
|
datasource_provider_service.reauthorize_datasource_oauth_provider(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=datasource_provider_id,
|
||||||
|
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||||
|
name=oauth_response.metadata.get("name") or None,
|
||||||
|
expire_at=oauth_response.expires_at,
|
||||||
|
credentials=dict(oauth_response.credentials),
|
||||||
|
credential_id=context.get("credential_id"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
datasource_provider_service.add_datasource_oauth_provider(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=datasource_provider_id,
|
||||||
|
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||||
|
name=oauth_response.metadata.get("name") or None,
|
||||||
|
expire_at=oauth_response.expires_at,
|
||||||
|
credentials=dict(oauth_response.credentials),
|
||||||
|
)
|
||||||
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceAuth(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider_id: str):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
|
||||||
|
)
|
||||||
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
|
||||||
|
try:
|
||||||
|
datasource_provider_service.add_datasource_api_key_provider(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
provider_id=datasource_provider_id,
|
||||||
|
credentials=args["credentials"],
|
||||||
|
name=args["name"],
|
||||||
|
)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ValueError(str(ex))
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider_id: str):
|
||||||
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasources = datasource_provider_service.list_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
provider=datasource_provider_id.provider_name,
|
||||||
|
plugin_id=datasource_provider_id.plugin_id,
|
||||||
|
)
|
||||||
|
return {"result": datasources}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceAuthDeleteApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider_id: str):
|
||||||
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
plugin_id = datasource_provider_id.plugin_id
|
||||||
|
provider_name = datasource_provider_id.provider_name
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasource_provider_service.remove_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
auth_id=args["credential_id"],
|
||||||
|
provider=provider_name,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
)
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceAuthUpdateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider_id: str):
|
||||||
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasource_provider_service.update_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
auth_id=args["credential_id"],
|
||||||
|
provider=datasource_provider_id.provider_name,
|
||||||
|
plugin_id=datasource_provider_id.plugin_id,
|
||||||
|
credentials=args.get("credentials", {}),
|
||||||
|
name=args.get("name", None),
|
||||||
|
)
|
||||||
|
return {"result": "success"}, 201
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceAuthListApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasources = datasource_provider_service.get_all_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id
|
||||||
|
)
|
||||||
|
return {"result": jsonable_encoder(datasources)}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceHardCodeAuthListApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id
|
||||||
|
)
|
||||||
|
return {"result": jsonable_encoder(datasources)}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceAuthOauthCustomClient(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider_id: str):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasource_provider_service.setup_oauth_custom_client_params(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
datasource_provider_id=datasource_provider_id,
|
||||||
|
client_params=args.get("client_params", {}),
|
||||||
|
enabled=args.get("enable_oauth_custom_client", False),
|
||||||
|
)
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def delete(self, provider_id: str):
|
||||||
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasource_provider_service.remove_oauth_custom_client_params(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
datasource_provider_id=datasource_provider_id,
|
||||||
|
)
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceAuthDefaultApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider_id: str):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasource_provider_service.set_default_datasource_provider(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
datasource_provider_id=datasource_provider_id,
|
||||||
|
credential_id=args["id"],
|
||||||
|
)
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceUpdateProviderNameApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider_id: str):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasource_provider_service.update_datasource_provider_name(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
datasource_provider_id=datasource_provider_id,
|
||||||
|
name=args["name"],
|
||||||
|
credential_id=args["credential_id"],
|
||||||
|
)
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DatasourcePluginOAuthAuthorizationUrl,
|
||||||
|
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceOAuthCallback,
|
||||||
|
"/oauth/plugin/<path:provider_id>/datasource/callback",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceAuth,
|
||||||
|
"/auth/plugin/datasource/<path:provider_id>",
|
||||||
|
)
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceAuthUpdateApi,
|
||||||
|
"/auth/plugin/datasource/<path:provider_id>/update",
|
||||||
|
)
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceAuthDeleteApi,
|
||||||
|
"/auth/plugin/datasource/<path:provider_id>/delete",
|
||||||
|
)
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceAuthListApi,
|
||||||
|
"/auth/plugin/datasource/list",
|
||||||
|
)
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceHardCodeAuthListApi,
|
||||||
|
"/auth/plugin/datasource/default-list",
|
||||||
|
)
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceAuthOauthCustomClient,
|
||||||
|
"/auth/plugin/datasource/<path:provider_id>/custom-client",
|
||||||
|
)
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceAuthDefaultApi,
|
||||||
|
"/auth/plugin/datasource/<path:provider_id>/default",
|
||||||
|
)
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceUpdateProviderNameApi,
|
||||||
|
"/auth/plugin/datasource/<path:provider_id>/update-name",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
from flask_restx import ( # type: ignore
|
||||||
|
Resource, # type: ignore
|
||||||
|
reqparse,
|
||||||
|
)
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from libs.login import current_user, login_required
|
||||||
|
from models import Account
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
|
||||||
|
|
||||||
|
class DataSourceContentPreviewApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_rag_pipeline
|
||||||
|
def post(self, pipeline: Pipeline, node_id: str):
|
||||||
|
"""
|
||||||
|
Run datasource content preview
|
||||||
|
"""
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("credential_id", type=str, required=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
inputs = args.get("inputs")
|
||||||
|
if inputs is None:
|
||||||
|
raise ValueError("missing inputs")
|
||||||
|
datasource_type = args.get("datasource_type")
|
||||||
|
if datasource_type is None:
|
||||||
|
raise ValueError("missing datasource_type")
|
||||||
|
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||||
|
pipeline=pipeline,
|
||||||
|
node_id=node_id,
|
||||||
|
user_inputs=inputs,
|
||||||
|
account=current_user,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
is_published=True,
|
||||||
|
credential_id=args.get("credential_id"),
|
||||||
|
)
|
||||||
|
return preview_content, 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DataSourceContentPreviewApi,
|
||||||
|
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,164 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource, reqparse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
enterprise_license_required,
|
||||||
|
knowledge_pipeline_publish_enabled,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.login import login_required
|
||||||
|
from models.dataset import PipelineCustomizedTemplate
|
||||||
|
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||||
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_name(name):
|
||||||
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
|
raise ValueError("Name must be between 1 to 40 characters.")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_description_length(description):
|
||||||
|
if len(description) > 400:
|
||||||
|
raise ValueError("Description cannot exceed 400 characters.")
|
||||||
|
return description
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineTemplateListApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def get(self):
|
||||||
|
type = request.args.get("type", default="built-in", type=str)
|
||||||
|
language = request.args.get("language", default="en-US", type=str)
|
||||||
|
# get pipeline templates
|
||||||
|
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||||
|
return pipeline_templates, 200
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineTemplateDetailApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def get(self, template_id: str):
|
||||||
|
type = request.args.get("type", default="built-in", type=str)
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||||
|
return pipeline_template, 200
|
||||||
|
|
||||||
|
|
||||||
|
class CustomizedPipelineTemplateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def patch(self, template_id: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"name",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="Name must be between 1 to 40 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"description",
|
||||||
|
type=str,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"icon_info",
|
||||||
|
type=dict,
|
||||||
|
location="json",
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||||
|
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||||
|
return 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def delete(self, template_id: str):
|
||||||
|
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||||
|
return 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def post(self, template_id: str):
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
template = (
|
||||||
|
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
||||||
|
)
|
||||||
|
if not template:
|
||||||
|
raise ValueError("Customized pipeline template not found.")
|
||||||
|
|
||||||
|
return {"data": template.yaml_content}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
@knowledge_pipeline_publish_enabled
|
||||||
|
def post(self, pipeline_id: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"name",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="Name must be between 1 to 40 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"description",
|
||||||
|
type=str,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"icon_info",
|
||||||
|
type=dict,
|
||||||
|
location="json",
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
PipelineTemplateListApi,
|
||||||
|
"/rag/pipeline/templates",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
PipelineTemplateDetailApi,
|
||||||
|
"/rag/pipeline/templates/<string:template_id>",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
CustomizedPipelineTemplateApi,
|
||||||
|
"/rag/pipeline/customized/templates/<string:template_id>",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
PublishCustomizedPipelineTemplateApi,
|
||||||
|
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
from flask_login import current_user # type: ignore # type: ignore
|
||||||
|
from flask_restx import Resource, marshal, reqparse # type: ignore
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
cloud_edition_billing_rate_limit_check,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from fields.dataset_fields import dataset_detail_fields
|
||||||
|
from libs.login import login_required
|
||||||
|
from models.dataset import DatasetPermissionEnum
|
||||||
|
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||||
|
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||||
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_name(name):
|
||||||
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
|
raise ValueError("Name must be between 1 to 40 characters.")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_description_length(description):
|
||||||
|
if len(description) > 400:
|
||||||
|
raise ValueError("Description cannot exceed 400 characters.")
|
||||||
|
return description
|
||||||
|
|
||||||
|
|
||||||
|
class CreateRagPipelineDatasetApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"yaml_content",
|
||||||
|
type=str,
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="yaml_content is required.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
if not current_user.is_dataset_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(
|
||||||
|
name="",
|
||||||
|
description="",
|
||||||
|
icon_info=IconInfo(
|
||||||
|
icon="📙",
|
||||||
|
icon_background="#FFF4ED",
|
||||||
|
icon_type="emoji",
|
||||||
|
),
|
||||||
|
permission=DatasetPermissionEnum.ONLY_ME,
|
||||||
|
partial_member_list=None,
|
||||||
|
yaml_content=args["yaml_content"],
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||||
|
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||||
|
)
|
||||||
|
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||||
|
DatasetPermissionService.update_partial_member_list(
|
||||||
|
current_user.current_tenant_id,
|
||||||
|
import_info["dataset_id"],
|
||||||
|
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||||
|
)
|
||||||
|
except services.errors.dataset.DatasetNameDuplicateError:
|
||||||
|
raise DatasetNameDuplicateError()
|
||||||
|
|
||||||
|
return import_info, 201
|
||||||
|
|
||||||
|
|
||||||
|
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
def post(self):
|
||||||
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
if not current_user.is_dataset_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
||||||
|
name="",
|
||||||
|
description="",
|
||||||
|
icon_info=IconInfo(
|
||||||
|
icon="📙",
|
||||||
|
icon_background="#FFF4ED",
|
||||||
|
icon_type="emoji",
|
||||||
|
),
|
||||||
|
permission=DatasetPermissionEnum.ONLY_ME,
|
||||||
|
partial_member_list=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return marshal(dataset, dataset_detail_fields), 201
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
||||||
|
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
||||||
|
|
@ -0,0 +1,389 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any, NoReturn
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
|
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.app.error import (
|
||||||
|
DraftWorkflowNotExist,
|
||||||
|
)
|
||||||
|
from controllers.console.app.workflow_draft_variable import (
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||||
|
)
|
||||||
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||||
|
from core.variables.segment_group import SegmentGroup
|
||||||
|
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||||
|
from core.variables.types import SegmentType
|
||||||
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||||
|
from factories.variable_factory import build_segment_with_type
|
||||||
|
from libs.login import current_user, login_required
|
||||||
|
from models.account import Account
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from models.workflow import WorkflowDraftVariable
|
||||||
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||||
|
if isinstance(value, FileSegment):
|
||||||
|
return value.value.model_dump()
|
||||||
|
elif isinstance(value, ArrayFileSegment):
|
||||||
|
return [i.model_dump() for i in value.value]
|
||||||
|
elif isinstance(value, SegmentGroup):
|
||||||
|
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||||
|
else:
|
||||||
|
return value.value
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||||
|
value = variable.get_value()
|
||||||
|
# create a copy of the value to avoid affecting the model cache.
|
||||||
|
value = value.model_copy(deep=True)
|
||||||
|
# Refresh the url signature before returning it to client.
|
||||||
|
if isinstance(value, FileSegment):
|
||||||
|
file = value.value
|
||||||
|
file.remote_url = file.generate_url()
|
||||||
|
elif isinstance(value, ArrayFileSegment):
|
||||||
|
files = value.value
|
||||||
|
for file in files:
|
||||||
|
file.remote_url = file.generate_url()
|
||||||
|
return _convert_values_to_json_serializable_object(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_pagination_parser():
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"page",
|
||||||
|
type=inputs.int_range(1, 100_000),
|
||||||
|
required=False,
|
||||||
|
default=1,
|
||||||
|
location="args",
|
||||||
|
help="the page of data requested",
|
||||||
|
)
|
||||||
|
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||||
|
return var_list.variables
|
||||||
|
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||||
|
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||||
|
"total": fields.Raw(),
|
||||||
|
}
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||||
|
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _api_prerequisite(f):
|
||||||
|
"""Common prerequisites for all draft workflow variable APIs.
|
||||||
|
|
||||||
|
It ensures the following conditions are satisfied:
|
||||||
|
|
||||||
|
- Dify has been property setup.
|
||||||
|
- The request user has logged in and initialized.
|
||||||
|
- The requested app is a workflow or a chat flow.
|
||||||
|
- The request user has the edit permission for the app.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_rag_pipeline
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||||
|
def get(self, pipeline: Pipeline):
|
||||||
|
"""
|
||||||
|
Get draft workflow
|
||||||
|
"""
|
||||||
|
parser = _create_pagination_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
workflow_exist = rag_pipeline_service.is_workflow_exist(pipeline=pipeline)
|
||||||
|
if not workflow_exist:
|
||||||
|
raise DraftWorkflowNotExist()
|
||||||
|
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||||
|
app_id=pipeline.id,
|
||||||
|
page=args.page,
|
||||||
|
limit=args.limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
return workflow_vars
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
def delete(self, pipeline: Pipeline):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
draft_var_srv.delete_workflow_variables(pipeline.id)
|
||||||
|
db.session.commit()
|
||||||
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||||
|
if node_id in [
|
||||||
|
CONVERSATION_VARIABLE_NODE_ID,
|
||||||
|
SYSTEM_VARIABLE_NODE_ID,
|
||||||
|
]:
|
||||||
|
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
|
||||||
|
# with specific `node_id` in database, we still want to make the API separated. By disallowing
|
||||||
|
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
|
||||||
|
# we mitigate the risk that user of the API depending on the implementation detail of the API.
|
||||||
|
#
|
||||||
|
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
|
||||||
|
|
||||||
|
raise InvalidArgumentError(
|
||||||
|
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineNodeVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
|
def get(self, pipeline: Pipeline, node_id: str):
|
||||||
|
validate_node_id(node_id)
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id)
|
||||||
|
|
||||||
|
return node_vars
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
def delete(self, pipeline: Pipeline, node_id: str):
|
||||||
|
validate_node_id(node_id)
|
||||||
|
srv = WorkflowDraftVariableService(db.session())
|
||||||
|
srv.delete_node_variables(pipeline.id, node_id)
|
||||||
|
db.session.commit()
|
||||||
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineVariableApi(Resource):
|
||||||
|
_PATCH_NAME_FIELD = "name"
|
||||||
|
_PATCH_VALUE_FIELD = "value"
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||||
|
def get(self, pipeline: Pipeline, variable_id: str):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != pipeline.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
return variable
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||||
|
def patch(self, pipeline: Pipeline, variable_id: str):
|
||||||
|
# Request payload for file types:
|
||||||
|
#
|
||||||
|
# Local File:
|
||||||
|
#
|
||||||
|
# {
|
||||||
|
# "type": "image",
|
||||||
|
# "transfer_method": "local_file",
|
||||||
|
# "url": "",
|
||||||
|
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# Remote File:
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# {
|
||||||
|
# "type": "image",
|
||||||
|
# "transfer_method": "remote_url",
|
||||||
|
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
|
||||||
|
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||||
|
# }
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||||
|
# Parse 'value' field as-is to maintain its original data structure
|
||||||
|
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||||
|
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
args = parser.parse_args(strict=True)
|
||||||
|
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != pipeline.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
|
||||||
|
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||||
|
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||||
|
if new_name is None and raw_value is None:
|
||||||
|
return variable
|
||||||
|
|
||||||
|
new_value = None
|
||||||
|
if raw_value is not None:
|
||||||
|
if variable.value_type == SegmentType.FILE:
|
||||||
|
if not isinstance(raw_value, dict):
|
||||||
|
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||||
|
raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id)
|
||||||
|
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||||
|
if not isinstance(raw_value, list):
|
||||||
|
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||||
|
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||||
|
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||||
|
raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id)
|
||||||
|
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||||
|
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||||
|
db.session.commit()
|
||||||
|
return variable
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
def delete(self, pipeline: Pipeline, variable_id: str):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != pipeline.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
draft_var_srv.delete_variable(variable)
|
||||||
|
db.session.commit()
|
||||||
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineVariableResetApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
def put(self, pipeline: Pipeline, variable_id: str):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session(),
|
||||||
|
)
|
||||||
|
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||||
|
if draft_workflow is None:
|
||||||
|
raise NotFoundError(
|
||||||
|
f"Draft workflow not found, pipeline_id={pipeline.id}",
|
||||||
|
)
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != pipeline.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
|
||||||
|
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||||
|
db.session.commit()
|
||||||
|
if resetted is None:
|
||||||
|
return Response("", 204)
|
||||||
|
else:
|
||||||
|
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||||
|
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id)
|
||||||
|
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||||
|
draft_vars = draft_var_srv.list_system_variables(pipeline.id)
|
||||||
|
else:
|
||||||
|
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id)
|
||||||
|
return draft_vars
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineSystemVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
|
def get(self, pipeline: Pipeline):
|
||||||
|
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
def get(self, pipeline: Pipeline):
|
||||||
|
"""
|
||||||
|
Get draft workflow
|
||||||
|
"""
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||||
|
if workflow is None:
|
||||||
|
raise DraftWorkflowNotExist()
|
||||||
|
|
||||||
|
env_vars = workflow.environment_variables
|
||||||
|
env_vars_list = []
|
||||||
|
for v in env_vars:
|
||||||
|
env_vars_list.append(
|
||||||
|
{
|
||||||
|
"id": v.id,
|
||||||
|
"type": "env",
|
||||||
|
"name": v.name,
|
||||||
|
"description": v.description,
|
||||||
|
"selector": v.selector,
|
||||||
|
"value_type": v.value_type.value,
|
||||||
|
"value": v.value,
|
||||||
|
# Do not track edited for env vars.
|
||||||
|
"edited": False,
|
||||||
|
"visible": True,
|
||||||
|
"editable": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"items": env_vars_list}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineVariableCollectionApi,
|
||||||
|
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineNodeVariableCollectionApi,
|
||||||
|
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineEnvironmentVariableCollectionApi,
|
||||||
|
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,147 @@
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from flask_login import current_user # type: ignore
|
||||||
|
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||||
|
from libs.login import login_required
|
||||||
|
from models import Account
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from services.app_dsl_service import ImportStatus
|
||||||
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineImportApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(pipeline_import_fields)
|
||||||
|
def post(self):
|
||||||
|
# Check user role first
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("mode", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("yaml_content", type=str, location="json")
|
||||||
|
parser.add_argument("yaml_url", type=str, location="json")
|
||||||
|
parser.add_argument("name", type=str, location="json")
|
||||||
|
parser.add_argument("description", type=str, location="json")
|
||||||
|
parser.add_argument("icon_type", type=str, location="json")
|
||||||
|
parser.add_argument("icon", type=str, location="json")
|
||||||
|
parser.add_argument("icon_background", type=str, location="json")
|
||||||
|
parser.add_argument("pipeline_id", type=str, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create service with session
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
import_service = RagPipelineDslService(session)
|
||||||
|
# Import app
|
||||||
|
account = cast(Account, current_user)
|
||||||
|
result = import_service.import_rag_pipeline(
|
||||||
|
account=account,
|
||||||
|
import_mode=args["mode"],
|
||||||
|
yaml_content=args.get("yaml_content"),
|
||||||
|
yaml_url=args.get("yaml_url"),
|
||||||
|
pipeline_id=args.get("pipeline_id"),
|
||||||
|
dataset_name=args.get("name"),
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Return appropriate status code based on result
|
||||||
|
status = result.status
|
||||||
|
if status == ImportStatus.FAILED.value:
|
||||||
|
return result.model_dump(mode="json"), 400
|
||||||
|
elif status == ImportStatus.PENDING.value:
|
||||||
|
return result.model_dump(mode="json"), 202
|
||||||
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineImportConfirmApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(pipeline_import_fields)
|
||||||
|
def post(self, import_id):
|
||||||
|
# Check user role first
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
# Create service with session
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
import_service = RagPipelineDslService(session)
|
||||||
|
# Confirm import
|
||||||
|
account = cast(Account, current_user)
|
||||||
|
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Return appropriate status code based on result
|
||||||
|
if result.status == ImportStatus.FAILED.value:
|
||||||
|
return result.model_dump(mode="json"), 400
|
||||||
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@get_rag_pipeline
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||||
|
def get(self, pipeline: Pipeline):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
import_service = RagPipelineDslService(session)
|
||||||
|
result = import_service.check_dependencies(pipeline=pipeline)
|
||||||
|
|
||||||
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineExportApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@get_rag_pipeline
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, pipeline: Pipeline):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
# Add include_secret params
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("include_secret", type=bool, default=False, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
export_service = RagPipelineDslService(session)
|
||||||
|
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
|
||||||
|
|
||||||
|
return {"data": result}, 200
|
||||||
|
|
||||||
|
|
||||||
|
# Import Rag Pipeline
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineImportApi,
|
||||||
|
"/rag/pipelines/imports",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineImportConfirmApi,
|
||||||
|
"/rag/pipelines/imports/<string:import_id>/confirm",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineImportCheckDependenciesApi,
|
||||||
|
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineExportApi,
|
||||||
|
"/rag/pipelines/<string:pipeline_id>/exports",
|
||||||
|
)
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,46 @@
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from controllers.console.datasets.error import PipelineNotFoundError
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.login import current_user
|
||||||
|
from models.account import Account
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def get_rag_pipeline(
|
||||||
|
view: Callable | None = None,
|
||||||
|
):
|
||||||
|
def decorator(view_func):
|
||||||
|
@wraps(view_func)
|
||||||
|
def decorated_view(*args, **kwargs):
|
||||||
|
if not kwargs.get("pipeline_id"):
|
||||||
|
raise ValueError("missing pipeline_id in path parameters")
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user is not an account")
|
||||||
|
|
||||||
|
pipeline_id = kwargs.get("pipeline_id")
|
||||||
|
pipeline_id = str(pipeline_id)
|
||||||
|
|
||||||
|
del kwargs["pipeline_id"]
|
||||||
|
|
||||||
|
pipeline = (
|
||||||
|
db.session.query(Pipeline)
|
||||||
|
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not pipeline:
|
||||||
|
raise PipelineNotFoundError()
|
||||||
|
|
||||||
|
kwargs["pipeline"] = pipeline
|
||||||
|
|
||||||
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated_view
|
||||||
|
|
||||||
|
if view is None:
|
||||||
|
return decorator
|
||||||
|
else:
|
||||||
|
return decorator(view)
|
||||||
|
|
@ -20,6 +20,7 @@ from core.errors.error import (
|
||||||
QuotaExceededError,
|
QuotaExceededError,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.model import AppMode, InstalledApp
|
from models.model import AppMode, InstalledApp
|
||||||
|
|
@ -82,6 +83,11 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
assert current_user is not None
|
assert current_user is not None
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
# Stop using both mechanisms for backward compatibility
|
||||||
|
# Legacy stop flag mechanism (without user check)
|
||||||
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||||
|
|
||||||
|
# New graph engine command channel mechanism
|
||||||
|
GraphEngineManager.send_stop_command(task_id)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from controllers.console.wraps import (
|
||||||
cloud_edition_billing_resource_check,
|
cloud_edition_billing_resource_check,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import file_fields, upload_config_fields
|
from fields.file_fields import file_fields, upload_config_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import Account
|
from models import Account
|
||||||
|
|
@ -68,10 +69,11 @@ class FileApi(Resource):
|
||||||
if source not in ("datasets", None):
|
if source not in ("datasets", None):
|
||||||
source = None
|
source = None
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
upload_file = FileService(db.engine).upload_file(
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
upload_file = FileService.upload_file(
|
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
content=file.read(),
|
content=file.read(),
|
||||||
mimetype=file.mimetype,
|
mimetype=file.mimetype,
|
||||||
|
|
@ -92,7 +94,7 @@ class FilePreviewApi(Resource):
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, file_id):
|
def get(self, file_id):
|
||||||
file_id = str(file_id)
|
file_id = str(file_id)
|
||||||
text = FileService.get_file_preview(file_id)
|
text = FileService(db.engine).get_file_preview(file_id)
|
||||||
return {"content": text}
|
return {"content": text}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from controllers.common.errors import (
|
||||||
)
|
)
|
||||||
from core.file import helpers as file_helpers
|
from core.file import helpers as file_helpers
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
@ -61,7 +62,7 @@ class RemoteFileUploadApi(Resource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user = cast(Account, current_user)
|
user = cast(Account, current_user)
|
||||||
upload_file = FileService.upload_file(
|
upload_file = FileService(db.engine).upload_file(
|
||||||
filename=file_info.filename,
|
filename=file_info.filename,
|
||||||
content=content,
|
content=content,
|
||||||
mimetype=file_info.mimetype,
|
mimetype=file_info.mimetype,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from flask_restx import Resource
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from core.schemas.schema_manager import SchemaManager
|
||||||
|
from libs.login import login_required
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SpecSchemaDefinitionsApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
"""
|
||||||
|
Get system JSON Schema definitions specification
|
||||||
|
Used for frontend component type mapping
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
schema_manager = SchemaManager()
|
||||||
|
schema_definitions = schema_manager.get_all_schema_definitions()
|
||||||
|
return schema_definitions, 200
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get schema definitions from local registry")
|
||||||
|
# Return empty array as fallback
|
||||||
|
return [], 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
|
||||||
|
|
@ -21,11 +21,11 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||||
from core.mcp.error import MCPAuthError, MCPError
|
from core.mcp.error import MCPAuthError, MCPError
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.mcp_client import MCPClient
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin import ToolProviderID
|
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from core.tools.entities.tool_entities import CredentialType
|
from core.tools.entities.tool_entities import CredentialType
|
||||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
|
from models.provider_ids import ToolProviderID
|
||||||
from services.plugin.oauth_service import OAuthProxyService
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||||
|
|
|
||||||
|
|
@ -227,7 +227,7 @@ class WebappLogoWorkspaceApi(Resource):
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
upload_file = FileService.upload_file(
|
upload_file = FileService(db.engine).upload_file(
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
content=file.read(),
|
content=file.read(),
|
||||||
mimetype=file.mimetype,
|
mimetype=file.mimetype,
|
||||||
|
|
|
||||||
|
|
@ -279,3 +279,14 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
||||||
abort(403)
|
abort(403)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
def knowledge_pipeline_publish_enabled(view):
|
||||||
|
@wraps(view)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
if features.knowledge_pipeline.publish_enabled:
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
abort(403)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||||
import services
|
import services
|
||||||
from controllers.common.errors import UnsupportedFileTypeError
|
from controllers.common.errors import UnsupportedFileTypeError
|
||||||
from controllers.files import files_ns
|
from controllers.files import files_ns
|
||||||
|
from extensions.ext_database import db
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
@ -28,7 +29,7 @@ class ImagePreviewApi(Resource):
|
||||||
return {"content": "Invalid request."}, 400
|
return {"content": "Invalid request."}, 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generator, mimetype = FileService.get_image_preview(
|
generator, mimetype = FileService(db.engine).get_image_preview(
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
|
|
@ -57,7 +58,7 @@ class FilePreviewApi(Resource):
|
||||||
return {"content": "Invalid request."}, 400
|
return {"content": "Invalid request."}, 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generator, upload_file = FileService.get_file_generator_by_file_id(
|
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
timestamp=args["timestamp"],
|
timestamp=args["timestamp"],
|
||||||
nonce=args["nonce"],
|
nonce=args["nonce"],
|
||||||
|
|
@ -108,7 +109,7 @@ class WorkspaceWebappLogoApi(Resource):
|
||||||
raise NotFound("webapp logo is not found")
|
raise NotFound("webapp logo is not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generator, mimetype = FileService.get_public_image_preview(
|
generator, mimetype = FileService(db.engine).get_public_image_preview(
|
||||||
webapp_logo_file_id,
|
webapp_logo_file_id,
|
||||||
)
|
)
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from controllers.common.errors import UnsupportedFileTypeError
|
||||||
from controllers.files import files_ns
|
from controllers.files import files_ns
|
||||||
from core.tools.signature import verify_tool_file_signature
|
from core.tools.signature import verify_tool_file_signature
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from models import db as global_db
|
from extensions.ext_database import db as global_db
|
||||||
|
|
||||||
|
|
||||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||||
|
|
|
||||||
|
|
@ -420,7 +420,12 @@ class PluginUploadFileRequestApi(Resource):
|
||||||
)
|
)
|
||||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
|
||||||
# generate signed url
|
# generate signed url
|
||||||
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id)
|
url = get_signed_file_url_for_plugin(
|
||||||
|
filename=payload.filename,
|
||||||
|
mimetype=payload.mimetype,
|
||||||
|
tenant_id=tenant_model.id,
|
||||||
|
user_id=user_model.id,
|
||||||
|
)
|
||||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,11 +32,20 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||||
user_model = (
|
user_model = (
|
||||||
session.query(EndUser)
|
session.query(EndUser)
|
||||||
.where(
|
.where(
|
||||||
EndUser.session_id == user_id,
|
EndUser.id == user_id,
|
||||||
EndUser.tenant_id == tenant_id,
|
EndUser.tenant_id == tenant_id,
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
if not user_model:
|
||||||
|
user_model = (
|
||||||
|
session.query(EndUser)
|
||||||
|
.where(
|
||||||
|
EndUser.session_id == user_id,
|
||||||
|
EndUser.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not user_model:
|
if not user_model:
|
||||||
user_model = EndUser(
|
user_model = EndUser(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,9 @@ from controllers.common.errors import (
|
||||||
)
|
)
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||||
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import build_file_model
|
from fields.file_fields import build_file_model
|
||||||
from models.model import App, EndUser
|
from models import App, EndUser
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -52,7 +53,7 @@ class FileApi(Resource):
|
||||||
raise FilenameNotExistsError
|
raise FilenameNotExistsError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
upload_file = FileService.upload_file(
|
upload_file = FileService(db.engine).upload_file(
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
content=file.read(),
|
content=file.read(),
|
||||||
mimetype=file.mimetype,
|
mimetype=file.mimetype,
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,8 @@ from core.errors.error import (
|
||||||
)
|
)
|
||||||
from core.helper.trace_id_helper import get_external_trace_id
|
from core.helper.trace_id_helper import get_external_trace_id
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
from core.workflow.enums import WorkflowExecutionStatus
|
||||||
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||||
from libs import helper
|
from libs import helper
|
||||||
|
|
@ -262,7 +263,12 @@ class WorkflowTaskStopApi(Resource):
|
||||||
if app_mode != AppMode.WORKFLOW:
|
if app_mode != AppMode.WORKFLOW:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
# Stop using both mechanisms for backward compatibility
|
||||||
|
# Legacy stop flag mechanism (without user check)
|
||||||
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||||
|
|
||||||
|
# New graph engine command channel mechanism
|
||||||
|
GraphEngineManager.send_stop_command(task_id)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,13 +13,13 @@ from controllers.service_api.wraps import (
|
||||||
validate_dataset_token,
|
validate_dataset_token,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from fields.dataset_fields import dataset_detail_fields
|
from fields.dataset_fields import dataset_detail_fields
|
||||||
from fields.tag_fields import build_dataset_tag_fields
|
from fields.tag_fields import build_dataset_tag_fields
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.dataset import Dataset, DatasetPermissionEnum
|
from models.dataset import Dataset, DatasetPermissionEnum
|
||||||
|
from models.provider_ids import ModelProviderID
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||||
from services.tag_service import TagService
|
from services.tag_service import TagService
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,12 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
|
if not current_user:
|
||||||
|
raise ValueError("current_user is required")
|
||||||
|
|
||||||
|
upload_file = FileService(db.engine).upload_text(
|
||||||
|
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||||
|
)
|
||||||
data_source = {
|
data_source = {
|
||||||
"type": "upload_file",
|
"type": "upload_file",
|
||||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||||
|
|
@ -134,6 +139,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.document_create_args_validate(knowledge_config)
|
DocumentService.document_create_args_validate(knowledge_config)
|
||||||
|
|
||||||
|
if not current_user:
|
||||||
|
raise ValueError("current_user is required")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
|
@ -199,7 +207,11 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||||
name = args.get("name")
|
name = args.get("name")
|
||||||
if text is None or name is None:
|
if text is None or name is None:
|
||||||
raise ValueError("Both text and name must be strings.")
|
raise ValueError("Both text and name must be strings.")
|
||||||
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
|
if not current_user:
|
||||||
|
raise ValueError("current_user is required")
|
||||||
|
upload_file = FileService(db.engine).upload_text(
|
||||||
|
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||||
|
)
|
||||||
data_source = {
|
data_source = {
|
||||||
"type": "upload_file",
|
"type": "upload_file",
|
||||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||||
|
|
@ -301,8 +313,9 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||||
|
|
||||||
if not isinstance(current_user, EndUser):
|
if not isinstance(current_user, EndUser):
|
||||||
raise ValueError("Invalid user account")
|
raise ValueError("Invalid user account")
|
||||||
|
if not current_user:
|
||||||
upload_file = FileService.upload_file(
|
raise ValueError("current_user is required")
|
||||||
|
upload_file = FileService(db.engine).upload_file(
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
content=file.read(),
|
content=file.read(),
|
||||||
mimetype=file.mimetype,
|
mimetype=file.mimetype,
|
||||||
|
|
@ -390,10 +403,14 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||||
if not file.filename:
|
if not file.filename:
|
||||||
raise FilenameNotExistsError
|
raise FilenameNotExistsError
|
||||||
|
|
||||||
|
if not current_user:
|
||||||
|
raise ValueError("current_user is required")
|
||||||
|
|
||||||
|
if not isinstance(current_user, EndUser):
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, EndUser):
|
upload_file = FileService(db.engine).upload_file(
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
upload_file = FileService.upload_file(
|
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
content=file.read(),
|
content=file.read(),
|
||||||
mimetype=file.mimetype,
|
mimetype=file.mimetype,
|
||||||
|
|
@ -571,7 +588,7 @@ class DocumentApi(DatasetApiResource):
|
||||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||||
elif metadata == "without":
|
elif metadata == "without":
|
||||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||||
document_process_rules = document.dataset_process_rule.to_dict()
|
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||||
data_source_info = document.data_source_detail_dict
|
data_source_info = document.data_source_detail_dict
|
||||||
response = {
|
response = {
|
||||||
"id": document.id,
|
"id": document.id,
|
||||||
|
|
@ -604,7 +621,7 @@ class DocumentApi(DatasetApiResource):
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||||
document_process_rules = document.dataset_process_rule.to_dict()
|
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||||
data_source_info = document.data_source_detail_dict
|
data_source_info = document.data_source_detail_dict
|
||||||
response = {
|
response = {
|
||||||
"id": document.id,
|
"id": document.id,
|
||||||
|
|
|
||||||
|
|
@ -47,3 +47,9 @@ class DatasetInUseError(BaseHTTPException):
|
||||||
error_code = "dataset_in_use"
|
error_code = "dataset_in_use"
|
||||||
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
||||||
code = 409
|
code = 409
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRunError(BaseHTTPException):
|
||||||
|
error_code = "pipeline_run_error"
|
||||||
|
description = "An error occurred while running the pipeline."
|
||||||
|
code = 500
|
||||||
|
|
|
||||||
|
|
@ -133,7 +133,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||||
return 204
|
return 204
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route("/datasets/metadata/built-in")
|
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")
|
||||||
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||||
@service_api_ns.doc("get_built_in_fields")
|
@service_api_ns.doc("get_built_in_fields")
|
||||||
@service_api_ns.doc(description="Get all built-in metadata fields")
|
@service_api_ns.doc(description="Get all built-in metadata fields")
|
||||||
|
|
@ -143,7 +143,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||||
401: "Unauthorized - invalid API token",
|
401: "Unauthorized - invalid API token",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
def get(self, tenant_id):
|
def get(self, tenant_id, dataset_id):
|
||||||
"""Get all built-in metadata fields."""
|
"""Get all built-in metadata fields."""
|
||||||
built_in_fields = MetadataService.get_built_in_fields()
|
built_in_fields = MetadataService.get_built_in_fields()
|
||||||
return {"fields": built_in_fields}, 200
|
return {"fields": built_in_fields}, 200
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,242 @@
|
||||||
|
import string
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import reqparse
|
||||||
|
from flask_restx.reqparse import ParseResult, RequestParser
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
|
||||||
|
from controllers.service_api import service_api_ns
|
||||||
|
from controllers.service_api.dataset.error import PipelineRunError
|
||||||
|
from controllers.service_api.wraps import DatasetApiResource
|
||||||
|
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from libs import helper
|
||||||
|
from libs.login import current_user
|
||||||
|
from models.account import Account
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from models.engine import db
|
||||||
|
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||||
|
from services.file_service import FileService
|
||||||
|
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
|
||||||
|
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||||
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
|
||||||
|
|
||||||
|
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
|
||||||
|
class DatasourcePluginsApi(DatasetApiResource):
|
||||||
|
"""Resource for datasource plugins."""
|
||||||
|
|
||||||
|
@service_api_ns.doc(shortcut="list_rag_pipeline_datasource_plugins")
|
||||||
|
@service_api_ns.doc(description="List all datasource plugins for a rag pipeline")
|
||||||
|
@service_api_ns.doc(
|
||||||
|
path={
|
||||||
|
"dataset_id": "Dataset ID",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@service_api_ns.doc(
|
||||||
|
params={
|
||||||
|
"is_published": "Whether to get published or draft datasource plugins "
|
||||||
|
"(true for published, false for draft, default: true)"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@service_api_ns.doc(
|
||||||
|
responses={
|
||||||
|
200: "Datasource plugins retrieved successfully",
|
||||||
|
401: "Unauthorized - invalid API token",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def get(self, tenant_id: str, dataset_id: str):
|
||||||
|
"""Resource for getting datasource plugins."""
|
||||||
|
# Get query parameter to determine published or draft
|
||||||
|
is_published: bool = request.args.get("is_published", default=True, type=bool)
|
||||||
|
|
||||||
|
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||||
|
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
|
||||||
|
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
|
||||||
|
)
|
||||||
|
return datasource_plugins, 200
|
||||||
|
|
||||||
|
|
||||||
|
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
|
||||||
|
class DatasourceNodeRunApi(DatasetApiResource):
|
||||||
|
"""Resource for datasource node run."""
|
||||||
|
|
||||||
|
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
|
||||||
|
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
|
||||||
|
@service_api_ns.doc(
|
||||||
|
path={
|
||||||
|
"dataset_id": "Dataset ID",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@service_api_ns.doc(
|
||||||
|
body={
|
||||||
|
"inputs": "User input variables",
|
||||||
|
"datasource_type": "Datasource type, e.g. online_document",
|
||||||
|
"credential_id": "Credential ID",
|
||||||
|
"is_published": "Whether to get published or draft datasource plugins "
|
||||||
|
"(true for published, false for draft, default: true)",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@service_api_ns.doc(
|
||||||
|
responses={
|
||||||
|
200: "Datasource node run successfully",
|
||||||
|
401: "Unauthorized - invalid API token",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def post(self, tenant_id: str, dataset_id: str, node_id: str):
|
||||||
|
"""Resource for getting datasource plugins."""
|
||||||
|
# Get query parameter to determine published or draft
|
||||||
|
parser: RequestParser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("credential_id", type=str, required=False, location="json")
|
||||||
|
parser.add_argument("is_published", type=bool, required=True, location="json")
|
||||||
|
args: ParseResult = parser.parse_args()
|
||||||
|
|
||||||
|
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||||
|
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||||
|
return helper.compact_generate_response(
|
||||||
|
PipelineGenerator.convert_to_event_stream(
|
||||||
|
rag_pipeline_service.run_datasource_workflow_node(
|
||||||
|
pipeline=pipeline,
|
||||||
|
node_id=node_id,
|
||||||
|
user_inputs=datasource_node_run_api_entity.inputs,
|
||||||
|
account=current_user,
|
||||||
|
datasource_type=datasource_node_run_api_entity.datasource_type,
|
||||||
|
is_published=datasource_node_run_api_entity.is_published,
|
||||||
|
credential_id=datasource_node_run_api_entity.credential_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
|
||||||
|
class PipelineRunApi(DatasetApiResource):
|
||||||
|
"""Resource for datasource node run."""
|
||||||
|
|
||||||
|
@service_api_ns.doc(shortcut="pipeline_datasource_node_run")
|
||||||
|
@service_api_ns.doc(description="Run a datasource node for a rag pipeline")
|
||||||
|
@service_api_ns.doc(
|
||||||
|
path={
|
||||||
|
"dataset_id": "Dataset ID",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@service_api_ns.doc(
|
||||||
|
body={
|
||||||
|
"inputs": "User input variables",
|
||||||
|
"datasource_type": "Datasource type, e.g. online_document",
|
||||||
|
"datasource_info_list": "Datasource info list",
|
||||||
|
"start_node_id": "Start node ID",
|
||||||
|
"is_published": "Whether to get published or draft datasource plugins "
|
||||||
|
"(true for published, false for draft, default: true)",
|
||||||
|
"streaming": "Whether to stream the response(streaming or blocking), default: streaming",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@service_api_ns.doc(
|
||||||
|
responses={
|
||||||
|
200: "Pipeline run successfully",
|
||||||
|
401: "Unauthorized - invalid API token",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def post(self, tenant_id: str, dataset_id: str):
|
||||||
|
"""Resource for running a rag pipeline."""
|
||||||
|
parser: RequestParser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||||
|
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("is_published", type=bool, required=True, default=True, location="json")
|
||||||
|
parser.add_argument(
|
||||||
|
"response_mode",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
choices=["streaming", "blocking"],
|
||||||
|
default="blocking",
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
args: ParseResult = parser.parse_args()
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||||
|
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||||
|
try:
|
||||||
|
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
|
||||||
|
pipeline=pipeline,
|
||||||
|
user=current_user,
|
||||||
|
args=args,
|
||||||
|
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
|
||||||
|
streaming=args.get("response_mode") == "streaming",
|
||||||
|
)
|
||||||
|
|
||||||
|
return helper.compact_generate_response(response)
|
||||||
|
except Exception as ex:
|
||||||
|
raise PipelineRunError(description=str(ex))
|
||||||
|
|
||||||
|
|
||||||
|
@service_api_ns.route("/datasets/pipeline/file-upload")
|
||||||
|
class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
||||||
|
"""Resource for uploading a file to a knowledgebase pipeline."""
|
||||||
|
|
||||||
|
@service_api_ns.doc(shortcut="knowledgebase_pipeline_file_upload")
|
||||||
|
@service_api_ns.doc(description="Upload a file to a knowledgebase pipeline")
|
||||||
|
@service_api_ns.doc(
|
||||||
|
responses={
|
||||||
|
201: "File uploaded successfully",
|
||||||
|
400: "Bad request - no file or invalid file",
|
||||||
|
401: "Unauthorized - invalid API token",
|
||||||
|
413: "File too large",
|
||||||
|
415: "Unsupported file type",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def post(self, tenant_id: str):
|
||||||
|
"""Upload a file for use in conversations.
|
||||||
|
|
||||||
|
Accepts a single file upload via multipart/form-data.
|
||||||
|
"""
|
||||||
|
# check file
|
||||||
|
if "file" not in request.files:
|
||||||
|
raise NoFileUploadedError()
|
||||||
|
|
||||||
|
if len(request.files) > 1:
|
||||||
|
raise TooManyFilesError()
|
||||||
|
|
||||||
|
file = request.files["file"]
|
||||||
|
if not file.mimetype:
|
||||||
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
|
if not file.filename:
|
||||||
|
raise FilenameNotExistsError
|
||||||
|
|
||||||
|
if not current_user:
|
||||||
|
raise ValueError("Invalid user account")
|
||||||
|
|
||||||
|
try:
|
||||||
|
upload_file = FileService(db.engine).upload_file(
|
||||||
|
filename=file.filename,
|
||||||
|
content=file.read(),
|
||||||
|
mimetype=file.mimetype,
|
||||||
|
user=current_user,
|
||||||
|
)
|
||||||
|
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||||
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": upload_file.id,
|
||||||
|
"name": upload_file.name,
|
||||||
|
"size": upload_file.size,
|
||||||
|
"extension": upload_file.extension,
|
||||||
|
"mime_type": upload_file.mime_type,
|
||||||
|
"created_by": upload_file.created_by,
|
||||||
|
"created_at": upload_file.created_at,
|
||||||
|
}, 201
|
||||||
|
|
@ -193,6 +193,47 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||||
def decorator(view: Callable[Concatenate[T, P], R]):
|
def decorator(view: Callable[Concatenate[T, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
|
# get url path dataset_id from positional args or kwargs
|
||||||
|
# Flask passes URL path parameters as positional arguments
|
||||||
|
dataset_id = None
|
||||||
|
|
||||||
|
# First try to get from kwargs (explicit parameter)
|
||||||
|
dataset_id = kwargs.get("dataset_id")
|
||||||
|
|
||||||
|
# If not in kwargs, try to extract from positional args
|
||||||
|
if not dataset_id and args:
|
||||||
|
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
|
||||||
|
# Check if first arg is likely a class instance (has __dict__ or __class__)
|
||||||
|
if len(args) > 1 and hasattr(args[0], "__dict__"):
|
||||||
|
# This is a class method, dataset_id should be in args[1]
|
||||||
|
potential_id = args[1]
|
||||||
|
# Validate it's a string-like UUID, not another object
|
||||||
|
try:
|
||||||
|
# Try to convert to string and check if it's a valid UUID format
|
||||||
|
str_id = str(potential_id)
|
||||||
|
# Basic check: UUIDs are 36 chars with hyphens
|
||||||
|
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||||
|
dataset_id = str_id
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
elif len(args) > 0:
|
||||||
|
# Not a class method, check if args[0] looks like a UUID
|
||||||
|
potential_id = args[0]
|
||||||
|
try:
|
||||||
|
str_id = str(potential_id)
|
||||||
|
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||||
|
dataset_id = str_id
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Validate dataset if dataset_id is provided
|
||||||
|
if dataset_id:
|
||||||
|
dataset_id = str(dataset_id)
|
||||||
|
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||||
|
if not dataset:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
if not dataset.enable_api:
|
||||||
|
raise Forbidden("Dataset api access is not enabled.")
|
||||||
api_token = validate_and_get_api_token("dataset")
|
api_token = validate_and_get_api_token("dataset")
|
||||||
tenant_account_join = (
|
tenant_account_join = (
|
||||||
db.session.query(Tenant, TenantAccountJoin)
|
db.session.query(Tenant, TenantAccountJoin)
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from controllers.common.errors import (
|
||||||
)
|
)
|
||||||
from controllers.web import web_ns
|
from controllers.web import web_ns
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import build_file_model
|
from fields.file_fields import build_file_model
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
@ -68,7 +69,7 @@ class FileApi(WebApiResource):
|
||||||
source = None
|
source = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
upload_file = FileService.upload_file(
|
upload_file = FileService(db.engine).upload_file(
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
content=file.read(),
|
content=file.read(),
|
||||||
mimetype=file.mimetype,
|
mimetype=file.mimetype,
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from controllers.web import web_ns
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from core.file import helpers as file_helpers
|
from core.file import helpers as file_helpers
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
@ -119,7 +120,7 @@ class RemoteFileUploadApi(WebApiResource):
|
||||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||||
|
|
||||||
try:
|
try:
|
||||||
upload_file = FileService.upload_file(
|
upload_file = FileService(db.engine).upload_file(
|
||||||
filename=file_info.filename,
|
filename=file_info.filename,
|
||||||
content=content,
|
content=content,
|
||||||
mimetype=file_info.mimetype,
|
mimetype=file_info.mimetype,
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from core.errors.error import (
|
||||||
QuotaExceededError,
|
QuotaExceededError,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
|
|
@ -112,6 +113,11 @@ class WorkflowTaskStopApi(WebApiResource):
|
||||||
if app_mode != AppMode.WORKFLOW:
|
if app_mode != AppMode.WORKFLOW:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
# Stop using both mechanisms for backward compatibility
|
||||||
|
# Legacy stop flag mechanism (without user check)
|
||||||
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||||
|
|
||||||
|
# New graph engine command channel mechanism
|
||||||
|
GraphEngineManager.send_stop_command(task_id)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,9 @@ class BaseAgentRunner(AppRunner):
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
|
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
|
||||||
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
||||||
return_resource=app_config.additional_features.show_retrieve_source,
|
return_resource=(
|
||||||
|
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||||
|
),
|
||||||
invoke_from=application_generate_entity.invoke_from,
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
hit_callback=hit_callback,
|
hit_callback=hit_callback,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ from typing import Any
|
||||||
from core.app.app_config.entities import ModelConfigEntity
|
from core.app.app_config.entities import ModelConfigEntity
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
|
from models.provider_ids import ModelProviderID
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigManager:
|
class ModelConfigManager:
|
||||||
|
|
|
||||||
|
|
@ -114,9 +114,9 @@ class VariableEntity(BaseModel):
|
||||||
hide: bool = False
|
hide: bool = False
|
||||||
max_length: int | None = None
|
max_length: int | None = None
|
||||||
options: Sequence[str] = Field(default_factory=list)
|
options: Sequence[str] = Field(default_factory=list)
|
||||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||||
|
|
||||||
@field_validator("description", mode="before")
|
@field_validator("description", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -129,6 +129,16 @@ class VariableEntity(BaseModel):
|
||||||
return v or []
|
return v or []
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineVariableEntity(VariableEntity):
|
||||||
|
"""
|
||||||
|
Rag Pipeline Variable Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tooltips: str | None = None
|
||||||
|
placeholder: str | None = None
|
||||||
|
belong_to_node_id: str
|
||||||
|
|
||||||
|
|
||||||
class ExternalDataVariableEntity(BaseModel):
|
class ExternalDataVariableEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
External Data Variable Entity.
|
External Data Variable Entity.
|
||||||
|
|
@ -288,7 +298,7 @@ class AppConfig(BaseModel):
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
app_id: str
|
app_id: str
|
||||||
app_mode: AppMode
|
app_mode: AppMode
|
||||||
additional_features: AppAdditionalFeatures
|
additional_features: AppAdditionalFeatures | None = None
|
||||||
variables: list[VariableEntity] = []
|
variables: list[VariableEntity] = []
|
||||||
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
|
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from core.app.app_config.entities import VariableEntity
|
import re
|
||||||
|
|
||||||
|
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,3 +22,48 @@ class WorkflowVariablesConfigManager:
|
||||||
variables.append(VariableEntity.model_validate(variable))
|
variables.append(VariableEntity.model_validate(variable))
|
||||||
|
|
||||||
return variables
|
return variables
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
|
||||||
|
"""
|
||||||
|
Convert workflow start variables to variables
|
||||||
|
|
||||||
|
:param workflow: workflow instance
|
||||||
|
"""
|
||||||
|
variables = []
|
||||||
|
|
||||||
|
# get second step node
|
||||||
|
rag_pipeline_variables = workflow.rag_pipeline_variables
|
||||||
|
if not rag_pipeline_variables:
|
||||||
|
return []
|
||||||
|
variables_map = {item["variable"]: item for item in rag_pipeline_variables}
|
||||||
|
|
||||||
|
# get datasource node data
|
||||||
|
datasource_node_data = None
|
||||||
|
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||||
|
for datasource_node in datasource_nodes:
|
||||||
|
if datasource_node.get("id") == start_node_id:
|
||||||
|
datasource_node_data = datasource_node.get("data", {})
|
||||||
|
break
|
||||||
|
if datasource_node_data:
|
||||||
|
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
|
||||||
|
|
||||||
|
for _, value in datasource_parameters.items():
|
||||||
|
if value.get("value") and isinstance(value.get("value"), str):
|
||||||
|
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
|
||||||
|
match = re.match(pattern, value["value"])
|
||||||
|
if match:
|
||||||
|
full_path = match.group(1)
|
||||||
|
last_part = full_path.split(".")[-1]
|
||||||
|
variables_map.pop(last_part, None)
|
||||||
|
if value.get("value") and isinstance(value.get("value"), list):
|
||||||
|
last_part = value.get("value")[-1]
|
||||||
|
variables_map.pop(last_part, None)
|
||||||
|
|
||||||
|
all_second_step_variables = list(variables_map.values())
|
||||||
|
|
||||||
|
for item in all_second_step_variables:
|
||||||
|
if item.get("belong_to_node_id") == start_node_id or item.get("belong_to_node_id") == "shared":
|
||||||
|
variables.append(RagPipelineVariableEntity.model_validate(item))
|
||||||
|
|
||||||
|
return variables
|
||||||
|
|
|
||||||
|
|
@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
|
|
||||||
if invoke_from == InvokeFrom.DEBUGGER:
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
# always enable retriever resource in debugger mode
|
# always enable retriever resource in debugger mode
|
||||||
app_config.additional_features.show_retrieve_source = True
|
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||||
|
|
||||||
workflow_run_id = str(uuid.uuid4())
|
workflow_run_id = str(uuid.uuid4())
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
|
|
@ -467,7 +467,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
workflow_execution_repository=workflow_execution_repository,
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from),
|
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
|
||||||
)
|
)
|
||||||
|
|
||||||
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
|
|
@ -23,16 +23,17 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
|
||||||
from core.moderation.base import ModerationError
|
from core.moderation.base import ModerationError
|
||||||
from core.moderation.input_moderation import InputModeration
|
from core.moderation.input_moderation import InputModeration
|
||||||
from core.variables.variables import VariableUnion
|
from core.variables.variables import VariableUnion
|
||||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.variable_loader import VariableLoader
|
from core.workflow.variable_loader import VariableLoader
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
from models import Workflow
|
from models import Workflow
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.model import App, Conversation, Message, MessageAnnotation
|
from models.model import App, Conversation, Message, MessageAnnotation
|
||||||
from models.workflow import ConversationVariable, WorkflowType
|
from models.workflow import ConversationVariable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -78,23 +79,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
if not app_record:
|
if not app_record:
|
||||||
raise ValueError("App not found")
|
raise ValueError("App not found")
|
||||||
|
|
||||||
workflow_callbacks: list[WorkflowCallback] = []
|
|
||||||
if dify_config.DEBUG:
|
|
||||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
|
||||||
|
|
||||||
if self.application_generate_entity.single_iteration_run:
|
if self.application_generate_entity.single_iteration_run:
|
||||||
# if only single iteration run is requested
|
# if only single iteration run is requested
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool.empty(),
|
||||||
|
start_at=time.time(),
|
||||||
|
)
|
||||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||||
workflow=self._workflow,
|
workflow=self._workflow,
|
||||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
elif self.application_generate_entity.single_loop_run:
|
elif self.application_generate_entity.single_loop_run:
|
||||||
# if only single loop run is requested
|
# if only single loop run is requested
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool.empty(),
|
||||||
|
start_at=time.time(),
|
||||||
|
)
|
||||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||||
workflow=self._workflow,
|
workflow=self._workflow,
|
||||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inputs = self.application_generate_entity.inputs
|
inputs = self.application_generate_entity.inputs
|
||||||
|
|
@ -146,16 +153,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
)
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
|
||||||
|
graph = self._init_graph(
|
||||||
|
graph_config=self._workflow.graph_dict,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
workflow_id=self._workflow.id,
|
||||||
|
tenant_id=self._workflow.tenant_id,
|
||||||
|
user_id=self.application_generate_entity.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
# RUN WORKFLOW
|
# RUN WORKFLOW
|
||||||
|
# Create Redis command channel for this workflow execution
|
||||||
|
task_id = self.application_generate_entity.task_id
|
||||||
|
channel_key = f"workflow:{task_id}:commands"
|
||||||
|
command_channel = RedisChannel(redis_client, channel_key)
|
||||||
|
|
||||||
workflow_entry = WorkflowEntry(
|
workflow_entry = WorkflowEntry(
|
||||||
tenant_id=self._workflow.tenant_id,
|
tenant_id=self._workflow.tenant_id,
|
||||||
app_id=self._workflow.app_id,
|
app_id=self._workflow.app_id,
|
||||||
workflow_id=self._workflow.id,
|
workflow_id=self._workflow.id,
|
||||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=self._workflow.graph_dict,
|
graph_config=self._workflow.graph_dict,
|
||||||
user_id=self.application_generate_entity.user_id,
|
user_id=self.application_generate_entity.user_id,
|
||||||
|
|
@ -167,11 +185,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
invoke_from=self.application_generate_entity.invoke_from,
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
call_depth=self.application_generate_entity.call_depth,
|
call_depth=self.application_generate_entity.call_depth,
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
command_channel=command_channel,
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = workflow_entry.run(
|
generator = workflow_entry.run()
|
||||||
callbacks=workflow_callbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
for event in generator:
|
for event in generator:
|
||||||
self._handle_event(workflow_entry, event)
|
self._handle_event(workflow_entry, event)
|
||||||
|
|
|
||||||
|
|
@ -31,14 +31,9 @@ from core.app.entities.queue_entities import (
|
||||||
QueueMessageReplaceEvent,
|
QueueMessageReplaceEvent,
|
||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
|
||||||
QueueNodeInLoopFailedEvent,
|
|
||||||
QueueNodeRetryEvent,
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
|
||||||
QueueParallelBranchRunStartedEvent,
|
|
||||||
QueueParallelBranchRunSucceededEvent,
|
|
||||||
QueuePingEvent,
|
QueuePingEvent,
|
||||||
QueueRetrieverResourcesEvent,
|
QueueRetrieverResourcesEvent,
|
||||||
QueueStopEvent,
|
QueueStopEvent,
|
||||||
|
|
@ -65,8 +60,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
from core.workflow.entities import GraphRuntimeState
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
|
|
@ -387,9 +382,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
|
|
||||||
def _handle_node_failed_events(
|
def _handle_node_failed_events(
|
||||||
self,
|
self,
|
||||||
event: Union[
|
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
|
||||||
],
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Generator[StreamResponse, None, None]:
|
) -> Generator[StreamResponse, None, None]:
|
||||||
"""Handle various node failure events."""
|
"""Handle various node failure events."""
|
||||||
|
|
@ -434,32 +427,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_parallel_branch_started_event(
|
|
||||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle parallel branch started events."""
|
|
||||||
self._ensure_workflow_initialized()
|
|
||||||
|
|
||||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
workflow_execution_id=self._workflow_run_id,
|
|
||||||
event=event,
|
|
||||||
)
|
|
||||||
yield parallel_start_resp
|
|
||||||
|
|
||||||
def _handle_parallel_branch_finished_events(
|
|
||||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle parallel branch finished events."""
|
|
||||||
self._ensure_workflow_initialized()
|
|
||||||
|
|
||||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
workflow_execution_id=self._workflow_run_id,
|
|
||||||
event=event,
|
|
||||||
)
|
|
||||||
yield parallel_finish_resp
|
|
||||||
|
|
||||||
def _handle_iteration_start_event(
|
def _handle_iteration_start_event(
|
||||||
self, event: QueueIterationStartEvent, **kwargs
|
self, event: QueueIterationStartEvent, **kwargs
|
||||||
) -> Generator[StreamResponse, None, None]:
|
) -> Generator[StreamResponse, None, None]:
|
||||||
|
|
@ -751,8 +718,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||||
# Parallel branch events
|
|
||||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
|
||||||
# Iteration events
|
# Iteration events
|
||||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||||
|
|
@ -800,8 +765,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
event,
|
event,
|
||||||
(
|
(
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
|
||||||
QueueNodeInLoopFailedEvent,
|
|
||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
|
@ -814,17 +777,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Handle parallel branch finished events with isinstance check
|
|
||||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
|
||||||
yield from self._handle_parallel_branch_finished_events(
|
|
||||||
event,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
tts_publisher=tts_publisher,
|
|
||||||
trace_manager=trace_manager,
|
|
||||||
queue_message=queue_message,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# For unhandled events, we continue (original behavior)
|
# For unhandled events, we continue (original behavior)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -848,11 +800,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||||
graph_runtime_state = event.graph_runtime_state
|
graph_runtime_state = event.graph_runtime_state
|
||||||
yield from self._handle_workflow_started_event(event)
|
yield from self._handle_workflow_started_event(event)
|
||||||
|
|
||||||
case QueueTextChunkEvent():
|
|
||||||
yield from self._handle_text_chunk_event(
|
|
||||||
event, tts_publisher=tts_publisher, queue_message=queue_message
|
|
||||||
)
|
|
||||||
|
|
||||||
case QueueErrorEvent():
|
case QueueErrorEvent():
|
||||||
yield from self._handle_error_event(event)
|
yield from self._handle_error_event(event)
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||||
from core.app.app_config.entities import VariableEntityType
|
from core.app.app_config.entities import VariableEntityType
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file import File, FileUploadConfig
|
from core.file import File, FileUploadConfig
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
from core.workflow.repositories.draft_variable_repository import (
|
from core.workflow.repositories.draft_variable_repository import (
|
||||||
DraftVariableSaver,
|
DraftVariableSaver,
|
||||||
DraftVariableSaverFactory,
|
DraftVariableSaverFactory,
|
||||||
|
|
@ -14,6 +14,7 @@ from core.workflow.repositories.draft_variable_repository import (
|
||||||
)
|
)
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from libs.orjson import orjson_dumps
|
from libs.orjson import orjson_dumps
|
||||||
|
from models import Account, EndUser
|
||||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -44,9 +45,9 @@ class BaseAppGenerator:
|
||||||
mapping=v,
|
mapping=v,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=FileUploadConfig(
|
config=FileUploadConfig(
|
||||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
|
||||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
|
||||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
|
||||||
),
|
),
|
||||||
strict_type_validation=strict_type_validation,
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
|
|
@ -59,9 +60,9 @@ class BaseAppGenerator:
|
||||||
mappings=v,
|
mappings=v,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=FileUploadConfig(
|
config=FileUploadConfig(
|
||||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
|
||||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
|
||||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
for k, v in user_inputs.items()
|
for k, v in user_inputs.items()
|
||||||
|
|
@ -182,8 +183,9 @@ class BaseAppGenerator:
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_draft_var_saver_factory(invoke_from: InvokeFrom) -> DraftVariableSaverFactory:
|
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
|
||||||
if invoke_from == InvokeFrom.DEBUGGER:
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
|
assert isinstance(account, Account)
|
||||||
|
|
||||||
def draft_var_saver_factory(
|
def draft_var_saver_factory(
|
||||||
session: Session,
|
session: Session,
|
||||||
|
|
@ -200,6 +202,7 @@ class BaseAppGenerator:
|
||||||
node_type=node_type,
|
node_type=node_type,
|
||||||
node_execution_id=node_execution_id,
|
node_execution_id=node_execution_id,
|
||||||
enclosing_node_id=enclosing_node_id,
|
enclosing_node_id=enclosing_node_id,
|
||||||
|
user=account,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -127,6 +127,21 @@ class AppQueueManager:
|
||||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||||
redis_client.setex(stopped_cache_key, 600, 1)
|
redis_client.setex(stopped_cache_key, 600, 1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Set task stop flag without user permission check.
|
||||||
|
This method allows stopping workflows without user context.
|
||||||
|
|
||||||
|
:param task_id: The task ID to stop
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not task_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||||
|
redis_client.setex(stopped_cache_key, 600, 1)
|
||||||
|
|
||||||
def _is_stopped(self) -> bool:
|
def _is_stopped(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if task is stopped
|
Check if task is stopped
|
||||||
|
|
|
||||||
|
|
@ -164,7 +164,9 @@ class ChatAppRunner(AppRunner):
|
||||||
config=app_config.dataset,
|
config=app_config.dataset,
|
||||||
query=query,
|
query=query,
|
||||||
invoke_from=application_generate_entity.invoke_from,
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
show_retrieve_source=(
|
||||||
|
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||||
|
),
|
||||||
hit_callback=hit_callback,
|
hit_callback=hit_callback,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import time
|
import time
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any, Union, cast
|
from typing import Any, Union
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
@ -16,14 +16,9 @@ from core.app.entities.queue_entities import (
|
||||||
QueueLoopStartEvent,
|
QueueLoopStartEvent,
|
||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
|
||||||
QueueNodeInLoopFailedEvent,
|
|
||||||
QueueNodeRetryEvent,
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
|
||||||
QueueParallelBranchRunStartedEvent,
|
|
||||||
QueueParallelBranchRunSucceededEvent,
|
|
||||||
)
|
)
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
AgentLogStreamResponse,
|
AgentLogStreamResponse,
|
||||||
|
|
@ -36,24 +31,23 @@ from core.app.entities.task_entities import (
|
||||||
NodeFinishStreamResponse,
|
NodeFinishStreamResponse,
|
||||||
NodeRetryStreamResponse,
|
NodeRetryStreamResponse,
|
||||||
NodeStartStreamResponse,
|
NodeStartStreamResponse,
|
||||||
ParallelBranchFinishedStreamResponse,
|
|
||||||
ParallelBranchStartStreamResponse,
|
|
||||||
WorkflowFinishStreamResponse,
|
WorkflowFinishStreamResponse,
|
||||||
WorkflowStartStreamResponse,
|
WorkflowStartStreamResponse,
|
||||||
)
|
)
|
||||||
from core.file import FILE_MODEL_IDENTITY, File
|
from core.file import FILE_MODEL_IDENTITY, File
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.nodes import NodeType
|
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
|
||||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models import (
|
from models import (
|
||||||
Account,
|
Account,
|
||||||
EndUser,
|
EndUser,
|
||||||
)
|
)
|
||||||
|
from services.variable_truncator import VariableTruncator
|
||||||
|
|
||||||
|
|
||||||
class WorkflowResponseConverter:
|
class WorkflowResponseConverter:
|
||||||
|
|
@ -65,6 +59,7 @@ class WorkflowResponseConverter:
|
||||||
):
|
):
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._user = user
|
self._user = user
|
||||||
|
self._truncator = VariableTruncator.default()
|
||||||
|
|
||||||
def workflow_start_to_stream_response(
|
def workflow_start_to_stream_response(
|
||||||
self,
|
self,
|
||||||
|
|
@ -156,7 +151,8 @@ class WorkflowResponseConverter:
|
||||||
title=workflow_node_execution.title,
|
title=workflow_node_execution.title,
|
||||||
index=workflow_node_execution.index,
|
index=workflow_node_execution.index,
|
||||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||||
inputs=workflow_node_execution.inputs,
|
inputs=workflow_node_execution.get_response_inputs(),
|
||||||
|
inputs_truncated=workflow_node_execution.inputs_truncated,
|
||||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
|
@ -171,11 +167,19 @@ class WorkflowResponseConverter:
|
||||||
|
|
||||||
# extras logic
|
# extras logic
|
||||||
if event.node_type == NodeType.TOOL:
|
if event.node_type == NodeType.TOOL:
|
||||||
node_data = cast(ToolNodeData, event.node_data)
|
|
||||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||||
provider_type=node_data.provider_type,
|
provider_type=ToolProviderType(event.provider_type),
|
||||||
provider_id=node_data.provider_id,
|
provider_id=event.provider_id,
|
||||||
|
)
|
||||||
|
elif event.node_type == NodeType.DATASOURCE:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
provider_entity = manager.fetch_datasource_provider(
|
||||||
|
self._application_generate_entity.app_config.tenant_id,
|
||||||
|
event.provider_id,
|
||||||
|
)
|
||||||
|
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
|
||||||
|
self._application_generate_entity.app_config.tenant_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
@ -183,11 +187,7 @@ class WorkflowResponseConverter:
|
||||||
def workflow_node_finish_to_stream_response(
|
def workflow_node_finish_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
event: QueueNodeSucceededEvent
|
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||||
| QueueNodeFailedEvent
|
|
||||||
| QueueNodeInIterationFailedEvent
|
|
||||||
| QueueNodeInLoopFailedEvent
|
|
||||||
| QueueNodeExceptionEvent,
|
|
||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_node_execution: WorkflowNodeExecution,
|
workflow_node_execution: WorkflowNodeExecution,
|
||||||
) -> NodeFinishStreamResponse | None:
|
) -> NodeFinishStreamResponse | None:
|
||||||
|
|
@ -210,9 +210,12 @@ class WorkflowResponseConverter:
|
||||||
index=workflow_node_execution.index,
|
index=workflow_node_execution.index,
|
||||||
title=workflow_node_execution.title,
|
title=workflow_node_execution.title,
|
||||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||||
inputs=workflow_node_execution.inputs,
|
inputs=workflow_node_execution.get_response_inputs(),
|
||||||
process_data=workflow_node_execution.process_data,
|
inputs_truncated=workflow_node_execution.inputs_truncated,
|
||||||
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
|
process_data=workflow_node_execution.get_response_process_data(),
|
||||||
|
process_data_truncated=workflow_node_execution.process_data_truncated,
|
||||||
|
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
|
||||||
|
outputs_truncated=workflow_node_execution.outputs_truncated,
|
||||||
status=workflow_node_execution.status,
|
status=workflow_node_execution.status,
|
||||||
error=workflow_node_execution.error,
|
error=workflow_node_execution.error,
|
||||||
elapsed_time=workflow_node_execution.elapsed_time,
|
elapsed_time=workflow_node_execution.elapsed_time,
|
||||||
|
|
@ -221,9 +224,6 @@ class WorkflowResponseConverter:
|
||||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||||
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
iteration_id=event.in_iteration_id,
|
iteration_id=event.in_iteration_id,
|
||||||
loop_id=event.in_loop_id,
|
loop_id=event.in_loop_id,
|
||||||
),
|
),
|
||||||
|
|
@ -255,9 +255,12 @@ class WorkflowResponseConverter:
|
||||||
index=workflow_node_execution.index,
|
index=workflow_node_execution.index,
|
||||||
title=workflow_node_execution.title,
|
title=workflow_node_execution.title,
|
||||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||||
inputs=workflow_node_execution.inputs,
|
inputs=workflow_node_execution.get_response_inputs(),
|
||||||
process_data=workflow_node_execution.process_data,
|
inputs_truncated=workflow_node_execution.inputs_truncated,
|
||||||
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
|
process_data=workflow_node_execution.get_response_process_data(),
|
||||||
|
process_data_truncated=workflow_node_execution.process_data_truncated,
|
||||||
|
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
|
||||||
|
outputs_truncated=workflow_node_execution.outputs_truncated,
|
||||||
status=workflow_node_execution.status,
|
status=workflow_node_execution.status,
|
||||||
error=workflow_node_execution.error,
|
error=workflow_node_execution.error,
|
||||||
elapsed_time=workflow_node_execution.elapsed_time,
|
elapsed_time=workflow_node_execution.elapsed_time,
|
||||||
|
|
@ -275,50 +278,6 @@ class WorkflowResponseConverter:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def workflow_parallel_branch_start_to_stream_response(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
task_id: str,
|
|
||||||
workflow_execution_id: str,
|
|
||||||
event: QueueParallelBranchRunStartedEvent,
|
|
||||||
) -> ParallelBranchStartStreamResponse:
|
|
||||||
return ParallelBranchStartStreamResponse(
|
|
||||||
task_id=task_id,
|
|
||||||
workflow_run_id=workflow_execution_id,
|
|
||||||
data=ParallelBranchStartStreamResponse.Data(
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_branch_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
iteration_id=event.in_iteration_id,
|
|
||||||
loop_id=event.in_loop_id,
|
|
||||||
created_at=int(time.time()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def workflow_parallel_branch_finished_to_stream_response(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
task_id: str,
|
|
||||||
workflow_execution_id: str,
|
|
||||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
|
||||||
) -> ParallelBranchFinishedStreamResponse:
|
|
||||||
return ParallelBranchFinishedStreamResponse(
|
|
||||||
task_id=task_id,
|
|
||||||
workflow_run_id=workflow_execution_id,
|
|
||||||
data=ParallelBranchFinishedStreamResponse.Data(
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_branch_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
iteration_id=event.in_iteration_id,
|
|
||||||
loop_id=event.in_loop_id,
|
|
||||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
|
||||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
|
||||||
created_at=int(time.time()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def workflow_iteration_start_to_stream_response(
|
def workflow_iteration_start_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|
@ -326,6 +285,7 @@ class WorkflowResponseConverter:
|
||||||
workflow_execution_id: str,
|
workflow_execution_id: str,
|
||||||
event: QueueIterationStartEvent,
|
event: QueueIterationStartEvent,
|
||||||
) -> IterationNodeStartStreamResponse:
|
) -> IterationNodeStartStreamResponse:
|
||||||
|
new_inputs, truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
|
||||||
return IterationNodeStartStreamResponse(
|
return IterationNodeStartStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
workflow_run_id=workflow_execution_id,
|
workflow_run_id=workflow_execution_id,
|
||||||
|
|
@ -333,13 +293,12 @@ class WorkflowResponseConverter:
|
||||||
id=event.node_id,
|
id=event.node_id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type.value,
|
node_type=event.node_type.value,
|
||||||
title=event.node_data.title,
|
title=event.node_title,
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
extras={},
|
extras={},
|
||||||
inputs=event.inputs or {},
|
inputs=new_inputs,
|
||||||
|
inputs_truncated=truncated,
|
||||||
metadata=event.metadata or {},
|
metadata=event.metadata or {},
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -357,15 +316,10 @@ class WorkflowResponseConverter:
|
||||||
id=event.node_id,
|
id=event.node_id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type.value,
|
node_type=event.node_type.value,
|
||||||
title=event.node_data.title,
|
title=event.node_title,
|
||||||
index=event.index,
|
index=event.index,
|
||||||
pre_iteration_output=event.output,
|
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
extras={},
|
extras={},
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
|
||||||
duration=event.duration,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -377,6 +331,11 @@ class WorkflowResponseConverter:
|
||||||
event: QueueIterationCompletedEvent,
|
event: QueueIterationCompletedEvent,
|
||||||
) -> IterationNodeCompletedStreamResponse:
|
) -> IterationNodeCompletedStreamResponse:
|
||||||
json_converter = WorkflowRuntimeTypeConverter()
|
json_converter = WorkflowRuntimeTypeConverter()
|
||||||
|
|
||||||
|
new_inputs, inputs_truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
|
||||||
|
new_outputs, outputs_truncated = self._truncator.truncate_variable_mapping(
|
||||||
|
json_converter.to_json_encodable(event.outputs) or {}
|
||||||
|
)
|
||||||
return IterationNodeCompletedStreamResponse(
|
return IterationNodeCompletedStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
workflow_run_id=workflow_execution_id,
|
workflow_run_id=workflow_execution_id,
|
||||||
|
|
@ -384,28 +343,29 @@ class WorkflowResponseConverter:
|
||||||
id=event.node_id,
|
id=event.node_id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type.value,
|
node_type=event.node_type.value,
|
||||||
title=event.node_data.title,
|
title=event.node_title,
|
||||||
outputs=json_converter.to_json_encodable(event.outputs),
|
outputs=new_outputs,
|
||||||
|
outputs_truncated=outputs_truncated,
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
extras={},
|
extras={},
|
||||||
inputs=event.inputs or {},
|
inputs=new_inputs,
|
||||||
|
inputs_truncated=inputs_truncated,
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
if event.error is None
|
if event.error is None
|
||||||
else WorkflowNodeExecutionStatus.FAILED,
|
else WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=None,
|
error=None,
|
||||||
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
|
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
|
||||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
|
||||||
execution_metadata=event.metadata,
|
execution_metadata=event.metadata,
|
||||||
finished_at=int(time.time()),
|
finished_at=int(time.time()),
|
||||||
steps=event.steps,
|
steps=event.steps,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def workflow_loop_start_to_stream_response(
|
def workflow_loop_start_to_stream_response(
|
||||||
self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent
|
self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent
|
||||||
) -> LoopNodeStartStreamResponse:
|
) -> LoopNodeStartStreamResponse:
|
||||||
|
new_inputs, truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
|
||||||
return LoopNodeStartStreamResponse(
|
return LoopNodeStartStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
workflow_run_id=workflow_execution_id,
|
workflow_run_id=workflow_execution_id,
|
||||||
|
|
@ -413,10 +373,11 @@ class WorkflowResponseConverter:
|
||||||
id=event.node_id,
|
id=event.node_id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type.value,
|
node_type=event.node_type.value,
|
||||||
title=event.node_data.title,
|
title=event.node_title,
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
extras={},
|
extras={},
|
||||||
inputs=event.inputs or {},
|
inputs=new_inputs,
|
||||||
|
inputs_truncated=truncated,
|
||||||
metadata=event.metadata or {},
|
metadata=event.metadata or {},
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
|
@ -437,15 +398,16 @@ class WorkflowResponseConverter:
|
||||||
id=event.node_id,
|
id=event.node_id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type.value,
|
node_type=event.node_type.value,
|
||||||
title=event.node_data.title,
|
title=event.node_title,
|
||||||
index=event.index,
|
index=event.index,
|
||||||
pre_loop_output=event.output,
|
# The `pre_loop_output` field is not utilized by the frontend.
|
||||||
|
# Previously, it was assigned the value of `event.output`.
|
||||||
|
pre_loop_output={},
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
extras={},
|
extras={},
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||||
duration=event.duration,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -456,6 +418,11 @@ class WorkflowResponseConverter:
|
||||||
workflow_execution_id: str,
|
workflow_execution_id: str,
|
||||||
event: QueueLoopCompletedEvent,
|
event: QueueLoopCompletedEvent,
|
||||||
) -> LoopNodeCompletedStreamResponse:
|
) -> LoopNodeCompletedStreamResponse:
|
||||||
|
json_converter = WorkflowRuntimeTypeConverter()
|
||||||
|
new_inputs, inputs_truncated = self._truncator.truncate_variable_mapping(event.inputs or {})
|
||||||
|
new_outputs, outputs_truncated = self._truncator.truncate_variable_mapping(
|
||||||
|
json_converter.to_json_encodable(event.outputs) or {}
|
||||||
|
)
|
||||||
return LoopNodeCompletedStreamResponse(
|
return LoopNodeCompletedStreamResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
workflow_run_id=workflow_execution_id,
|
workflow_run_id=workflow_execution_id,
|
||||||
|
|
@ -463,17 +430,19 @@ class WorkflowResponseConverter:
|
||||||
id=event.node_id,
|
id=event.node_id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type.value,
|
node_type=event.node_type.value,
|
||||||
title=event.node_data.title,
|
title=event.node_title,
|
||||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
|
outputs=new_outputs,
|
||||||
|
outputs_truncated=outputs_truncated,
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
extras={},
|
extras={},
|
||||||
inputs=event.inputs or {},
|
inputs=new_inputs,
|
||||||
|
inputs_truncated=inputs_truncated,
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
if event.error is None
|
if event.error is None
|
||||||
else WorkflowNodeExecutionStatus.FAILED,
|
else WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=None,
|
error=None,
|
||||||
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
|
elapsed_time=(naive_utc_now() - event.start_at).total_seconds(),
|
||||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
total_tokens=(lambda x: x if isinstance(x, int) else 0)(event.metadata.get("total_tokens", 0)),
|
||||||
execution_metadata=event.metadata,
|
execution_metadata=event.metadata,
|
||||||
finished_at=int(time.time()),
|
finished_at=int(time.time()),
|
||||||
steps=event.steps,
|
steps=event.steps,
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,9 @@ class CompletionAppRunner(AppRunner):
|
||||||
config=dataset_config,
|
config=dataset_config,
|
||||||
query=query or "",
|
query=query or "",
|
||||||
invoke_from=application_generate_entity.invoke_from,
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
show_retrieve_source=app_config.additional_features.show_retrieve_source
|
||||||
|
if app_config.additional_features
|
||||||
|
else False,
|
||||||
hit_callback=hit_callback,
|
hit_callback=hit_callback,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
|
from core.app.entities.task_entities import (
|
||||||
|
AppStreamResponse,
|
||||||
|
ErrorStreamResponse,
|
||||||
|
NodeFinishStreamResponse,
|
||||||
|
NodeStartStreamResponse,
|
||||||
|
PingStreamResponse,
|
||||||
|
WorkflowAppBlockingResponse,
|
||||||
|
WorkflowAppStreamResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
|
_blocking_response_type = WorkflowAppBlockingResponse
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||||
|
"""
|
||||||
|
Convert blocking full response.
|
||||||
|
:param blocking_response: blocking response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return dict(blocking_response.model_dump())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||||
|
"""
|
||||||
|
Convert blocking simple response.
|
||||||
|
:param blocking_response: blocking response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return cls.convert_blocking_full_response(blocking_response)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_stream_full_response(
|
||||||
|
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||||
|
) -> Generator[dict | str, None, None]:
|
||||||
|
"""
|
||||||
|
Convert stream full response.
|
||||||
|
:param stream_response: stream response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for chunk in stream_response:
|
||||||
|
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||||
|
sub_stream_response = chunk.stream_response
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, PingStreamResponse):
|
||||||
|
yield "ping"
|
||||||
|
continue
|
||||||
|
|
||||||
|
response_chunk = {
|
||||||
|
"event": sub_stream_response.event.value,
|
||||||
|
"workflow_run_id": chunk.workflow_run_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||||
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
|
response_chunk.update(cast(dict, data))
|
||||||
|
else:
|
||||||
|
response_chunk.update(sub_stream_response.model_dump())
|
||||||
|
yield response_chunk
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_stream_simple_response(
|
||||||
|
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||||
|
) -> Generator[dict | str, None, None]:
|
||||||
|
"""
|
||||||
|
Convert stream simple response.
|
||||||
|
:param stream_response: stream response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for chunk in stream_response:
|
||||||
|
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||||
|
sub_stream_response = chunk.stream_response
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, PingStreamResponse):
|
||||||
|
yield "ping"
|
||||||
|
continue
|
||||||
|
|
||||||
|
response_chunk = {
|
||||||
|
"event": sub_stream_response.event.value,
|
||||||
|
"workflow_run_id": chunk.workflow_run_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||||
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
|
response_chunk.update(cast(dict, data))
|
||||||
|
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||||
|
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
|
||||||
|
else:
|
||||||
|
response_chunk.update(sub_stream_response.model_dump())
|
||||||
|
yield response_chunk
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||||
|
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||||
|
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
|
||||||
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
|
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||||
|
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from models.model import AppMode
|
||||||
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineConfig(WorkflowUIBasedAppConfig):
|
||||||
|
"""
|
||||||
|
Pipeline Config Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineConfigManager(BaseAppConfigManager):
|
||||||
|
@classmethod
|
||||||
|
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow, start_node_id: str) -> PipelineConfig:
|
||||||
|
pipeline_config = PipelineConfig(
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
app_id=pipeline.id,
|
||||||
|
app_mode=AppMode.RAG_PIPELINE,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(
|
||||||
|
workflow=workflow, start_node_id=start_node_id
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return pipeline_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
Validate for pipeline config
|
||||||
|
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param config: app model config args
|
||||||
|
:param only_structure_validate: only validate the structure of the config
|
||||||
|
"""
|
||||||
|
related_config_keys = []
|
||||||
|
|
||||||
|
# file upload validation
|
||||||
|
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# text_to_speech
|
||||||
|
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# moderation validation
|
||||||
|
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||||
|
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||||
|
)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
related_config_keys = list(set(related_config_keys))
|
||||||
|
|
||||||
|
# Filter out extra parameters
|
||||||
|
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||||
|
|
||||||
|
return filtered_config
|
||||||
|
|
@ -0,0 +1,851 @@
|
||||||
|
import contextvars
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any, Literal, Union, cast, overload
|
||||||
|
|
||||||
|
from flask import Flask, current_app
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
import contexts
|
||||||
|
from configs import dify_config
|
||||||
|
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
|
from core.app.apps.exc import GenerateTaskStoppedError
|
||||||
|
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||||
|
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||||
|
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||||
|
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||||
|
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||||
|
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||||
|
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceProviderType,
|
||||||
|
OnlineDriveBrowseFilesRequest,
|
||||||
|
)
|
||||||
|
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||||
|
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
|
||||||
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
|
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||||
|
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||||
|
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||||
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
|
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||||
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
|
from models.model import AppMode
|
||||||
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
|
from services.feature_service import FeatureService
|
||||||
|
from services.file_service import FileService
|
||||||
|
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||||
|
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||||
|
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineGenerator(BaseAppGenerator):
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: Literal[True],
|
||||||
|
call_depth: int,
|
||||||
|
workflow_thread_pool_id: str | None,
|
||||||
|
is_retry: bool = False,
|
||||||
|
) -> Generator[Mapping | str, None, None]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: Literal[False],
|
||||||
|
call_depth: int,
|
||||||
|
workflow_thread_pool_id: str | None,
|
||||||
|
is_retry: bool = False,
|
||||||
|
) -> Mapping[str, Any]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: bool,
|
||||||
|
call_depth: int,
|
||||||
|
workflow_thread_pool_id: str | None,
|
||||||
|
is_retry: bool = False,
|
||||||
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: bool = True,
|
||||||
|
call_depth: int = 0,
|
||||||
|
workflow_thread_pool_id: str | None = None,
|
||||||
|
is_retry: bool = False,
|
||||||
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||||
|
# Add null check for dataset
|
||||||
|
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
dataset = pipeline.retrieve_dataset(session)
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("Pipeline dataset is required")
|
||||||
|
inputs: Mapping[str, Any] = args["inputs"]
|
||||||
|
start_node_id: str = args["start_node_id"]
|
||||||
|
datasource_type: str = args["datasource_type"]
|
||||||
|
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
|
||||||
|
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
|
||||||
|
)
|
||||||
|
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
|
||||||
|
# convert to app config
|
||||||
|
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||||
|
pipeline=pipeline, workflow=workflow, start_node_id=start_node_id
|
||||||
|
)
|
||||||
|
documents: list[Document] = []
|
||||||
|
if invoke_from == InvokeFrom.PUBLISHED and not is_retry and not args.get("original_document_id"):
|
||||||
|
from services.dataset_service import DocumentService
|
||||||
|
|
||||||
|
for datasource_info in datasource_info_list:
|
||||||
|
position = DocumentService.get_documents_position(dataset.id)
|
||||||
|
document = self._build_document(
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
built_in_field_enabled=dataset.built_in_field_enabled,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
created_from="rag-pipeline",
|
||||||
|
position=position,
|
||||||
|
account=user,
|
||||||
|
batch=batch,
|
||||||
|
document_form=dataset.chunk_structure,
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
documents.append(document)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# run in child thread
|
||||||
|
rag_pipeline_invoke_entities = []
|
||||||
|
for i, datasource_info in enumerate(datasource_info_list):
|
||||||
|
workflow_run_id = str(uuid.uuid4())
|
||||||
|
document_id = args.get("original_document_id") or None
|
||||||
|
if invoke_from == InvokeFrom.PUBLISHED and not is_retry:
|
||||||
|
document_id = document_id or documents[i].id
|
||||||
|
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||||
|
document_id=document_id,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
datasource_info=json.dumps(datasource_info),
|
||||||
|
datasource_node_id=start_node_id,
|
||||||
|
input_data=inputs,
|
||||||
|
pipeline_id=pipeline.id,
|
||||||
|
created_by=user.id,
|
||||||
|
)
|
||||||
|
db.session.add(document_pipeline_execution_log)
|
||||||
|
db.session.commit()
|
||||||
|
application_generate_entity = RagPipelineGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
app_config=pipeline_config,
|
||||||
|
pipeline_config=pipeline_config,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
original_document_id=args.get("original_document_id"),
|
||||||
|
start_node_id=start_node_id,
|
||||||
|
batch=batch,
|
||||||
|
document_id=document_id,
|
||||||
|
inputs=self._prepare_user_inputs(
|
||||||
|
user_inputs=inputs,
|
||||||
|
variables=pipeline_config.rag_pipeline_variables,
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||||
|
),
|
||||||
|
files=[],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=streaming,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
call_depth=call_depth,
|
||||||
|
workflow_execution_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts.plugin_tool_providers.set({})
|
||||||
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
|
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
|
||||||
|
else:
|
||||||
|
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
|
||||||
|
# Create workflow node execution repository
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=workflow_triggered_from,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||||
|
)
|
||||||
|
if invoke_from == InvokeFrom.DEBUGGER or is_retry:
|
||||||
|
return self._generate(
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
context=contextvars.copy_context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
user=user,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
streaming=streaming,
|
||||||
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rag_pipeline_invoke_entities.append(
|
||||||
|
RagPipelineInvokeEntity(
|
||||||
|
pipeline_id=pipeline.id,
|
||||||
|
user_id=user.id,
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
streaming=streaming,
|
||||||
|
workflow_execution_id=workflow_run_id,
|
||||||
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
|
application_generate_entity=application_generate_entity.model_dump(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rag_pipeline_invoke_entities:
|
||||||
|
# store the rag_pipeline_invoke_entities to object storage
|
||||||
|
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
|
||||||
|
name = "rag_pipeline_invoke_entities.json"
|
||||||
|
# Convert list to proper JSON string
|
||||||
|
json_text = json.dumps(text)
|
||||||
|
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||||
|
features = FeatureService.get_features(dataset.tenant_id)
|
||||||
|
if features.billing.subscription.plan == "sandbox":
|
||||||
|
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||||
|
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||||
|
|
||||||
|
if redis_client.get(tenant_pipeline_task_key):
|
||||||
|
# Add to waiting queue using List operations (lpush)
|
||||||
|
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||||
|
else:
|
||||||
|
# Set flag and execute task
|
||||||
|
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
|
||||||
|
rag_pipeline_run_task.delay( # type: ignore
|
||||||
|
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||||
|
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# return batch, dataset, documents
|
||||||
|
return {
|
||||||
|
"batch": batch,
|
||||||
|
"dataset": PipelineDataset(
|
||||||
|
id=dataset.id,
|
||||||
|
name=dataset.name,
|
||||||
|
description=dataset.description,
|
||||||
|
chunk_structure=dataset.chunk_structure,
|
||||||
|
).model_dump(),
|
||||||
|
"documents": [
|
||||||
|
PipelineDocument(
|
||||||
|
id=document.id,
|
||||||
|
position=document.position,
|
||||||
|
data_source_type=document.data_source_type,
|
||||||
|
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
|
||||||
|
name=document.name,
|
||||||
|
indexing_status=document.indexing_status,
|
||||||
|
error=document.error,
|
||||||
|
enabled=document.enabled,
|
||||||
|
).model_dump()
|
||||||
|
for document in documents
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
flask_app: Flask,
|
||||||
|
context: contextvars.Context,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow_id: str,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
|
streaming: bool = True,
|
||||||
|
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||||
|
workflow_thread_pool_id: str | None = None,
|
||||||
|
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param pipeline: Pipeline
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param user: account or end user
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param invoke_from: invoke from source
|
||||||
|
:param workflow_execution_repository: repository for workflow execution
|
||||||
|
:param workflow_node_execution_repository: repository for workflow node execution
|
||||||
|
:param streaming: is stream
|
||||||
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
|
"""
|
||||||
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
|
# init queue manager
|
||||||
|
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
|
||||||
|
if not workflow:
|
||||||
|
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||||
|
queue_manager = PipelineQueueManager(
|
||||||
|
task_id=application_generate_entity.task_id,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
|
app_mode=AppMode.RAG_PIPELINE,
|
||||||
|
)
|
||||||
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
|
# new thread
|
||||||
|
worker_thread = threading.Thread(
|
||||||
|
target=self._generate_worker,
|
||||||
|
kwargs={
|
||||||
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
|
"context": context,
|
||||||
|
"queue_manager": queue_manager,
|
||||||
|
"application_generate_entity": application_generate_entity,
|
||||||
|
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||||
|
"variable_loader": variable_loader,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_thread.start()
|
||||||
|
|
||||||
|
draft_var_saver_factory = self._get_draft_var_saver_factory(
|
||||||
|
invoke_from,
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
# return response or stream generator
|
||||||
|
response = self._handle_response(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow=workflow,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
user=user,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
stream=streaming,
|
||||||
|
draft_var_saver_factory=draft_var_saver_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||||
|
|
||||||
|
def single_iteration_generate(
|
||||||
|
self,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
node_id: str,
|
||||||
|
user: Account | EndUser,
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
streaming: bool = True,
|
||||||
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param app_model: App
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param node_id: the node id
|
||||||
|
:param user: account or end user
|
||||||
|
:param args: request args
|
||||||
|
:param streaming: is streamed
|
||||||
|
"""
|
||||||
|
if not node_id:
|
||||||
|
raise ValueError("node_id is required")
|
||||||
|
|
||||||
|
if args.get("inputs") is None:
|
||||||
|
raise ValueError("inputs is required")
|
||||||
|
|
||||||
|
# convert to app config
|
||||||
|
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||||
|
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
|
||||||
|
)
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
dataset = pipeline.retrieve_dataset(session)
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("Pipeline dataset is required")
|
||||||
|
|
||||||
|
# init application generate entity - use RagPipelineGenerateEntity instead
|
||||||
|
application_generate_entity = RagPipelineGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
app_config=pipeline_config,
|
||||||
|
pipeline_config=pipeline_config,
|
||||||
|
datasource_type=args.get("datasource_type", ""),
|
||||||
|
datasource_info=args.get("datasource_info", {}),
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
batch=args.get("batch", ""),
|
||||||
|
document_id=args.get("document_id"),
|
||||||
|
inputs={},
|
||||||
|
files=[],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=streaming,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
call_depth=0,
|
||||||
|
workflow_execution_id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
contexts.plugin_tool_providers.set({})
|
||||||
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
# Create workflow node execution repository
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
|
)
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||||
|
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||||
|
var_loader = DraftVarLoader(
|
||||||
|
engine=db.engine,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._generate(
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
user=user,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
streaming=streaming,
|
||||||
|
variable_loader=var_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
def single_loop_generate(
|
||||||
|
self,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
node_id: str,
|
||||||
|
user: Account | EndUser,
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
streaming: bool = True,
|
||||||
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param app_model: App
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param node_id: the node id
|
||||||
|
:param user: account or end user
|
||||||
|
:param args: request args
|
||||||
|
:param streaming: is streamed
|
||||||
|
"""
|
||||||
|
if not node_id:
|
||||||
|
raise ValueError("node_id is required")
|
||||||
|
|
||||||
|
if args.get("inputs") is None:
|
||||||
|
raise ValueError("inputs is required")
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
dataset = pipeline.retrieve_dataset(session)
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("Pipeline dataset is required")
|
||||||
|
|
||||||
|
# convert to app config
|
||||||
|
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||||
|
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
|
||||||
|
)
|
||||||
|
|
||||||
|
# init application generate entity
|
||||||
|
application_generate_entity = RagPipelineGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
app_config=pipeline_config,
|
||||||
|
pipeline_config=pipeline_config,
|
||||||
|
datasource_type=args.get("datasource_type", ""),
|
||||||
|
datasource_info=args.get("datasource_info", {}),
|
||||||
|
batch=args.get("batch", ""),
|
||||||
|
document_id=args.get("document_id"),
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
inputs={},
|
||||||
|
files=[],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=streaming,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
extras={"auto_generate_conversation_name": False},
|
||||||
|
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||||
|
workflow_execution_id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
contexts.plugin_tool_providers.set({})
|
||||||
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
|
# Create workflow node execution repository
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
|
)
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||||
|
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||||
|
var_loader = DraftVarLoader(
|
||||||
|
engine=db.engine,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._generate(
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
user=user,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
streaming=streaming,
|
||||||
|
variable_loader=var_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_worker(
|
||||||
|
self,
|
||||||
|
flask_app: Flask,
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
context: contextvars.Context,
|
||||||
|
variable_loader: VariableLoader,
|
||||||
|
workflow_thread_pool_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Generate worker in a new thread.
|
||||||
|
:param flask_app: Flask app
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
|
try:
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
workflow = session.scalar(
|
||||||
|
select(Workflow).where(
|
||||||
|
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||||
|
Workflow.app_id == application_generate_entity.app_config.app_id,
|
||||||
|
Workflow.id == application_generate_entity.app_config.workflow_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if workflow is None:
|
||||||
|
raise ValueError("Workflow not found")
|
||||||
|
|
||||||
|
# Determine system_user_id based on invocation source
|
||||||
|
is_external_api_call = application_generate_entity.invoke_from in {
|
||||||
|
InvokeFrom.WEB_APP,
|
||||||
|
InvokeFrom.SERVICE_API,
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_external_api_call:
|
||||||
|
# For external API calls, use end user's session ID
|
||||||
|
end_user = session.scalar(
|
||||||
|
select(EndUser).where(EndUser.id == application_generate_entity.user_id)
|
||||||
|
)
|
||||||
|
system_user_id = end_user.session_id if end_user else ""
|
||||||
|
else:
|
||||||
|
# For internal calls, use the original user ID
|
||||||
|
system_user_id = application_generate_entity.user_id
|
||||||
|
# workflow app
|
||||||
|
runner = PipelineRunner(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
|
variable_loader=variable_loader,
|
||||||
|
workflow=workflow,
|
||||||
|
system_user_id=system_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner.run()
|
||||||
|
except GenerateTaskStoppedError:
|
||||||
|
pass
|
||||||
|
except InvokeAuthorizationError:
|
||||||
|
queue_manager.publish_error(
|
||||||
|
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.exception("Validation Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
except ValueError as e:
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
logger.exception("Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Unknown Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
finally:
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
def _handle_response(
|
||||||
|
self,
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
workflow: Workflow,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
|
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||||
|
"""
|
||||||
|
Handle response.
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param workflow: workflow
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param user: account or end user
|
||||||
|
:param stream: is stream
|
||||||
|
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# init generate task pipeline
|
||||||
|
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow=workflow,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
user=user,
|
||||||
|
stream=stream,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
draft_var_saver_factory=draft_var_saver_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return generate_task_pipeline.process()
|
||||||
|
except ValueError as e:
|
||||||
|
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||||
|
raise GenerateTaskStoppedError()
|
||||||
|
else:
|
||||||
|
logger.exception(
|
||||||
|
"Fails to process generate task pipeline, task_id: %r",
|
||||||
|
application_generate_entity.task_id,
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _build_document(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
built_in_field_enabled: bool,
|
||||||
|
datasource_type: str,
|
||||||
|
datasource_info: Mapping[str, Any],
|
||||||
|
created_from: str,
|
||||||
|
position: int,
|
||||||
|
account: Union[Account, EndUser],
|
||||||
|
batch: str,
|
||||||
|
document_form: str,
|
||||||
|
):
|
||||||
|
if datasource_type == "local_file":
|
||||||
|
name = datasource_info.get("name", "untitled")
|
||||||
|
elif datasource_type == "online_document":
|
||||||
|
name = datasource_info.get("page", {}).get("page_name", "untitled")
|
||||||
|
elif datasource_type == "website_crawl":
|
||||||
|
name = datasource_info.get("title", "untitled")
|
||||||
|
elif datasource_type == "online_drive":
|
||||||
|
name = datasource_info.get("name", "untitled")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||||
|
|
||||||
|
document = Document(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
position=position,
|
||||||
|
data_source_type=datasource_type,
|
||||||
|
data_source_info=json.dumps(datasource_info),
|
||||||
|
batch=batch,
|
||||||
|
name=name,
|
||||||
|
created_from=created_from,
|
||||||
|
created_by=account.id,
|
||||||
|
doc_form=document_form,
|
||||||
|
)
|
||||||
|
doc_metadata = {}
|
||||||
|
if built_in_field_enabled:
|
||||||
|
doc_metadata = {
|
||||||
|
BuiltInField.document_name: name,
|
||||||
|
BuiltInField.uploader: account.name,
|
||||||
|
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
BuiltInField.source: datasource_type,
|
||||||
|
}
|
||||||
|
if doc_metadata:
|
||||||
|
document.doc_metadata = doc_metadata
|
||||||
|
return document
|
||||||
|
|
||||||
|
def _format_datasource_info_list(
|
||||||
|
self,
|
||||||
|
datasource_type: str,
|
||||||
|
datasource_info_list: list[Mapping[str, Any]],
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
start_node_id: str,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
) -> list[Mapping[str, Any]]:
|
||||||
|
"""
|
||||||
|
Format datasource info list.
|
||||||
|
"""
|
||||||
|
if datasource_type == "online_drive":
|
||||||
|
all_files: list[Mapping[str, Any]] = []
|
||||||
|
datasource_node_data = None
|
||||||
|
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||||
|
for datasource_node in datasource_nodes:
|
||||||
|
if datasource_node.get("id") == start_node_id:
|
||||||
|
datasource_node_data = datasource_node.get("data", {})
|
||||||
|
break
|
||||||
|
if not datasource_node_data:
|
||||||
|
raise ValueError("Datasource node data not found")
|
||||||
|
|
||||||
|
from core.datasource.datasource_manager import DatasourceManager
|
||||||
|
|
||||||
|
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||||
|
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
|
||||||
|
datasource_name=datasource_node_data.get("datasource_name"),
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
datasource_type=DatasourceProviderType(datasource_type),
|
||||||
|
)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
credentials = datasource_provider_service.get_datasource_credentials(
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
provider=datasource_node_data.get("provider_name"),
|
||||||
|
plugin_id=datasource_node_data.get("plugin_id"),
|
||||||
|
credential_id=datasource_node_data.get("credential_id"),
|
||||||
|
)
|
||||||
|
if credentials:
|
||||||
|
datasource_runtime.runtime.credentials = credentials
|
||||||
|
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||||
|
|
||||||
|
for datasource_info in datasource_info_list:
|
||||||
|
if datasource_info.get("id") and datasource_info.get("type") == "folder":
|
||||||
|
# get all files in the folder
|
||||||
|
self._get_files_in_folder(
|
||||||
|
datasource_runtime,
|
||||||
|
datasource_info.get("id", ""),
|
||||||
|
datasource_info.get("bucket", None),
|
||||||
|
user.id,
|
||||||
|
all_files,
|
||||||
|
datasource_info,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
all_files.append(
|
||||||
|
{
|
||||||
|
"id": datasource_info.get("id", ""),
|
||||||
|
"name": datasource_info.get("name", "untitled"),
|
||||||
|
"bucket": datasource_info.get("bucket", None),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return all_files
|
||||||
|
else:
|
||||||
|
return datasource_info_list
|
||||||
|
|
||||||
|
def _get_files_in_folder(
|
||||||
|
self,
|
||||||
|
datasource_runtime: OnlineDriveDatasourcePlugin,
|
||||||
|
prefix: str,
|
||||||
|
bucket: str | None,
|
||||||
|
user_id: str,
|
||||||
|
all_files: list,
|
||||||
|
datasource_info: Mapping[str, Any],
|
||||||
|
next_page_parameters: dict | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get files in a folder.
|
||||||
|
"""
|
||||||
|
result_generator = datasource_runtime.online_drive_browse_files(
|
||||||
|
user_id=user_id,
|
||||||
|
request=OnlineDriveBrowseFilesRequest(
|
||||||
|
bucket=bucket,
|
||||||
|
prefix=prefix,
|
||||||
|
max_keys=20,
|
||||||
|
next_page_parameters=next_page_parameters,
|
||||||
|
),
|
||||||
|
provider_type=datasource_runtime.datasource_provider_type(),
|
||||||
|
)
|
||||||
|
is_truncated = False
|
||||||
|
for result in result_generator:
|
||||||
|
for files in result.result:
|
||||||
|
for file in files.files:
|
||||||
|
if file.type == "folder":
|
||||||
|
self._get_files_in_folder(
|
||||||
|
datasource_runtime,
|
||||||
|
file.id,
|
||||||
|
bucket,
|
||||||
|
user_id,
|
||||||
|
all_files,
|
||||||
|
datasource_info,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
all_files.append(
|
||||||
|
{
|
||||||
|
"id": file.id,
|
||||||
|
"name": file.name,
|
||||||
|
"bucket": bucket,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
is_truncated = files.is_truncated
|
||||||
|
next_page_parameters = files.next_page_parameters
|
||||||
|
|
||||||
|
if is_truncated:
|
||||||
|
self._get_files_in_folder(
|
||||||
|
datasource_runtime, prefix, bucket, user_id, all_files, datasource_info, next_page_parameters
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
|
from core.app.apps.exc import GenerateTaskStoppedError
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.app.entities.queue_entities import (
|
||||||
|
AppQueueEvent,
|
||||||
|
QueueErrorEvent,
|
||||||
|
QueueMessageEndEvent,
|
||||||
|
QueueStopEvent,
|
||||||
|
QueueWorkflowFailedEvent,
|
||||||
|
QueueWorkflowPartialSuccessEvent,
|
||||||
|
QueueWorkflowSucceededEvent,
|
||||||
|
WorkflowQueueMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineQueueManager(AppQueueManager):
|
||||||
|
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||||
|
super().__init__(task_id, user_id, invoke_from)
|
||||||
|
|
||||||
|
self._app_mode = app_mode
|
||||||
|
|
||||||
|
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||||
|
"""
|
||||||
|
Publish event to queue
|
||||||
|
:param event:
|
||||||
|
:param pub_from:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
|
||||||
|
|
||||||
|
self._q.put(message)
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
event,
|
||||||
|
QueueStopEvent
|
||||||
|
| QueueErrorEvent
|
||||||
|
| QueueMessageEndEvent
|
||||||
|
| QueueWorkflowSucceededEvent
|
||||||
|
| QueueWorkflowFailedEvent
|
||||||
|
| QueueWorkflowPartialSuccessEvent,
|
||||||
|
):
|
||||||
|
self.stop_listen()
|
||||||
|
|
||||||
|
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||||
|
raise GenerateTaskStoppedError()
|
||||||
|
|
@ -0,0 +1,280 @@
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
|
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||||
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
|
from core.app.entities.app_invoke_entities import (
|
||||||
|
InvokeFrom,
|
||||||
|
RagPipelineGenerateEntity,
|
||||||
|
)
|
||||||
|
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||||
|
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||||
|
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.graph import Graph
|
||||||
|
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||||
|
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
from core.workflow.variable_loader import VariableLoader
|
||||||
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Document, Pipeline
|
||||||
|
from models.enums import UserFrom
|
||||||
|
from models.model import EndUser
|
||||||
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRunner(WorkflowBasedAppRunner):
|
||||||
|
"""
|
||||||
|
Pipeline Application Runner
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
variable_loader: VariableLoader,
|
||||||
|
workflow: Workflow,
|
||||||
|
system_user_id: str,
|
||||||
|
workflow_thread_pool_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
variable_loader=variable_loader,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
)
|
||||||
|
self.application_generate_entity = application_generate_entity
|
||||||
|
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||||
|
self._workflow = workflow
|
||||||
|
self._sys_user_id = system_user_id
|
||||||
|
|
||||||
|
def _get_app_id(self) -> str:
|
||||||
|
return self.application_generate_entity.app_config.app_id
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""
|
||||||
|
Run application
|
||||||
|
"""
|
||||||
|
app_config = self.application_generate_entity.app_config
|
||||||
|
app_config = cast(PipelineConfig, app_config)
|
||||||
|
|
||||||
|
user_id = None
|
||||||
|
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||||
|
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
|
||||||
|
if end_user:
|
||||||
|
user_id = end_user.session_id
|
||||||
|
else:
|
||||||
|
user_id = self.application_generate_entity.user_id
|
||||||
|
|
||||||
|
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
|
||||||
|
if not pipeline:
|
||||||
|
raise ValueError("Pipeline not found")
|
||||||
|
|
||||||
|
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||||
|
if not workflow:
|
||||||
|
raise ValueError("Workflow not initialized")
|
||||||
|
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
# if only single iteration run is requested
|
||||||
|
if self.application_generate_entity.single_iteration_run:
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool.empty(),
|
||||||
|
start_at=time.time(),
|
||||||
|
)
|
||||||
|
# if only single iteration run is requested
|
||||||
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||||
|
workflow=workflow,
|
||||||
|
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||||
|
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
elif self.application_generate_entity.single_loop_run:
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool.empty(),
|
||||||
|
start_at=time.time(),
|
||||||
|
)
|
||||||
|
# if only single loop run is requested
|
||||||
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||||
|
workflow=workflow,
|
||||||
|
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||||
|
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inputs = self.application_generate_entity.inputs
|
||||||
|
files = self.application_generate_entity.files
|
||||||
|
|
||||||
|
# Create a variable pool.
|
||||||
|
system_inputs = SystemVariable(
|
||||||
|
files=files,
|
||||||
|
user_id=user_id,
|
||||||
|
app_id=app_config.app_id,
|
||||||
|
workflow_id=app_config.workflow_id,
|
||||||
|
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||||
|
document_id=self.application_generate_entity.document_id,
|
||||||
|
original_document_id=self.application_generate_entity.original_document_id,
|
||||||
|
batch=self.application_generate_entity.batch,
|
||||||
|
dataset_id=self.application_generate_entity.dataset_id,
|
||||||
|
datasource_type=self.application_generate_entity.datasource_type,
|
||||||
|
datasource_info=self.application_generate_entity.datasource_info,
|
||||||
|
invoke_from=self.application_generate_entity.invoke_from.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
rag_pipeline_variables = []
|
||||||
|
if workflow.rag_pipeline_variables:
|
||||||
|
for v in workflow.rag_pipeline_variables:
|
||||||
|
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||||
|
if (
|
||||||
|
rag_pipeline_variable.belong_to_node_id
|
||||||
|
in (self.application_generate_entity.start_node_id, "shared")
|
||||||
|
) and rag_pipeline_variable.variable in inputs:
|
||||||
|
rag_pipeline_variables.append(
|
||||||
|
RAGPipelineVariableInput(
|
||||||
|
variable=rag_pipeline_variable,
|
||||||
|
value=inputs[rag_pipeline_variable.variable],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables=system_inputs,
|
||||||
|
user_inputs=inputs,
|
||||||
|
environment_variables=workflow.environment_variables,
|
||||||
|
conversation_variables=[],
|
||||||
|
rag_pipeline_variables=rag_pipeline_variables,
|
||||||
|
)
|
||||||
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||||
|
|
||||||
|
# init graph
|
||||||
|
graph = self._init_rag_pipeline_graph(
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
start_node_id=self.application_generate_entity.start_node_id,
|
||||||
|
workflow=workflow,
|
||||||
|
)
|
||||||
|
|
||||||
|
# RUN WORKFLOW
|
||||||
|
workflow_entry = WorkflowEntry(
|
||||||
|
tenant_id=workflow.tenant_id,
|
||||||
|
app_id=workflow.app_id,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
graph=graph,
|
||||||
|
graph_config=workflow.graph_dict,
|
||||||
|
user_id=self.application_generate_entity.user_id,
|
||||||
|
user_from=(
|
||||||
|
UserFrom.ACCOUNT
|
||||||
|
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||||
|
else UserFrom.END_USER
|
||||||
|
),
|
||||||
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
|
call_depth=self.application_generate_entity.call_depth,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = workflow_entry.run()
|
||||||
|
|
||||||
|
for event in generator:
|
||||||
|
self._update_document_status(
|
||||||
|
event, self.application_generate_entity.document_id, self.application_generate_entity.dataset_id
|
||||||
|
)
|
||||||
|
self._handle_event(workflow_entry, event)
|
||||||
|
|
||||||
|
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||||
|
"""
|
||||||
|
Get workflow
|
||||||
|
"""
|
||||||
|
# fetch workflow by workflow_id
|
||||||
|
workflow = (
|
||||||
|
db.session.query(Workflow)
|
||||||
|
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# return workflow
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
def _init_rag_pipeline_graph(
|
||||||
|
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
|
||||||
|
) -> Graph:
|
||||||
|
"""
|
||||||
|
Init pipeline graph
|
||||||
|
"""
|
||||||
|
graph_config = workflow.graph_dict
|
||||||
|
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||||
|
raise ValueError("nodes or edges not found in workflow graph")
|
||||||
|
|
||||||
|
if not isinstance(graph_config.get("nodes"), list):
|
||||||
|
raise ValueError("nodes in workflow graph must be a list")
|
||||||
|
|
||||||
|
if not isinstance(graph_config.get("edges"), list):
|
||||||
|
raise ValueError("edges in workflow graph must be a list")
|
||||||
|
# nodes = graph_config.get("nodes", [])
|
||||||
|
# edges = graph_config.get("edges", [])
|
||||||
|
# real_run_nodes = []
|
||||||
|
# real_edges = []
|
||||||
|
# exclude_node_ids = []
|
||||||
|
# for node in nodes:
|
||||||
|
# node_id = node.get("id")
|
||||||
|
# node_type = node.get("data", {}).get("type", "")
|
||||||
|
# if node_type == "datasource":
|
||||||
|
# if start_node_id != node_id:
|
||||||
|
# exclude_node_ids.append(node_id)
|
||||||
|
# continue
|
||||||
|
# real_run_nodes.append(node)
|
||||||
|
|
||||||
|
# for edge in edges:
|
||||||
|
# if edge.get("source") in exclude_node_ids:
|
||||||
|
# continue
|
||||||
|
# real_edges.append(edge)
|
||||||
|
# graph_config = dict(graph_config)
|
||||||
|
# graph_config["nodes"] = real_run_nodes
|
||||||
|
# graph_config["edges"] = real_edges
|
||||||
|
# init graph
|
||||||
|
# Create required parameters for Graph.init
|
||||||
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id=workflow.tenant_id,
|
||||||
|
app_id=self._app_id,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id=self.application_generate_entity.user_id,
|
||||||
|
user_from=UserFrom.ACCOUNT.value,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
node_factory = DifyNodeFactory(
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
|
||||||
|
|
||||||
|
if not graph:
|
||||||
|
raise ValueError("graph not found in workflow")
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def _update_document_status(self, event: GraphEngineEvent, document_id: str | None, dataset_id: str | None) -> None:
|
||||||
|
"""
|
||||||
|
Update document status
|
||||||
|
"""
|
||||||
|
if isinstance(event, GraphRunFailedEvent):
|
||||||
|
if document_id and dataset_id:
|
||||||
|
document = (
|
||||||
|
db.session.query(Document)
|
||||||
|
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if document:
|
||||||
|
document.indexing_status = "error"
|
||||||
|
document.error = event.error or "Unknown error"
|
||||||
|
db.session.add(document)
|
||||||
|
db.session.commit()
|
||||||
|
|
@ -53,7 +53,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[True],
|
streaming: Literal[True],
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_thread_pool_id: str | None,
|
|
||||||
) -> Generator[Mapping | str, None, None]: ...
|
) -> Generator[Mapping | str, None, None]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -67,7 +66,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[False],
|
streaming: Literal[False],
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_thread_pool_id: str | None,
|
|
||||||
) -> Mapping[str, Any]: ...
|
) -> Mapping[str, Any]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -81,7 +79,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_thread_pool_id: str | None,
|
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
|
@ -94,7 +91,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
call_depth: int = 0,
|
call_depth: int = 0,
|
||||||
workflow_thread_pool_id: str | None = None,
|
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||||
|
|
||||||
|
|
@ -186,7 +182,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
workflow_execution_repository=workflow_execution_repository,
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
|
|
@ -200,7 +195,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
workflow_thread_pool_id: str | None = None,
|
|
||||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -214,7 +208,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
:param workflow_execution_repository: repository for workflow execution
|
:param workflow_execution_repository: repository for workflow execution
|
||||||
:param workflow_node_execution_repository: repository for workflow node execution
|
:param workflow_node_execution_repository: repository for workflow node execution
|
||||||
:param streaming: is stream
|
:param streaming: is stream
|
||||||
:param workflow_thread_pool_id: workflow thread pool id
|
|
||||||
"""
|
"""
|
||||||
# init queue manager
|
# init queue manager
|
||||||
queue_manager = WorkflowAppQueueManager(
|
queue_manager = WorkflowAppQueueManager(
|
||||||
|
|
@ -237,16 +230,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
"application_generate_entity": application_generate_entity,
|
"application_generate_entity": application_generate_entity,
|
||||||
"queue_manager": queue_manager,
|
"queue_manager": queue_manager,
|
||||||
"context": context,
|
"context": context,
|
||||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
|
||||||
"variable_loader": variable_loader,
|
"variable_loader": variable_loader,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
draft_var_saver_factory = self._get_draft_var_saver_factory(
|
draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user)
|
||||||
invoke_from,
|
|
||||||
)
|
|
||||||
|
|
||||||
# return response or stream generator
|
# return response or stream generator
|
||||||
response = self._handle_response(
|
response = self._handle_response(
|
||||||
|
|
@ -434,8 +424,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
context: contextvars.Context,
|
context: contextvars.Context,
|
||||||
variable_loader: VariableLoader,
|
variable_loader: VariableLoader,
|
||||||
workflow_thread_pool_id: str | None = None,
|
) -> None:
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Generate worker in a new thread.
|
Generate worker in a new thread.
|
||||||
:param flask_app: Flask app
|
:param flask_app: Flask app
|
||||||
|
|
@ -444,7 +433,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
:param workflow_thread_pool_id: workflow thread pool id
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow = session.scalar(
|
workflow = session.scalar(
|
||||||
|
|
@ -474,7 +462,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
runner = WorkflowAppRunner(
|
runner = WorkflowAppRunner(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
|
||||||
variable_loader=variable_loader,
|
variable_loader=variable_loader,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
system_user_id=system_user_id,
|
system_user_id=system_user_id,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
|
|
@ -9,13 +9,14 @@ from core.app.entities.app_invoke_entities import (
|
||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
WorkflowAppGenerateEntity,
|
WorkflowAppGenerateEntity,
|
||||||
)
|
)
|
||||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.variable_loader import VariableLoader
|
from core.workflow.variable_loader import VariableLoader
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import Workflow, WorkflowType
|
from models.workflow import Workflow
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -31,7 +32,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||||
application_generate_entity: WorkflowAppGenerateEntity,
|
application_generate_entity: WorkflowAppGenerateEntity,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
variable_loader: VariableLoader,
|
variable_loader: VariableLoader,
|
||||||
workflow_thread_pool_id: str | None = None,
|
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
system_user_id: str,
|
system_user_id: str,
|
||||||
):
|
):
|
||||||
|
|
@ -41,7 +41,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||||
app_id=application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
)
|
)
|
||||||
self.application_generate_entity = application_generate_entity
|
self.application_generate_entity = application_generate_entity
|
||||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
|
||||||
self._workflow = workflow
|
self._workflow = workflow
|
||||||
self._sys_user_id = system_user_id
|
self._sys_user_id = system_user_id
|
||||||
|
|
||||||
|
|
@ -52,24 +51,30 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||||
app_config = self.application_generate_entity.app_config
|
app_config = self.application_generate_entity.app_config
|
||||||
app_config = cast(WorkflowAppConfig, app_config)
|
app_config = cast(WorkflowAppConfig, app_config)
|
||||||
|
|
||||||
workflow_callbacks: list[WorkflowCallback] = []
|
|
||||||
if dify_config.DEBUG:
|
|
||||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
|
||||||
|
|
||||||
# if only single iteration run is requested
|
# if only single iteration run is requested
|
||||||
if self.application_generate_entity.single_iteration_run:
|
if self.application_generate_entity.single_iteration_run:
|
||||||
# if only single iteration run is requested
|
# if only single iteration run is requested
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool.empty(),
|
||||||
|
start_at=time.time(),
|
||||||
|
)
|
||||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||||
workflow=self._workflow,
|
workflow=self._workflow,
|
||||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
elif self.application_generate_entity.single_loop_run:
|
elif self.application_generate_entity.single_loop_run:
|
||||||
# if only single loop run is requested
|
# if only single loop run is requested
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool.empty(),
|
||||||
|
start_at=time.time(),
|
||||||
|
)
|
||||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||||
workflow=self._workflow,
|
workflow=self._workflow,
|
||||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inputs = self.application_generate_entity.inputs
|
inputs = self.application_generate_entity.inputs
|
||||||
|
|
@ -92,15 +97,27 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
graph = self._init_graph(
|
||||||
|
graph_config=self._workflow.graph_dict,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
workflow_id=self._workflow.id,
|
||||||
|
tenant_id=self._workflow.tenant_id,
|
||||||
|
user_id=self.application_generate_entity.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
# RUN WORKFLOW
|
# RUN WORKFLOW
|
||||||
|
# Create Redis command channel for this workflow execution
|
||||||
|
task_id = self.application_generate_entity.task_id
|
||||||
|
channel_key = f"workflow:{task_id}:commands"
|
||||||
|
command_channel = RedisChannel(redis_client, channel_key)
|
||||||
|
|
||||||
workflow_entry = WorkflowEntry(
|
workflow_entry = WorkflowEntry(
|
||||||
tenant_id=self._workflow.tenant_id,
|
tenant_id=self._workflow.tenant_id,
|
||||||
app_id=self._workflow.app_id,
|
app_id=self._workflow.app_id,
|
||||||
workflow_id=self._workflow.id,
|
workflow_id=self._workflow.id,
|
||||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=self._workflow.graph_dict,
|
graph_config=self._workflow.graph_dict,
|
||||||
user_id=self.application_generate_entity.user_id,
|
user_id=self.application_generate_entity.user_id,
|
||||||
|
|
@ -112,10 +129,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||||
invoke_from=self.application_generate_entity.invoke_from,
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
call_depth=self.application_generate_entity.call_depth,
|
call_depth=self.application_generate_entity.call_depth,
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
thread_pool_id=self.workflow_thread_pool_id,
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
command_channel=command_channel,
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
generator = workflow_entry.run()
|
||||||
|
|
||||||
for event in generator:
|
for event in generator:
|
||||||
self._handle_event(workflow_entry, event)
|
self._handle_event(workflow_entry, event)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Union
|
from typing import Union
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
|
||||||
WorkflowAppGenerateEntity,
|
WorkflowAppGenerateEntity,
|
||||||
)
|
)
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
|
AppQueueEvent,
|
||||||
MessageQueueMessage,
|
MessageQueueMessage,
|
||||||
QueueAgentLogEvent,
|
QueueAgentLogEvent,
|
||||||
QueueErrorEvent,
|
QueueErrorEvent,
|
||||||
|
|
@ -25,14 +26,9 @@ from core.app.entities.queue_entities import (
|
||||||
QueueLoopStartEvent,
|
QueueLoopStartEvent,
|
||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
|
||||||
QueueNodeInLoopFailedEvent,
|
|
||||||
QueueNodeRetryEvent,
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
|
||||||
QueueParallelBranchRunStartedEvent,
|
|
||||||
QueueParallelBranchRunSucceededEvent,
|
|
||||||
QueuePingEvent,
|
QueuePingEvent,
|
||||||
QueueStopEvent,
|
QueueStopEvent,
|
||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
|
|
@ -57,8 +53,8 @@ from core.app.entities.task_entities import (
|
||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
|
@ -349,9 +345,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
|
|
||||||
def _handle_node_failed_events(
|
def _handle_node_failed_events(
|
||||||
self,
|
self,
|
||||||
event: Union[
|
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
|
||||||
],
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Generator[StreamResponse, None, None]:
|
) -> Generator[StreamResponse, None, None]:
|
||||||
"""Handle various node failure events."""
|
"""Handle various node failure events."""
|
||||||
|
|
@ -370,32 +364,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
if node_failed_response:
|
if node_failed_response:
|
||||||
yield node_failed_response
|
yield node_failed_response
|
||||||
|
|
||||||
def _handle_parallel_branch_started_event(
|
|
||||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle parallel branch started events."""
|
|
||||||
self._ensure_workflow_initialized()
|
|
||||||
|
|
||||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
workflow_execution_id=self._workflow_run_id,
|
|
||||||
event=event,
|
|
||||||
)
|
|
||||||
yield parallel_start_resp
|
|
||||||
|
|
||||||
def _handle_parallel_branch_finished_events(
|
|
||||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle parallel branch finished events."""
|
|
||||||
self._ensure_workflow_initialized()
|
|
||||||
|
|
||||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
workflow_execution_id=self._workflow_run_id,
|
|
||||||
event=event,
|
|
||||||
)
|
|
||||||
yield parallel_finish_resp
|
|
||||||
|
|
||||||
def _handle_iteration_start_event(
|
def _handle_iteration_start_event(
|
||||||
self, event: QueueIterationStartEvent, **kwargs
|
self, event: QueueIterationStartEvent, **kwargs
|
||||||
) -> Generator[StreamResponse, None, None]:
|
) -> Generator[StreamResponse, None, None]:
|
||||||
|
|
@ -617,8 +585,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||||
# Parallel branch events
|
|
||||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
|
||||||
# Iteration events
|
# Iteration events
|
||||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||||
|
|
@ -633,7 +599,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
|
|
||||||
def _dispatch_event(
|
def _dispatch_event(
|
||||||
self,
|
self,
|
||||||
event: Any,
|
event: AppQueueEvent,
|
||||||
*,
|
*,
|
||||||
graph_runtime_state: GraphRuntimeState | None = None,
|
graph_runtime_state: GraphRuntimeState | None = None,
|
||||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||||
|
|
@ -660,8 +626,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
event,
|
event,
|
||||||
(
|
(
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
|
||||||
QueueNodeInLoopFailedEvent,
|
|
||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
|
@ -674,17 +638,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Handle parallel branch finished events with isinstance check
|
|
||||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
|
||||||
yield from self._handle_parallel_branch_finished_events(
|
|
||||||
event,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
tts_publisher=tts_publisher,
|
|
||||||
trace_manager=trace_manager,
|
|
||||||
queue_message=queue_message,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Handle workflow failed and stop events with isinstance check
|
# Handle workflow failed and stop events with isinstance check
|
||||||
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
|
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
|
||||||
yield from self._handle_workflow_failed_and_stop_events(
|
yield from self._handle_workflow_failed_and_stop_events(
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from collections.abc import Mapping
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
AppQueueEvent,
|
AppQueueEvent,
|
||||||
QueueAgentLogEvent,
|
QueueAgentLogEvent,
|
||||||
|
|
@ -13,14 +14,9 @@ from core.app.entities.queue_entities import (
|
||||||
QueueLoopStartEvent,
|
QueueLoopStartEvent,
|
||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
|
||||||
QueueNodeInLoopFailedEvent,
|
|
||||||
QueueNodeRetryEvent,
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
|
||||||
QueueParallelBranchRunStartedEvent,
|
|
||||||
QueueParallelBranchRunSucceededEvent,
|
|
||||||
QueueRetrieverResourcesEvent,
|
QueueRetrieverResourcesEvent,
|
||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
QueueWorkflowFailedEvent,
|
QueueWorkflowFailedEvent,
|
||||||
|
|
@ -28,42 +24,39 @@ from core.app.entities.queue_entities import (
|
||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_events import (
|
||||||
AgentLogEvent,
|
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
GraphRunPartialSucceededEvent,
|
GraphRunPartialSucceededEvent,
|
||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
IterationRunFailedEvent,
|
NodeRunAgentLogEvent,
|
||||||
IterationRunNextEvent,
|
|
||||||
IterationRunStartedEvent,
|
|
||||||
IterationRunSucceededEvent,
|
|
||||||
LoopRunFailedEvent,
|
|
||||||
LoopRunNextEvent,
|
|
||||||
LoopRunStartedEvent,
|
|
||||||
LoopRunSucceededEvent,
|
|
||||||
NodeInIterationFailedEvent,
|
|
||||||
NodeInLoopFailedEvent,
|
|
||||||
NodeRunExceptionEvent,
|
NodeRunExceptionEvent,
|
||||||
NodeRunFailedEvent,
|
NodeRunFailedEvent,
|
||||||
|
NodeRunIterationFailedEvent,
|
||||||
|
NodeRunIterationNextEvent,
|
||||||
|
NodeRunIterationStartedEvent,
|
||||||
|
NodeRunIterationSucceededEvent,
|
||||||
|
NodeRunLoopFailedEvent,
|
||||||
|
NodeRunLoopNextEvent,
|
||||||
|
NodeRunLoopStartedEvent,
|
||||||
|
NodeRunLoopSucceededEvent,
|
||||||
NodeRunRetrieverResourceEvent,
|
NodeRunRetrieverResourceEvent,
|
||||||
NodeRunRetryEvent,
|
NodeRunRetryEvent,
|
||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
ParallelBranchRunFailedEvent,
|
|
||||||
ParallelBranchRunStartedEvent,
|
|
||||||
ParallelBranchRunSucceededEvent,
|
|
||||||
)
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
|
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from models.enums import UserFrom
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -79,7 +72,14 @@ class WorkflowBasedAppRunner:
|
||||||
self._variable_loader = variable_loader
|
self._variable_loader = variable_loader
|
||||||
self._app_id = app_id
|
self._app_id = app_id
|
||||||
|
|
||||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
def _init_graph(
|
||||||
|
self,
|
||||||
|
graph_config: Mapping[str, Any],
|
||||||
|
graph_runtime_state: GraphRuntimeState,
|
||||||
|
workflow_id: str = "",
|
||||||
|
tenant_id: str = "",
|
||||||
|
user_id: str = "",
|
||||||
|
) -> Graph:
|
||||||
"""
|
"""
|
||||||
Init graph
|
Init graph
|
||||||
"""
|
"""
|
||||||
|
|
@ -91,8 +91,28 @@ class WorkflowBasedAppRunner:
|
||||||
|
|
||||||
if not isinstance(graph_config.get("edges"), list):
|
if not isinstance(graph_config.get("edges"), list):
|
||||||
raise ValueError("edges in workflow graph must be a list")
|
raise ValueError("edges in workflow graph must be a list")
|
||||||
|
|
||||||
|
# Create required parameters for Graph.init
|
||||||
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id=tenant_id or "",
|
||||||
|
app_id=self._app_id,
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id=user_id,
|
||||||
|
user_from=UserFrom.ACCOUNT.value,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the provided graph_runtime_state for consistent state management
|
||||||
|
|
||||||
|
node_factory = DifyNodeFactory(
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = Graph.init(graph_config=graph_config)
|
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||||
|
|
||||||
if not graph:
|
if not graph:
|
||||||
raise ValueError("graph not found in workflow")
|
raise ValueError("graph not found in workflow")
|
||||||
|
|
@ -104,6 +124,7 @@ class WorkflowBasedAppRunner:
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
user_inputs: dict,
|
user_inputs: dict,
|
||||||
|
graph_runtime_state: GraphRuntimeState,
|
||||||
) -> tuple[Graph, VariablePool]:
|
) -> tuple[Graph, VariablePool]:
|
||||||
"""
|
"""
|
||||||
Get variable pool of single iteration
|
Get variable pool of single iteration
|
||||||
|
|
@ -145,8 +166,25 @@ class WorkflowBasedAppRunner:
|
||||||
|
|
||||||
graph_config["edges"] = edge_configs
|
graph_config["edges"] = edge_configs
|
||||||
|
|
||||||
|
# Create required parameters for Graph.init
|
||||||
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id=workflow.tenant_id,
|
||||||
|
app_id=self._app_id,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="",
|
||||||
|
user_from=UserFrom.ACCOUNT.value,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
node_factory = DifyNodeFactory(
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||||
|
|
||||||
if not graph:
|
if not graph:
|
||||||
raise ValueError("graph not found in workflow")
|
raise ValueError("graph not found in workflow")
|
||||||
|
|
@ -201,6 +239,7 @@ class WorkflowBasedAppRunner:
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
user_inputs: dict,
|
user_inputs: dict,
|
||||||
|
graph_runtime_state: GraphRuntimeState,
|
||||||
) -> tuple[Graph, VariablePool]:
|
) -> tuple[Graph, VariablePool]:
|
||||||
"""
|
"""
|
||||||
Get variable pool of single loop
|
Get variable pool of single loop
|
||||||
|
|
@ -242,8 +281,25 @@ class WorkflowBasedAppRunner:
|
||||||
|
|
||||||
graph_config["edges"] = edge_configs
|
graph_config["edges"] = edge_configs
|
||||||
|
|
||||||
|
# Create required parameters for Graph.init
|
||||||
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id=workflow.tenant_id,
|
||||||
|
app_id=self._app_id,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="",
|
||||||
|
user_from=UserFrom.ACCOUNT.value,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
node_factory = DifyNodeFactory(
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||||
|
|
||||||
if not graph:
|
if not graph:
|
||||||
raise ValueError("graph not found in workflow")
|
raise ValueError("graph not found in workflow")
|
||||||
|
|
@ -310,39 +366,32 @@ class WorkflowBasedAppRunner:
|
||||||
)
|
)
|
||||||
elif isinstance(event, GraphRunFailedEvent):
|
elif isinstance(event, GraphRunFailedEvent):
|
||||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||||
|
elif isinstance(event, GraphRunAbortedEvent):
|
||||||
|
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||||
elif isinstance(event, NodeRunRetryEvent):
|
elif isinstance(event, NodeRunRetryEvent):
|
||||||
node_run_result = event.route_node_state.node_run_result
|
node_run_result = event.node_run_result
|
||||||
inputs: Mapping[str, Any] | None = {}
|
inputs = node_run_result.inputs
|
||||||
process_data: Mapping[str, Any] | None = {}
|
process_data = node_run_result.process_data
|
||||||
outputs: Mapping[str, Any] | None = {}
|
outputs = node_run_result.outputs
|
||||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
|
execution_metadata = node_run_result.metadata
|
||||||
if node_run_result:
|
|
||||||
inputs = node_run_result.inputs
|
|
||||||
process_data = node_run_result.process_data
|
|
||||||
outputs = node_run_result.outputs
|
|
||||||
execution_metadata = node_run_result.metadata
|
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueNodeRetryEvent(
|
QueueNodeRetryEvent(
|
||||||
node_execution_id=event.id,
|
node_execution_id=event.id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
|
node_title=event.node_title,
|
||||||
node_type=event.node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.node_data,
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.start_at,
|
start_at=event.start_at,
|
||||||
node_run_index=event.route_node_state.index,
|
|
||||||
predecessor_node_id=event.predecessor_node_id,
|
predecessor_node_id=event.predecessor_node_id,
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
process_data=process_data,
|
process_data=process_data,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
error=event.error,
|
error=event.error,
|
||||||
execution_metadata=execution_metadata,
|
execution_metadata=execution_metadata,
|
||||||
retry_index=event.retry_index,
|
retry_index=event.retry_index,
|
||||||
|
provider_type=event.provider_type,
|
||||||
|
provider_id=event.provider_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, NodeRunStartedEvent):
|
elif isinstance(event, NodeRunStartedEvent):
|
||||||
|
|
@ -350,44 +399,29 @@ class WorkflowBasedAppRunner:
|
||||||
QueueNodeStartedEvent(
|
QueueNodeStartedEvent(
|
||||||
node_execution_id=event.id,
|
node_execution_id=event.id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
|
node_title=event.node_title,
|
||||||
node_type=event.node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.node_data,
|
start_at=event.start_at,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.route_node_state.start_at,
|
|
||||||
node_run_index=event.route_node_state.index,
|
|
||||||
predecessor_node_id=event.predecessor_node_id,
|
predecessor_node_id=event.predecessor_node_id,
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
|
||||||
agent_strategy=event.agent_strategy,
|
agent_strategy=event.agent_strategy,
|
||||||
|
provider_type=event.provider_type,
|
||||||
|
provider_id=event.provider_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, NodeRunSucceededEvent):
|
elif isinstance(event, NodeRunSucceededEvent):
|
||||||
node_run_result = event.route_node_state.node_run_result
|
node_run_result = event.node_run_result
|
||||||
if node_run_result:
|
inputs = node_run_result.inputs
|
||||||
inputs = node_run_result.inputs
|
process_data = node_run_result.process_data
|
||||||
process_data = node_run_result.process_data
|
outputs = node_run_result.outputs
|
||||||
outputs = node_run_result.outputs
|
execution_metadata = node_run_result.metadata
|
||||||
execution_metadata = node_run_result.metadata
|
|
||||||
else:
|
|
||||||
inputs = {}
|
|
||||||
process_data = {}
|
|
||||||
outputs = {}
|
|
||||||
execution_metadata = {}
|
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueNodeSucceededEvent(
|
QueueNodeSucceededEvent(
|
||||||
node_execution_id=event.id,
|
node_execution_id=event.id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.node_data,
|
start_at=event.start_at,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.route_node_state.start_at,
|
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
process_data=process_data,
|
process_data=process_data,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
|
|
@ -396,34 +430,18 @@ class WorkflowBasedAppRunner:
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(event, NodeRunFailedEvent):
|
elif isinstance(event, NodeRunFailedEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueNodeFailedEvent(
|
QueueNodeFailedEvent(
|
||||||
node_execution_id=event.id,
|
node_execution_id=event.id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.node_data,
|
start_at=event.start_at,
|
||||||
parallel_id=event.parallel_id,
|
inputs=event.node_run_result.inputs,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
process_data=event.node_run_result.process_data,
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
outputs=event.node_run_result.outputs,
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
error=event.node_run_result.error or "Unknown error",
|
||||||
start_at=event.route_node_state.start_at,
|
execution_metadata=event.node_run_result.metadata,
|
||||||
inputs=event.route_node_state.node_run_result.inputs
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
process_data=event.route_node_state.node_run_result.process_data
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
error=event.route_node_state.node_run_result.error
|
|
||||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
|
||||||
else "Unknown error",
|
|
||||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
)
|
)
|
||||||
|
|
@ -434,93 +452,21 @@ class WorkflowBasedAppRunner:
|
||||||
node_execution_id=event.id,
|
node_execution_id=event.id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.node_data,
|
start_at=event.start_at,
|
||||||
parallel_id=event.parallel_id,
|
inputs=event.node_run_result.inputs,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
process_data=event.node_run_result.process_data,
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
outputs=event.node_run_result.outputs,
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
error=event.node_run_result.error or "Unknown error",
|
||||||
start_at=event.route_node_state.start_at,
|
execution_metadata=event.node_run_result.metadata,
|
||||||
inputs=event.route_node_state.node_run_result.inputs
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
process_data=event.route_node_state.node_run_result.process_data
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
outputs=event.route_node_state.node_run_result.outputs
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
error=event.route_node_state.node_run_result.error
|
|
||||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
|
||||||
else "Unknown error",
|
|
||||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(event, NodeInIterationFailedEvent):
|
|
||||||
self._publish_event(
|
|
||||||
QueueNodeInIterationFailedEvent(
|
|
||||||
node_execution_id=event.id,
|
|
||||||
node_id=event.node_id,
|
|
||||||
node_type=event.node_type,
|
|
||||||
node_data=event.node_data,
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.route_node_state.start_at,
|
|
||||||
inputs=event.route_node_state.node_run_result.inputs
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
process_data=event.route_node_state.node_run_result.process_data
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
in_iteration_id=event.in_iteration_id,
|
|
||||||
error=event.error,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, NodeInLoopFailedEvent):
|
|
||||||
self._publish_event(
|
|
||||||
QueueNodeInLoopFailedEvent(
|
|
||||||
node_execution_id=event.id,
|
|
||||||
node_id=event.node_id,
|
|
||||||
node_type=event.node_type,
|
|
||||||
node_data=event.node_data,
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.route_node_state.start_at,
|
|
||||||
inputs=event.route_node_state.node_run_result.inputs
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
process_data=event.route_node_state.node_run_result.process_data
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
in_loop_id=event.in_loop_id,
|
|
||||||
error=event.error,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueTextChunkEvent(
|
QueueTextChunkEvent(
|
||||||
text=event.chunk_content,
|
text=event.chunk,
|
||||||
from_variable_selector=event.from_variable_selector,
|
from_variable_selector=list(event.selector),
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
)
|
)
|
||||||
|
|
@ -533,10 +479,10 @@ class WorkflowBasedAppRunner:
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, AgentLogEvent):
|
elif isinstance(event, NodeRunAgentLogEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueAgentLogEvent(
|
QueueAgentLogEvent(
|
||||||
id=event.id,
|
id=event.message_id,
|
||||||
label=event.label,
|
label=event.label,
|
||||||
node_execution_id=event.node_execution_id,
|
node_execution_id=event.node_execution_id,
|
||||||
parent_id=event.parent_id,
|
parent_id=event.parent_id,
|
||||||
|
|
@ -547,51 +493,13 @@ class WorkflowBasedAppRunner:
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||||
self._publish_event(
|
|
||||||
QueueParallelBranchRunStartedEvent(
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
in_iteration_id=event.in_iteration_id,
|
|
||||||
in_loop_id=event.in_loop_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
|
||||||
self._publish_event(
|
|
||||||
QueueParallelBranchRunSucceededEvent(
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
in_iteration_id=event.in_iteration_id,
|
|
||||||
in_loop_id=event.in_loop_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
|
||||||
self._publish_event(
|
|
||||||
QueueParallelBranchRunFailedEvent(
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
in_iteration_id=event.in_iteration_id,
|
|
||||||
in_loop_id=event.in_loop_id,
|
|
||||||
error=event.error,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, IterationRunStartedEvent):
|
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueIterationStartEvent(
|
QueueIterationStartEvent(
|
||||||
node_execution_id=event.iteration_id,
|
node_execution_id=event.id,
|
||||||
node_id=event.iteration_node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.iteration_node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.iteration_node_data,
|
node_title=event.node_title,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.start_at,
|
start_at=event.start_at,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
inputs=event.inputs,
|
inputs=event.inputs,
|
||||||
|
|
@ -599,55 +507,41 @@ class WorkflowBasedAppRunner:
|
||||||
metadata=event.metadata,
|
metadata=event.metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, IterationRunNextEvent):
|
elif isinstance(event, NodeRunIterationNextEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueIterationNextEvent(
|
QueueIterationNextEvent(
|
||||||
node_execution_id=event.iteration_id,
|
node_execution_id=event.id,
|
||||||
node_id=event.iteration_node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.iteration_node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.iteration_node_data,
|
node_title=event.node_title,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
index=event.index,
|
index=event.index,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
output=event.pre_iteration_output,
|
output=event.pre_iteration_output,
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
|
||||||
duration=event.duration,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueIterationCompletedEvent(
|
QueueIterationCompletedEvent(
|
||||||
node_execution_id=event.iteration_id,
|
node_execution_id=event.id,
|
||||||
node_id=event.iteration_node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.iteration_node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.iteration_node_data,
|
node_title=event.node_title,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.start_at,
|
start_at=event.start_at,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
inputs=event.inputs,
|
inputs=event.inputs,
|
||||||
outputs=event.outputs,
|
outputs=event.outputs,
|
||||||
metadata=event.metadata,
|
metadata=event.metadata,
|
||||||
steps=event.steps,
|
steps=event.steps,
|
||||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, LoopRunStartedEvent):
|
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueLoopStartEvent(
|
QueueLoopStartEvent(
|
||||||
node_execution_id=event.loop_id,
|
node_execution_id=event.id,
|
||||||
node_id=event.loop_node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.loop_node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.loop_node_data,
|
node_title=event.node_title,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.start_at,
|
start_at=event.start_at,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
inputs=event.inputs,
|
inputs=event.inputs,
|
||||||
|
|
@ -655,42 +549,32 @@ class WorkflowBasedAppRunner:
|
||||||
metadata=event.metadata,
|
metadata=event.metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, LoopRunNextEvent):
|
elif isinstance(event, NodeRunLoopNextEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueLoopNextEvent(
|
QueueLoopNextEvent(
|
||||||
node_execution_id=event.loop_id,
|
node_execution_id=event.id,
|
||||||
node_id=event.loop_node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.loop_node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.loop_node_data,
|
node_title=event.node_title,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
index=event.index,
|
index=event.index,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
output=event.pre_loop_output,
|
output=event.pre_loop_output,
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
|
||||||
duration=event.duration,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)):
|
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueLoopCompletedEvent(
|
QueueLoopCompletedEvent(
|
||||||
node_execution_id=event.loop_id,
|
node_execution_id=event.id,
|
||||||
node_id=event.loop_node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.loop_node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.loop_node_data,
|
node_title=event.node_title,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.start_at,
|
start_at=event.start_at,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
inputs=event.inputs,
|
inputs=event.inputs,
|
||||||
outputs=event.outputs,
|
outputs=event.outputs,
|
||||||
metadata=event.metadata,
|
metadata=event.metadata,
|
||||||
steps=event.steps,
|
steps=event.steps,
|
||||||
error=event.error if isinstance(event, LoopRunFailedEvent) else None,
|
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
|
||||||
from constants import UUID_NIL
|
from constants import UUID_NIL
|
||||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||||
from core.entities.provider_configuration import ProviderModelBundle
|
from core.entities.provider_configuration import ProviderModelBundle
|
||||||
|
|
@ -35,6 +38,7 @@ class InvokeFrom(StrEnum):
|
||||||
# DEBUGGER indicates that this invocation is from
|
# DEBUGGER indicates that this invocation is from
|
||||||
# the workflow (or chatflow) edit page.
|
# the workflow (or chatflow) edit page.
|
||||||
DEBUGGER = "debugger"
|
DEBUGGER = "debugger"
|
||||||
|
PUBLISHED = "published"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str):
|
def value_of(cls, value: str):
|
||||||
|
|
@ -113,8 +117,7 @@ class AppGenerateEntity(BaseModel):
|
||||||
extras: dict[str, Any] = Field(default_factory=dict)
|
extras: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
# tracing instance
|
# tracing instance
|
||||||
# Using Any to avoid circular import with TraceQueueManager
|
trace_manager: Optional["TraceQueueManager"] = None
|
||||||
trace_manager: Any | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||||
|
|
@ -240,3 +243,34 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||||
inputs: dict
|
inputs: dict
|
||||||
|
|
||||||
single_loop_run: SingleLoopRunEntity | None = None
|
single_loop_run: SingleLoopRunEntity | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||||
|
"""
|
||||||
|
RAG Pipeline Application Generate Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pipeline config
|
||||||
|
pipeline_config: WorkflowUIBasedAppConfig
|
||||||
|
datasource_type: str
|
||||||
|
datasource_info: Mapping[str, Any]
|
||||||
|
dataset_id: str
|
||||||
|
batch: str
|
||||||
|
document_id: str | None = None
|
||||||
|
original_document_id: str | None = None
|
||||||
|
start_node_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Import TraceQueueManager at runtime to resolve forward references
|
||||||
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
|
||||||
|
# Rebuild models that use forward references
|
||||||
|
AppGenerateEntity.model_rebuild()
|
||||||
|
EasyUIBasedAppGenerateEntity.model_rebuild()
|
||||||
|
ConversationAppGenerateEntity.model_rebuild()
|
||||||
|
ChatAppGenerateEntity.model_rebuild()
|
||||||
|
CompletionAppGenerateEntity.model_rebuild()
|
||||||
|
AgentChatAppGenerateEntity.model_rebuild()
|
||||||
|
AdvancedChatAppGenerateEntity.model_rebuild()
|
||||||
|
WorkflowAppGenerateEntity.model_rebuild()
|
||||||
|
RagPipelineGenerateEntity.model_rebuild()
|
||||||
|
|
|
||||||
|
|
@ -3,15 +3,13 @@ from datetime import datetime
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.nodes.base import BaseNodeData
|
|
||||||
|
|
||||||
|
|
||||||
class QueueEvent(StrEnum):
|
class QueueEvent(StrEnum):
|
||||||
|
|
@ -43,9 +41,6 @@ class QueueEvent(StrEnum):
|
||||||
ANNOTATION_REPLY = "annotation_reply"
|
ANNOTATION_REPLY = "annotation_reply"
|
||||||
AGENT_THOUGHT = "agent_thought"
|
AGENT_THOUGHT = "agent_thought"
|
||||||
MESSAGE_FILE = "message_file"
|
MESSAGE_FILE = "message_file"
|
||||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
|
||||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
|
||||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
|
||||||
AGENT_LOG = "agent_log"
|
AGENT_LOG = "agent_log"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
PING = "ping"
|
PING = "ping"
|
||||||
|
|
@ -80,21 +75,13 @@ class QueueIterationStartEvent(AppQueueEvent):
|
||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_title: str
|
||||||
parallel_id: str | None = None
|
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: str | None = None
|
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: str | None = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: str | None = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
predecessor_node_id: str | None = None
|
predecessor_node_id: str | None = None
|
||||||
metadata: Mapping[str, Any] | None = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class QueueIterationNextEvent(AppQueueEvent):
|
class QueueIterationNextEvent(AppQueueEvent):
|
||||||
|
|
@ -108,20 +95,9 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_title: str
|
||||||
parallel_id: str | None = None
|
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: str | None = None
|
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: str | None = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: str | None = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
parallel_mode_run_id: str | None = None
|
|
||||||
"""iteration run in parallel mode run id"""
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
output: Any | None = None # output for the current iteration
|
output: Any = None # output for the current iteration
|
||||||
duration: float | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||||
|
|
@ -134,21 +110,13 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_title: str
|
||||||
parallel_id: str | None = None
|
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: str | None = None
|
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: str | None = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: str | None = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Mapping[str, Any] | None = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
|
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
|
@ -163,7 +131,7 @@ class QueueLoopStartEvent(AppQueueEvent):
|
||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_title: str
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
|
|
@ -175,9 +143,9 @@ class QueueLoopStartEvent(AppQueueEvent):
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
predecessor_node_id: str | None = None
|
predecessor_node_id: str | None = None
|
||||||
metadata: Mapping[str, Any] | None = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class QueueLoopNextEvent(AppQueueEvent):
|
class QueueLoopNextEvent(AppQueueEvent):
|
||||||
|
|
@ -191,7 +159,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_title: str
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
|
|
@ -203,8 +171,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||||
parallel_mode_run_id: str | None = None
|
parallel_mode_run_id: str | None = None
|
||||||
"""iteration run in parallel mode run id"""
|
"""iteration run in parallel mode run id"""
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
output: Any | None = None # output for the current loop
|
output: Any = None # output for the current loop
|
||||||
duration: float | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class QueueLoopCompletedEvent(AppQueueEvent):
|
class QueueLoopCompletedEvent(AppQueueEvent):
|
||||||
|
|
@ -217,7 +184,7 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_title: str
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
|
|
@ -229,9 +196,9 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
metadata: Mapping[str, Any] | None = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
steps: int = 0
|
steps: int = 0
|
||||||
|
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
|
@ -332,7 +299,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
||||||
outputs: dict[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class QueueWorkflowFailedEvent(AppQueueEvent):
|
class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||||
|
|
@ -352,7 +319,7 @@ class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
|
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
|
||||||
exceptions_count: int
|
exceptions_count: int
|
||||||
outputs: dict[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeStartedEvent(AppQueueEvent):
|
class QueueNodeStartedEvent(AppQueueEvent):
|
||||||
|
|
@ -364,27 +331,24 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||||
|
|
||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
|
node_title: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||||
node_run_index: int = 1
|
|
||||||
predecessor_node_id: str | None = None
|
predecessor_node_id: str | None = None
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: str | None = None
|
parent_parallel_id: str | None = None
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: str | None = None
|
parent_parallel_start_node_id: str | None = None
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: str | None = None
|
in_iteration_id: str | None = None
|
||||||
"""iteration id if node is in iteration"""
|
|
||||||
in_loop_id: str | None = None
|
in_loop_id: str | None = None
|
||||||
"""loop id if node is in loop"""
|
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
parallel_mode_run_id: str | None = None
|
parallel_mode_run_id: str | None = None
|
||||||
"""iteration run in parallel mode run id"""
|
|
||||||
agent_strategy: AgentNodeStrategyInit | None = None
|
agent_strategy: AgentNodeStrategyInit | None = None
|
||||||
|
|
||||||
|
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||||
|
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||||
|
provider_id: str
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
|
|
@ -396,7 +360,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
|
|
@ -411,16 +374,12 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||||
"""loop id if node is in loop"""
|
"""loop id if node is in loop"""
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
process_data: Mapping[str, Any] | None = None
|
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||||
|
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
"""single iteration duration map"""
|
|
||||||
iteration_duration_map: dict[str, float] | None = None
|
|
||||||
"""single loop duration map"""
|
|
||||||
loop_duration_map: dict[str, float] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class QueueAgentLogEvent(AppQueueEvent):
|
class QueueAgentLogEvent(AppQueueEvent):
|
||||||
|
|
@ -436,7 +395,7 @@ class QueueAgentLogEvent(AppQueueEvent):
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
status: str
|
status: str
|
||||||
data: Mapping[str, Any]
|
data: Mapping[str, Any]
|
||||||
metadata: Mapping[str, Any] | None = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
node_id: str
|
node_id: str
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -445,81 +404,15 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.RETRY
|
event: QueueEvent = QueueEvent.RETRY
|
||||||
|
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
process_data: Mapping[str, Any] | None = None
|
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||||
|
|
||||||
error: str
|
error: str
|
||||||
retry_index: int # retry index
|
retry_index: int # retry index
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueNodeInIterationFailedEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
|
||||||
|
|
||||||
node_execution_id: str
|
|
||||||
node_id: str
|
|
||||||
node_type: NodeType
|
|
||||||
node_data: BaseNodeData
|
|
||||||
parallel_id: str | None = None
|
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: str | None = None
|
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: str | None = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: str | None = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: str | None = None
|
|
||||||
"""iteration id if node is in iteration"""
|
|
||||||
in_loop_id: str | None = None
|
|
||||||
"""loop id if node is in loop"""
|
|
||||||
start_at: datetime
|
|
||||||
|
|
||||||
inputs: Mapping[str, Any] | None = None
|
|
||||||
process_data: Mapping[str, Any] | None = None
|
|
||||||
outputs: Mapping[str, Any] | None = None
|
|
||||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
|
||||||
|
|
||||||
error: str
|
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeInLoopFailedEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueNodeInLoopFailedEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
|
||||||
|
|
||||||
node_execution_id: str
|
|
||||||
node_id: str
|
|
||||||
node_type: NodeType
|
|
||||||
node_data: BaseNodeData
|
|
||||||
parallel_id: str | None = None
|
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: str | None = None
|
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: str | None = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: str | None = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: str | None = None
|
|
||||||
"""iteration id if node is in iteration"""
|
|
||||||
in_loop_id: str | None = None
|
|
||||||
"""loop id if node is in loop"""
|
|
||||||
start_at: datetime
|
|
||||||
|
|
||||||
inputs: Mapping[str, Any] | None = None
|
|
||||||
process_data: Mapping[str, Any] | None = None
|
|
||||||
outputs: Mapping[str, Any] | None = None
|
|
||||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
|
||||||
|
|
||||||
error: str
|
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeExceptionEvent(AppQueueEvent):
|
class QueueNodeExceptionEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
QueueNodeExceptionEvent entity
|
QueueNodeExceptionEvent entity
|
||||||
|
|
@ -530,7 +423,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
|
|
@ -545,9 +437,9 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||||
"""loop id if node is in loop"""
|
"""loop id if node is in loop"""
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
process_data: Mapping[str, Any] | None = None
|
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||||
|
|
||||||
error: str
|
error: str
|
||||||
|
|
@ -563,24 +455,16 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: str | None = None
|
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: str | None = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: str | None = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: str | None = None
|
in_iteration_id: str | None = None
|
||||||
"""iteration id if node is in iteration"""
|
"""iteration id if node is in iteration"""
|
||||||
in_loop_id: str | None = None
|
in_loop_id: str | None = None
|
||||||
"""loop id if node is in loop"""
|
"""loop id if node is in loop"""
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
process_data: Mapping[str, Any] | None = None
|
process_data: Mapping[str, object] = Field(default_factory=dict)
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
|
||||||
|
|
||||||
error: str
|
error: str
|
||||||
|
|
@ -610,7 +494,7 @@ class QueueErrorEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.ERROR
|
event: QueueEvent = QueueEvent.ERROR
|
||||||
error: Any | None = None
|
error: Any = None
|
||||||
|
|
||||||
|
|
||||||
class QueuePingEvent(AppQueueEvent):
|
class QueuePingEvent(AppQueueEvent):
|
||||||
|
|
@ -678,61 +562,3 @@ class WorkflowQueueMessage(QueueMessage):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueParallelBranchRunStartedEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
|
||||||
|
|
||||||
parallel_id: str
|
|
||||||
parallel_start_node_id: str
|
|
||||||
parent_parallel_id: str | None = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: str | None = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: str | None = None
|
|
||||||
"""iteration id if node is in iteration"""
|
|
||||||
in_loop_id: str | None = None
|
|
||||||
"""loop id if node is in loop"""
|
|
||||||
|
|
||||||
|
|
||||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueParallelBranchRunSucceededEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
|
||||||
|
|
||||||
parallel_id: str
|
|
||||||
parallel_start_node_id: str
|
|
||||||
parent_parallel_id: str | None = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: str | None = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: str | None = None
|
|
||||||
"""iteration id if node is in iteration"""
|
|
||||||
in_loop_id: str | None = None
|
|
||||||
"""loop id if node is in loop"""
|
|
||||||
|
|
||||||
|
|
||||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueParallelBranchRunFailedEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
|
||||||
|
|
||||||
parallel_id: str
|
|
||||||
parallel_start_node_id: str
|
|
||||||
parent_parallel_id: str | None = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: str | None = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: str | None = None
|
|
||||||
"""iteration id if node is in iteration"""
|
|
||||||
in_loop_id: str | None = None
|
|
||||||
"""loop id if node is in loop"""
|
|
||||||
error: str
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineInvokeEntity(BaseModel):
|
||||||
|
pipeline_id: str
|
||||||
|
application_generate_entity: dict[str, Any]
|
||||||
|
user_id: str
|
||||||
|
tenant_id: str
|
||||||
|
workflow_id: str
|
||||||
|
streaming: bool
|
||||||
|
workflow_execution_id: str | None = None
|
||||||
|
workflow_thread_pool_id: str | None = None
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
from core.workflow.entities import AgentNodeStrategyInit
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class AnnotationReplyAccount(BaseModel):
|
class AnnotationReplyAccount(BaseModel):
|
||||||
|
|
@ -55,32 +55,30 @@ class StreamEvent(StrEnum):
|
||||||
Stream event
|
Stream event
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PING = auto()
|
PING = "ping"
|
||||||
ERROR = auto()
|
ERROR = "error"
|
||||||
MESSAGE = auto()
|
MESSAGE = "message"
|
||||||
MESSAGE_END = auto()
|
MESSAGE_END = "message_end"
|
||||||
TTS_MESSAGE = auto()
|
TTS_MESSAGE = "tts_message"
|
||||||
TTS_MESSAGE_END = auto()
|
TTS_MESSAGE_END = "tts_message_end"
|
||||||
MESSAGE_FILE = auto()
|
MESSAGE_FILE = "message_file"
|
||||||
MESSAGE_REPLACE = auto()
|
MESSAGE_REPLACE = "message_replace"
|
||||||
AGENT_THOUGHT = auto()
|
AGENT_THOUGHT = "agent_thought"
|
||||||
AGENT_MESSAGE = auto()
|
AGENT_MESSAGE = "agent_message"
|
||||||
WORKFLOW_STARTED = auto()
|
WORKFLOW_STARTED = "workflow_started"
|
||||||
WORKFLOW_FINISHED = auto()
|
WORKFLOW_FINISHED = "workflow_finished"
|
||||||
NODE_STARTED = auto()
|
NODE_STARTED = "node_started"
|
||||||
NODE_FINISHED = auto()
|
NODE_FINISHED = "node_finished"
|
||||||
NODE_RETRY = auto()
|
NODE_RETRY = "node_retry"
|
||||||
PARALLEL_BRANCH_STARTED = auto()
|
ITERATION_STARTED = "iteration_started"
|
||||||
PARALLEL_BRANCH_FINISHED = auto()
|
ITERATION_NEXT = "iteration_next"
|
||||||
ITERATION_STARTED = auto()
|
ITERATION_COMPLETED = "iteration_completed"
|
||||||
ITERATION_NEXT = auto()
|
LOOP_STARTED = "loop_started"
|
||||||
ITERATION_COMPLETED = auto()
|
LOOP_NEXT = "loop_next"
|
||||||
LOOP_STARTED = auto()
|
LOOP_COMPLETED = "loop_completed"
|
||||||
LOOP_NEXT = auto()
|
TEXT_CHUNK = "text_chunk"
|
||||||
LOOP_COMPLETED = auto()
|
TEXT_REPLACE = "text_replace"
|
||||||
TEXT_CHUNK = auto()
|
AGENT_LOG = "agent_log"
|
||||||
TEXT_REPLACE = auto()
|
|
||||||
AGENT_LOG = auto()
|
|
||||||
|
|
||||||
|
|
||||||
class StreamResponse(BaseModel):
|
class StreamResponse(BaseModel):
|
||||||
|
|
@ -138,7 +136,7 @@ class MessageEndStreamResponse(StreamResponse):
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.MESSAGE_END
|
event: StreamEvent = StreamEvent.MESSAGE_END
|
||||||
id: str
|
id: str
|
||||||
metadata: dict = Field(default_factory=dict)
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
files: Sequence[Mapping[str, Any]] | None = None
|
files: Sequence[Mapping[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -175,7 +173,7 @@ class AgentThoughtStreamResponse(StreamResponse):
|
||||||
thought: str | None = None
|
thought: str | None = None
|
||||||
observation: str | None = None
|
observation: str | None = None
|
||||||
tool: str | None = None
|
tool: str | None = None
|
||||||
tool_labels: dict | None = None
|
tool_labels: Mapping[str, object] = Field(default_factory=dict)
|
||||||
tool_input: str | None = None
|
tool_input: str | None = None
|
||||||
message_files: list[str] | None = None
|
message_files: list[str] | None = None
|
||||||
|
|
||||||
|
|
@ -228,7 +226,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
total_steps: int
|
total_steps: int
|
||||||
created_by: dict | None = None
|
created_by: Mapping[str, object] = Field(default_factory=dict)
|
||||||
created_at: int
|
created_at: int
|
||||||
finished_at: int
|
finished_at: int
|
||||||
exceptions_count: int | None = 0
|
exceptions_count: int | None = 0
|
||||||
|
|
@ -256,8 +254,9 @@ class NodeStartStreamResponse(StreamResponse):
|
||||||
index: int
|
index: int
|
||||||
predecessor_node_id: str | None = None
|
predecessor_node_id: str | None = None
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, Any] | None = None
|
||||||
|
inputs_truncated: bool = False
|
||||||
created_at: int
|
created_at: int
|
||||||
extras: dict = Field(default_factory=dict)
|
extras: dict[str, object] = Field(default_factory=dict)
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
parent_parallel_id: str | None = None
|
parent_parallel_id: str | None = None
|
||||||
|
|
@ -313,8 +312,11 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||||
index: int
|
index: int
|
||||||
predecessor_node_id: str | None = None
|
predecessor_node_id: str | None = None
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, Any] | None = None
|
||||||
|
inputs_truncated: bool = False
|
||||||
process_data: Mapping[str, Any] | None = None
|
process_data: Mapping[str, Any] | None = None
|
||||||
|
process_data_truncated: bool = False
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, Any] | None = None
|
||||||
|
outputs_truncated: bool = True
|
||||||
status: str
|
status: str
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
|
|
@ -382,8 +384,11 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||||
index: int
|
index: int
|
||||||
predecessor_node_id: str | None = None
|
predecessor_node_id: str | None = None
|
||||||
inputs: Mapping[str, Any] | None = None
|
inputs: Mapping[str, Any] | None = None
|
||||||
|
inputs_truncated: bool = False
|
||||||
process_data: Mapping[str, Any] | None = None
|
process_data: Mapping[str, Any] | None = None
|
||||||
|
process_data_truncated: bool = False
|
||||||
outputs: Mapping[str, Any] | None = None
|
outputs: Mapping[str, Any] | None = None
|
||||||
|
outputs_truncated: bool = False
|
||||||
status: str
|
status: str
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
|
|
@ -436,54 +441,6 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
|
||||||
"""
|
|
||||||
ParallelBranchStartStreamResponse entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Data(BaseModel):
|
|
||||||
"""
|
|
||||||
Data entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
parallel_id: str
|
|
||||||
parallel_branch_id: str
|
|
||||||
parent_parallel_id: str | None = None
|
|
||||||
parent_parallel_start_node_id: str | None = None
|
|
||||||
iteration_id: str | None = None
|
|
||||||
loop_id: str | None = None
|
|
||||||
created_at: int
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
|
||||||
workflow_run_id: str
|
|
||||||
data: Data
|
|
||||||
|
|
||||||
|
|
||||||
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
|
||||||
"""
|
|
||||||
ParallelBranchFinishedStreamResponse entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Data(BaseModel):
|
|
||||||
"""
|
|
||||||
Data entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
parallel_id: str
|
|
||||||
parallel_branch_id: str
|
|
||||||
parent_parallel_id: str | None = None
|
|
||||||
parent_parallel_start_node_id: str | None = None
|
|
||||||
iteration_id: str | None = None
|
|
||||||
loop_id: str | None = None
|
|
||||||
status: str
|
|
||||||
error: str | None = None
|
|
||||||
created_at: int
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
|
||||||
workflow_run_id: str
|
|
||||||
data: Data
|
|
||||||
|
|
||||||
|
|
||||||
class IterationNodeStartStreamResponse(StreamResponse):
|
class IterationNodeStartStreamResponse(StreamResponse):
|
||||||
"""
|
"""
|
||||||
NodeStartStreamResponse entity
|
NodeStartStreamResponse entity
|
||||||
|
|
@ -502,8 +459,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||||
extras: dict = Field(default_factory=dict)
|
extras: dict = Field(default_factory=dict)
|
||||||
metadata: Mapping = {}
|
metadata: Mapping = {}
|
||||||
inputs: Mapping = {}
|
inputs: Mapping = {}
|
||||||
parallel_id: str | None = None
|
inputs_truncated: bool = False
|
||||||
parallel_start_node_id: str | None = None
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
|
|
@ -526,12 +482,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||||
title: str
|
title: str
|
||||||
index: int
|
index: int
|
||||||
created_at: int
|
created_at: int
|
||||||
pre_iteration_output: Any | None = None
|
|
||||||
extras: dict = Field(default_factory=dict)
|
extras: dict = Field(default_factory=dict)
|
||||||
parallel_id: str | None = None
|
|
||||||
parallel_start_node_id: str | None = None
|
|
||||||
parallel_mode_run_id: str | None = None
|
|
||||||
duration: float | None = None
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
|
|
@ -553,18 +504,18 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||||
node_type: str
|
node_type: str
|
||||||
title: str
|
title: str
|
||||||
outputs: Mapping | None = None
|
outputs: Mapping | None = None
|
||||||
|
outputs_truncated: bool = False
|
||||||
created_at: int
|
created_at: int
|
||||||
extras: dict | None = None
|
extras: dict | None = None
|
||||||
inputs: Mapping | None = None
|
inputs: Mapping | None = None
|
||||||
|
inputs_truncated: bool = False
|
||||||
status: WorkflowNodeExecutionStatus
|
status: WorkflowNodeExecutionStatus
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
execution_metadata: Mapping | None = None
|
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
finished_at: int
|
finished_at: int
|
||||||
steps: int
|
steps: int
|
||||||
parallel_id: str | None = None
|
|
||||||
parallel_start_node_id: str | None = None
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
|
|
@ -589,6 +540,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
|
||||||
extras: dict = Field(default_factory=dict)
|
extras: dict = Field(default_factory=dict)
|
||||||
metadata: Mapping = {}
|
metadata: Mapping = {}
|
||||||
inputs: Mapping = {}
|
inputs: Mapping = {}
|
||||||
|
inputs_truncated: bool = False
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
|
|
||||||
|
|
@ -613,12 +565,11 @@ class LoopNodeNextStreamResponse(StreamResponse):
|
||||||
title: str
|
title: str
|
||||||
index: int
|
index: int
|
||||||
created_at: int
|
created_at: int
|
||||||
pre_loop_output: Any | None = None
|
pre_loop_output: Any = None
|
||||||
extras: dict = Field(default_factory=dict)
|
extras: Mapping[str, object] = Field(default_factory=dict)
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
parallel_start_node_id: str | None = None
|
parallel_start_node_id: str | None = None
|
||||||
parallel_mode_run_id: str | None = None
|
parallel_mode_run_id: str | None = None
|
||||||
duration: float | None = None
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.LOOP_NEXT
|
event: StreamEvent = StreamEvent.LOOP_NEXT
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
|
|
@ -640,14 +591,16 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||||
node_type: str
|
node_type: str
|
||||||
title: str
|
title: str
|
||||||
outputs: Mapping | None = None
|
outputs: Mapping | None = None
|
||||||
|
outputs_truncated: bool = False
|
||||||
created_at: int
|
created_at: int
|
||||||
extras: dict | None = None
|
extras: dict | None = None
|
||||||
inputs: Mapping | None = None
|
inputs: Mapping | None = None
|
||||||
|
inputs_truncated: bool = False
|
||||||
status: WorkflowNodeExecutionStatus
|
status: WorkflowNodeExecutionStatus
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
execution_metadata: Mapping | None = None
|
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
finished_at: int
|
finished_at: int
|
||||||
steps: int
|
steps: int
|
||||||
parallel_id: str | None = None
|
parallel_id: str | None = None
|
||||||
|
|
@ -757,7 +710,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||||
conversation_id: str
|
conversation_id: str
|
||||||
message_id: str
|
message_id: str
|
||||||
answer: str
|
answer: str
|
||||||
metadata: dict = Field(default_factory=dict)
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
created_at: int
|
created_at: int
|
||||||
|
|
||||||
data: Data
|
data: Data
|
||||||
|
|
@ -777,7 +730,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
|
||||||
mode: str
|
mode: str
|
||||||
message_id: str
|
message_id: str
|
||||||
answer: str
|
answer: str
|
||||||
metadata: dict = Field(default_factory=dict)
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
created_at: int
|
created_at: int
|
||||||
|
|
||||||
data: Data
|
data: Data
|
||||||
|
|
@ -825,7 +778,7 @@ class AgentLogStreamResponse(StreamResponse):
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
status: str
|
status: str
|
||||||
data: Mapping[str, Any]
|
data: Mapping[str, Any]
|
||||||
metadata: Mapping[str, Any] | None = None
|
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||||
node_id: str
|
node_id: str
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||||
|
|
|
||||||
|
|
@ -138,6 +138,8 @@ class MessageCycleManager:
|
||||||
:param event: event
|
:param event: event
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
if not self._application_generate_entity.app_config.additional_features:
|
||||||
|
raise ValueError("Additional features not found")
|
||||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||||
self._task_state.metadata.retriever_resources = event.retriever_resources
|
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -109,7 +109,9 @@ class AppGeneratorTTSPublisher:
|
||||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||||
if message.event.outputs is None:
|
if message.event.outputs is None:
|
||||||
continue
|
continue
|
||||||
self.msg_text += message.event.outputs.get("output", "")
|
output = message.event.outputs.get("output", "")
|
||||||
|
if isinstance(output, str):
|
||||||
|
self.msg_text += output
|
||||||
self.last_message = message
|
self.last_message = message
|
||||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||||
if len(sentence_arr) >= min(self.max_sentence, 7):
|
if len(sentence_arr) >= min(self.max_sentence, 7):
|
||||||
|
|
@ -119,7 +121,7 @@ class AppGeneratorTTSPublisher:
|
||||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||||
)
|
)
|
||||||
future_queue.put(futures_result)
|
future_queue.put(futures_result)
|
||||||
if text_tmp:
|
if isinstance(text_tmp, str):
|
||||||
self.msg_text = text_tmp
|
self.msg_text = text_tmp
|
||||||
else:
|
else:
|
||||||
self.msg_text = ""
|
self.msg_text = ""
|
||||||
|
|
|
||||||
|
|
@ -105,6 +105,14 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||||
|
|
||||||
self.current_loop += 1
|
self.current_loop += 1
|
||||||
|
|
||||||
|
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
|
||||||
|
"""Run on datasource start."""
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
print_text(
|
||||||
|
"\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + str(datasource_inputs) + "\n",
|
||||||
|
color=self.color,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ignore_agent(self) -> bool:
|
def ignore_agent(self) -> bool:
|
||||||
"""Whether to ignore agent callbacks."""
|
"""Whether to ignore agent callbacks."""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceProviderType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourcePlugin(ABC):
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
icon: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
icon: str,
|
||||||
|
) -> None:
|
||||||
|
self.entity = entity
|
||||||
|
self.runtime = runtime
|
||||||
|
self.icon = icon
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
"""
|
||||||
|
returns the type of the datasource provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
|
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||||
|
return self.__class__(
|
||||||
|
entity=self.entity.model_copy(),
|
||||||
|
runtime=runtime,
|
||||||
|
icon=self.icon,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_icon_url(self, tenant_id: str) -> str:
|
||||||
|
return f"{dify_config.CONSOLE_API_URL}/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={self.icon}" # noqa: E501
|
||||||
|
|
@ -0,0 +1,118 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
from core.plugin.impl.tool import PluginToolManager
|
||||||
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourcePluginProviderController(ABC):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
tenant_id: str
|
||||||
|
|
||||||
|
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
|
||||||
|
self.entity = entity
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def need_credentials(self) -> bool:
|
||||||
|
"""
|
||||||
|
returns whether the provider needs credentials
|
||||||
|
|
||||||
|
:return: whether the provider needs credentials
|
||||||
|
"""
|
||||||
|
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
||||||
|
|
||||||
|
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
validate the credentials of the provider
|
||||||
|
"""
|
||||||
|
manager = PluginToolManager()
|
||||||
|
if not manager.validate_datasource_credentials(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider=self.entity.identity.name,
|
||||||
|
credentials=credentials,
|
||||||
|
):
|
||||||
|
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
validate the format of the credentials of the provider and set the default value if needed
|
||||||
|
|
||||||
|
:param credentials: the credentials of the tool
|
||||||
|
"""
|
||||||
|
credentials_schema = dict[str, ProviderConfig]()
|
||||||
|
if credentials_schema is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for credential in self.entity.credentials_schema:
|
||||||
|
credentials_schema[credential.name] = credential
|
||||||
|
|
||||||
|
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||||
|
for credential_name in credentials_schema:
|
||||||
|
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||||
|
|
||||||
|
for credential_name in credentials:
|
||||||
|
if credential_name not in credentials_need_to_validate:
|
||||||
|
raise ToolProviderCredentialValidationError(
|
||||||
|
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check type
|
||||||
|
credential_schema = credentials_need_to_validate[credential_name]
|
||||||
|
if not credential_schema.required and credentials[credential_name] is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||||
|
if not isinstance(credentials[credential_name], str):
|
||||||
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||||
|
|
||||||
|
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||||
|
if not isinstance(credentials[credential_name], str):
|
||||||
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||||
|
|
||||||
|
options = credential_schema.options
|
||||||
|
if not isinstance(options, list):
|
||||||
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||||
|
|
||||||
|
if credentials[credential_name] not in [x.value for x in options]:
|
||||||
|
raise ToolProviderCredentialValidationError(
|
||||||
|
f"credential {credential_name} should be one of {options}"
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials_need_to_validate.pop(credential_name)
|
||||||
|
|
||||||
|
for credential_name in credentials_need_to_validate:
|
||||||
|
credential_schema = credentials_need_to_validate[credential_name]
|
||||||
|
if credential_schema.required:
|
||||||
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||||
|
|
||||||
|
# the credential is not set currently, set the default value if needed
|
||||||
|
if credential_schema.default is not None:
|
||||||
|
default_value = credential_schema.default
|
||||||
|
# parse default value into the correct type
|
||||||
|
if credential_schema.type in {
|
||||||
|
ProviderConfig.Type.SECRET_INPUT,
|
||||||
|
ProviderConfig.Type.TEXT_INPUT,
|
||||||
|
ProviderConfig.Type.SELECT,
|
||||||
|
}:
|
||||||
|
default_value = str(default_value)
|
||||||
|
|
||||||
|
credentials[credential_name] = default_value
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
from openai import BaseModel
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
# Import InvokeFrom locally to avoid circular import
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceRuntime(BaseModel):
|
||||||
|
"""
|
||||||
|
Meta data of a datasource call processing
|
||||||
|
"""
|
||||||
|
|
||||||
|
tenant_id: str
|
||||||
|
datasource_id: str | None = None
|
||||||
|
invoke_from: Optional["InvokeFrom"] = None
|
||||||
|
datasource_invoke_from: DatasourceInvokeFrom | None = None
|
||||||
|
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeDatasourceRuntime(DatasourceRuntime):
|
||||||
|
"""
|
||||||
|
Fake datasource runtime for testing
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
tenant_id="fake_tenant_id",
|
||||||
|
datasource_id="fake_datasource_id",
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
|
||||||
|
credentials={},
|
||||||
|
runtime_parameters={},
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,218 @@
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from mimetypes import guess_extension, guess_type
|
||||||
|
from typing import Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from models.enums import CreatorUserRole
|
||||||
|
from models.model import MessageFile, UploadFile
|
||||||
|
from models.tools import ToolFile
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceFileManager:
|
||||||
|
@staticmethod
|
||||||
|
def sign_file(datasource_file_id: str, extension: str) -> str:
|
||||||
|
"""
|
||||||
|
sign file to get a temporary url
|
||||||
|
"""
|
||||||
|
base_url = dify_config.FILES_URL
|
||||||
|
file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}"
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||||
|
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||||
|
"""
|
||||||
|
verify signature
|
||||||
|
"""
|
||||||
|
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||||
|
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||||
|
|
||||||
|
# verify signature
|
||||||
|
if sign != recalculated_encoded_sign:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = int(time.time())
|
||||||
|
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_file_by_raw(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
conversation_id: str | None,
|
||||||
|
file_binary: bytes,
|
||||||
|
mimetype: str,
|
||||||
|
filename: str | None = None,
|
||||||
|
) -> UploadFile:
|
||||||
|
extension = guess_extension(mimetype) or ".bin"
|
||||||
|
unique_name = uuid4().hex
|
||||||
|
unique_filename = f"{unique_name}{extension}"
|
||||||
|
# default just as before
|
||||||
|
present_filename = unique_filename
|
||||||
|
if filename is not None:
|
||||||
|
has_extension = len(filename.split(".")) > 1
|
||||||
|
# Add extension flexibly
|
||||||
|
present_filename = filename if has_extension else f"{filename}{extension}"
|
||||||
|
filepath = f"datasources/{tenant_id}/{unique_filename}"
|
||||||
|
storage.save(filepath, file_binary)
|
||||||
|
|
||||||
|
upload_file = UploadFile(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
storage_type=dify_config.STORAGE_TYPE,
|
||||||
|
key=filepath,
|
||||||
|
name=present_filename,
|
||||||
|
size=len(file_binary),
|
||||||
|
extension=extension,
|
||||||
|
mime_type=mimetype,
|
||||||
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
|
created_by=user_id,
|
||||||
|
used=False,
|
||||||
|
hash=hashlib.sha3_256(file_binary).hexdigest(),
|
||||||
|
source_url="",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(upload_file)
|
||||||
|
db.session.commit()
|
||||||
|
db.session.refresh(upload_file)
|
||||||
|
|
||||||
|
return upload_file
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_file_by_url(
|
||||||
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
file_url: str,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
) -> ToolFile:
|
||||||
|
# try to download image
|
||||||
|
try:
|
||||||
|
response = ssrf_proxy.get(file_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
blob = response.content
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||||
|
|
||||||
|
mimetype = (
|
||||||
|
guess_type(file_url)[0]
|
||||||
|
or response.headers.get("Content-Type", "").split(";")[0].strip()
|
||||||
|
or "application/octet-stream"
|
||||||
|
)
|
||||||
|
extension = guess_extension(mimetype) or ".bin"
|
||||||
|
unique_name = uuid4().hex
|
||||||
|
filename = f"{unique_name}{extension}"
|
||||||
|
filepath = f"tools/{tenant_id}/{filename}"
|
||||||
|
storage.save(filepath, blob)
|
||||||
|
|
||||||
|
tool_file = ToolFile(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
file_key=filepath,
|
||||||
|
mimetype=mimetype,
|
||||||
|
original_url=file_url,
|
||||||
|
name=filename,
|
||||||
|
size=len(blob),
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(tool_file)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return tool_file
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||||
|
"""
|
||||||
|
get file binary
|
||||||
|
|
||||||
|
:param id: the id of the file
|
||||||
|
|
||||||
|
:return: the binary of the file, mime type
|
||||||
|
"""
|
||||||
|
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first()
|
||||||
|
|
||||||
|
if not upload_file:
|
||||||
|
return None
|
||||||
|
|
||||||
|
blob = storage.load_once(upload_file.key)
|
||||||
|
|
||||||
|
return blob, upload_file.mime_type
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
|
||||||
|
"""
|
||||||
|
get file binary
|
||||||
|
|
||||||
|
:param id: the id of the file
|
||||||
|
|
||||||
|
:return: the binary of the file, mime type
|
||||||
|
"""
|
||||||
|
message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first()
|
||||||
|
|
||||||
|
# Check if message_file is not None
|
||||||
|
if message_file is not None:
|
||||||
|
# get tool file id
|
||||||
|
if message_file.url is not None:
|
||||||
|
tool_file_id = message_file.url.split("/")[-1]
|
||||||
|
# trim extension
|
||||||
|
tool_file_id = tool_file_id.split(".")[0]
|
||||||
|
else:
|
||||||
|
tool_file_id = None
|
||||||
|
else:
|
||||||
|
tool_file_id = None
|
||||||
|
|
||||||
|
tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
|
||||||
|
|
||||||
|
if not tool_file:
|
||||||
|
return None
|
||||||
|
|
||||||
|
blob = storage.load_once(tool_file.file_key)
|
||||||
|
|
||||||
|
return blob, tool_file.mimetype
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_file_generator_by_upload_file_id(upload_file_id: str):
|
||||||
|
"""
|
||||||
|
get file binary
|
||||||
|
|
||||||
|
:param tool_file_id: the id of the tool file
|
||||||
|
|
||||||
|
:return: the binary of the file, mime type
|
||||||
|
"""
|
||||||
|
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||||
|
|
||||||
|
if not upload_file:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
stream = storage.load_stream(upload_file.key)
|
||||||
|
|
||||||
|
return stream, upload_file.mime_type
|
||||||
|
|
||||||
|
|
||||||
|
# init tool_file_parser
|
||||||
|
# from core.file.datasource_file_parser import datasource_file_manager
|
||||||
|
#
|
||||||
|
# datasource_file_manager["manager"] = DatasourceFileManager
|
||||||
|
|
@ -0,0 +1,112 @@
|
||||||
|
import logging
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import contexts
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.entities.common_entities import I18nObject
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||||
|
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||||
|
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||||
|
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||||
|
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
|
||||||
|
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceManager:
|
||||||
|
_builtin_provider_lock = Lock()
|
||||||
|
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
|
||||||
|
_builtin_providers_loaded = False
|
||||||
|
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_datasource_plugin_provider(
|
||||||
|
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
|
||||||
|
) -> DatasourcePluginProviderController:
|
||||||
|
"""
|
||||||
|
get the datasource plugin provider
|
||||||
|
"""
|
||||||
|
# check if context is set
|
||||||
|
try:
|
||||||
|
contexts.datasource_plugin_providers.get()
|
||||||
|
except LookupError:
|
||||||
|
contexts.datasource_plugin_providers.set({})
|
||||||
|
contexts.datasource_plugin_providers_lock.set(Lock())
|
||||||
|
|
||||||
|
with contexts.datasource_plugin_providers_lock.get():
|
||||||
|
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
|
||||||
|
if provider_id in datasource_plugin_providers:
|
||||||
|
return datasource_plugin_providers[provider_id]
|
||||||
|
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
|
||||||
|
if not provider_entity:
|
||||||
|
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
|
||||||
|
controller: DatasourcePluginProviderController | None = None
|
||||||
|
match datasource_type:
|
||||||
|
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||||
|
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||||
|
entity=provider_entity.declaration,
|
||||||
|
plugin_id=provider_entity.plugin_id,
|
||||||
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.ONLINE_DRIVE:
|
||||||
|
controller = OnlineDriveDatasourcePluginProviderController(
|
||||||
|
entity=provider_entity.declaration,
|
||||||
|
plugin_id=provider_entity.plugin_id,
|
||||||
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||||
|
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||||
|
entity=provider_entity.declaration,
|
||||||
|
plugin_id=provider_entity.plugin_id,
|
||||||
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.LOCAL_FILE:
|
||||||
|
controller = LocalFileDatasourcePluginProviderController(
|
||||||
|
entity=provider_entity.declaration,
|
||||||
|
plugin_id=provider_entity.plugin_id,
|
||||||
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||||
|
|
||||||
|
if controller:
|
||||||
|
datasource_plugin_providers[provider_id] = controller
|
||||||
|
|
||||||
|
if controller is None:
|
||||||
|
raise DatasourceProviderNotFoundError(f"Datasource provider {provider_id} not found.")
|
||||||
|
|
||||||
|
return controller
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_datasource_runtime(
|
||||||
|
cls,
|
||||||
|
provider_id: str,
|
||||||
|
datasource_name: str,
|
||||||
|
tenant_id: str,
|
||||||
|
datasource_type: DatasourceProviderType,
|
||||||
|
) -> DatasourcePlugin:
|
||||||
|
"""
|
||||||
|
get the datasource runtime
|
||||||
|
|
||||||
|
:param provider_type: the type of the provider
|
||||||
|
:param provider_id: the id of the provider
|
||||||
|
:param datasource_name: the name of the datasource
|
||||||
|
:param tenant_id: the tenant id
|
||||||
|
|
||||||
|
:return: the datasource plugin
|
||||||
|
"""
|
||||||
|
return cls.get_datasource_plugin_provider(
|
||||||
|
provider_id,
|
||||||
|
tenant_id,
|
||||||
|
datasource_type,
|
||||||
|
).get_datasource(datasource_name)
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceApiEntity(BaseModel):
|
||||||
|
author: str
|
||||||
|
name: str # identifier
|
||||||
|
label: I18nObject # label
|
||||||
|
description: I18nObject
|
||||||
|
parameters: list[DatasourceParameter] | None = None
|
||||||
|
labels: list[str] = Field(default_factory=list)
|
||||||
|
output_schema: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderApiEntity(BaseModel):
|
||||||
|
id: str
|
||||||
|
author: str
|
||||||
|
name: str # identifier
|
||||||
|
description: I18nObject
|
||||||
|
icon: str | dict
|
||||||
|
label: I18nObject # label
|
||||||
|
type: str
|
||||||
|
masked_credentials: dict | None = None
|
||||||
|
original_credentials: dict | None = None
|
||||||
|
is_team_authorization: bool = False
|
||||||
|
allow_delete: bool = True
|
||||||
|
plugin_id: str | None = Field(default="", description="The plugin id of the datasource")
|
||||||
|
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the datasource")
|
||||||
|
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
|
||||||
|
labels: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("datasources", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_none_to_empty_list(cls, v):
|
||||||
|
return v if v is not None else []
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
# -------------
|
||||||
|
# overwrite datasource parameter types for temp fix
|
||||||
|
datasources = jsonable_encoder(self.datasources)
|
||||||
|
for datasource in datasources:
|
||||||
|
if datasource.get("parameters"):
|
||||||
|
for parameter in datasource.get("parameters"):
|
||||||
|
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
|
||||||
|
parameter["type"] = "files"
|
||||||
|
# -------------
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"author": self.author,
|
||||||
|
"name": self.name,
|
||||||
|
"plugin_id": self.plugin_id,
|
||||||
|
"plugin_unique_identifier": self.plugin_unique_identifier,
|
||||||
|
"description": self.description.to_dict(),
|
||||||
|
"icon": self.icon,
|
||||||
|
"label": self.label.to_dict(),
|
||||||
|
"type": self.type,
|
||||||
|
"team_credentials": self.masked_credentials,
|
||||||
|
"is_team_authorization": self.is_team_authorization,
|
||||||
|
"allow_delete": self.allow_delete,
|
||||||
|
"datasources": datasources,
|
||||||
|
"labels": self.labels,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class I18nObject(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for i18n object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
en_US: str
|
||||||
|
zh_Hans: str | None = Field(default=None)
|
||||||
|
pt_BR: str | None = Field(default=None)
|
||||||
|
ja_JP: str | None = Field(default=None)
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.zh_Hans = self.zh_Hans or self.en_US
|
||||||
|
self.pt_BR = self.pt_BR or self.en_US
|
||||||
|
self.ja_JP = self.ja_JP or self.en_US
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||||
|
|
@ -0,0 +1,380 @@
|
||||||
|
import enum
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
from core.plugin.entities.oauth import OAuthSchema
|
||||||
|
from core.plugin.entities.parameters import (
|
||||||
|
PluginParameter,
|
||||||
|
PluginParameterOption,
|
||||||
|
PluginParameterType,
|
||||||
|
as_normal_type,
|
||||||
|
cast_parameter_value,
|
||||||
|
init_frontend_parameter,
|
||||||
|
)
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderType(enum.StrEnum):
|
||||||
|
"""
|
||||||
|
Enum class for datasource provider
|
||||||
|
"""
|
||||||
|
|
||||||
|
ONLINE_DOCUMENT = "online_document"
|
||||||
|
LOCAL_FILE = "local_file"
|
||||||
|
WEBSITE_CRAWL = "website_crawl"
|
||||||
|
ONLINE_DRIVE = "online_drive"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_of(cls, value: str) -> "DatasourceProviderType":
|
||||||
|
"""
|
||||||
|
Get value of given mode.
|
||||||
|
|
||||||
|
:param value: mode value
|
||||||
|
:return: mode
|
||||||
|
"""
|
||||||
|
for mode in cls:
|
||||||
|
if mode.value == value:
|
||||||
|
return mode
|
||||||
|
raise ValueError(f"invalid mode value {value}")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceParameter(PluginParameter):
|
||||||
|
"""
|
||||||
|
Overrides type
|
||||||
|
"""
|
||||||
|
|
||||||
|
class DatasourceParameterType(enum.StrEnum):
|
||||||
|
"""
|
||||||
|
removes TOOLS_SELECTOR from PluginParameterType
|
||||||
|
"""
|
||||||
|
|
||||||
|
STRING = PluginParameterType.STRING.value
|
||||||
|
NUMBER = PluginParameterType.NUMBER.value
|
||||||
|
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||||
|
SELECT = PluginParameterType.SELECT.value
|
||||||
|
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||||
|
FILE = PluginParameterType.FILE.value
|
||||||
|
FILES = PluginParameterType.FILES.value
|
||||||
|
|
||||||
|
# deprecated, should not use.
|
||||||
|
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||||
|
|
||||||
|
def as_normal_type(self):
|
||||||
|
return as_normal_type(self)
|
||||||
|
|
||||||
|
def cast_value(self, value: Any):
|
||||||
|
return cast_parameter_value(self, value)
|
||||||
|
|
||||||
|
type: DatasourceParameterType = Field(..., description="The type of the parameter")
|
||||||
|
description: I18nObject = Field(..., description="The description of the parameter")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_simple_instance(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
typ: DatasourceParameterType,
|
||||||
|
required: bool,
|
||||||
|
options: list[str] | None = None,
|
||||||
|
) -> "DatasourceParameter":
|
||||||
|
"""
|
||||||
|
get a simple datasource parameter
|
||||||
|
|
||||||
|
:param name: the name of the parameter
|
||||||
|
:param llm_description: the description presented to the LLM
|
||||||
|
:param typ: the type of the parameter
|
||||||
|
:param required: if the parameter is required
|
||||||
|
:param options: the options of the parameter
|
||||||
|
"""
|
||||||
|
# convert options to ToolParameterOption
|
||||||
|
# FIXME fix the type error
|
||||||
|
if options:
|
||||||
|
option_objs = [
|
||||||
|
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||||
|
for option in options
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
option_objs = []
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=name,
|
||||||
|
label=I18nObject(en_US="", zh_Hans=""),
|
||||||
|
placeholder=None,
|
||||||
|
type=typ,
|
||||||
|
required=required,
|
||||||
|
options=option_objs,
|
||||||
|
description=I18nObject(en_US="", zh_Hans=""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_frontend_parameter(self, value: Any):
|
||||||
|
return init_frontend_parameter(self, self.type, value)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceIdentity(BaseModel):
|
||||||
|
author: str = Field(..., description="The author of the datasource")
|
||||||
|
name: str = Field(..., description="The name of the datasource")
|
||||||
|
label: I18nObject = Field(..., description="The label of the datasource")
|
||||||
|
provider: str = Field(..., description="The provider of the datasource")
|
||||||
|
icon: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceEntity(BaseModel):
|
||||||
|
identity: DatasourceIdentity
|
||||||
|
parameters: list[DatasourceParameter] = Field(default_factory=list)
|
||||||
|
description: I18nObject = Field(..., description="The label of the datasource")
|
||||||
|
output_schema: dict | None = None
|
||||||
|
|
||||||
|
@field_validator("parameters", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
|
||||||
|
return v or []
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderIdentity(BaseModel):
|
||||||
|
author: str = Field(..., description="The author of the tool")
|
||||||
|
name: str = Field(..., description="The name of the tool")
|
||||||
|
description: I18nObject = Field(..., description="The description of the tool")
|
||||||
|
icon: str = Field(..., description="The icon of the tool")
|
||||||
|
label: I18nObject = Field(..., description="The label of the tool")
|
||||||
|
tags: list[ToolLabelEnum] | None = Field(
|
||||||
|
default=[],
|
||||||
|
description="The tags of the tool",
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_datasource_icon_url(self, tenant_id: str) -> str:
|
||||||
|
HARD_CODED_DATASOURCE_ICONS = ["https://assets.dify.ai/images/File%20Upload.svg"]
|
||||||
|
if self.icon in HARD_CODED_DATASOURCE_ICONS:
|
||||||
|
return self.icon
|
||||||
|
return str(
|
||||||
|
URL(dify_config.CONSOLE_API_URL or "/")
|
||||||
|
/ "console"
|
||||||
|
/ "api"
|
||||||
|
/ "workspaces"
|
||||||
|
/ "current"
|
||||||
|
/ "plugin"
|
||||||
|
/ "icon"
|
||||||
|
% {"tenant_id": tenant_id, "filename": self.icon}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Datasource provider entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
identity: DatasourceProviderIdentity
|
||||||
|
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||||
|
oauth_schema: OAuthSchema | None = None
|
||||||
|
provider_type: DatasourceProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||||
|
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceInvokeMeta(BaseModel):
|
||||||
|
"""
|
||||||
|
Datasource invoke meta
|
||||||
|
"""
|
||||||
|
|
||||||
|
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||||
|
error: str | None = None
|
||||||
|
tool_config: dict | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty(cls) -> "DatasourceInvokeMeta":
|
||||||
|
"""
|
||||||
|
Get an empty instance of DatasourceInvokeMeta
|
||||||
|
"""
|
||||||
|
return cls(time_cost=0.0, error=None, tool_config={})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
|
||||||
|
"""
|
||||||
|
Get an instance of DatasourceInvokeMeta with error
|
||||||
|
"""
|
||||||
|
return cls(time_cost=0.0, error=error, tool_config={})
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"time_cost": self.time_cost,
|
||||||
|
"error": self.error,
|
||||||
|
"tool_config": self.tool_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceLabel(BaseModel):
|
||||||
|
"""
|
||||||
|
Datasource label
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = Field(..., description="The name of the tool")
|
||||||
|
label: I18nObject = Field(..., description="The label of the tool")
|
||||||
|
icon: str = Field(..., description="The icon of the tool")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceInvokeFrom(Enum):
|
||||||
|
"""
|
||||||
|
Enum class for datasource invoke
|
||||||
|
"""
|
||||||
|
|
||||||
|
RAG_PIPELINE = "rag_pipeline"
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentPage(BaseModel):
|
||||||
|
"""
|
||||||
|
Online document page
|
||||||
|
"""
|
||||||
|
|
||||||
|
page_id: str = Field(..., description="The page id")
|
||||||
|
page_name: str = Field(..., description="The page title")
|
||||||
|
page_icon: dict | None = Field(None, description="The page icon")
|
||||||
|
type: str = Field(..., description="The type of the page")
|
||||||
|
last_edited_time: str = Field(..., description="The last edited time")
|
||||||
|
parent_id: str | None = Field(None, description="The parent page id")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
Online document info
|
||||||
|
"""
|
||||||
|
|
||||||
|
workspace_id: str | None = Field(None, description="The workspace id")
|
||||||
|
workspace_name: str | None = Field(None, description="The workspace name")
|
||||||
|
workspace_icon: str | None = Field(None, description="The workspace icon")
|
||||||
|
total: int = Field(..., description="The total number of documents")
|
||||||
|
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentPagesMessage(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online document pages response
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: list[OnlineDocumentInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class GetOnlineDocumentPageContentRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online document page content request
|
||||||
|
"""
|
||||||
|
|
||||||
|
workspace_id: str = Field(..., description="The workspace id")
|
||||||
|
page_id: str = Field(..., description="The page id")
|
||||||
|
type: str = Field(..., description="The type of the page")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentPageContent(BaseModel):
|
||||||
|
"""
|
||||||
|
Online document page content
|
||||||
|
"""
|
||||||
|
|
||||||
|
workspace_id: str = Field(..., description="The workspace id")
|
||||||
|
page_id: str = Field(..., description="The page id")
|
||||||
|
content: str = Field(..., description="The content of the page")
|
||||||
|
|
||||||
|
|
||||||
|
class GetOnlineDocumentPageContentResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online document page content response
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: OnlineDocumentPageContent
|
||||||
|
|
||||||
|
|
||||||
|
class GetWebsiteCrawlRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Get website crawl request
|
||||||
|
"""
|
||||||
|
|
||||||
|
crawl_parameters: dict = Field(..., description="The crawl parameters")
|
||||||
|
|
||||||
|
|
||||||
|
class WebSiteInfoDetail(BaseModel):
|
||||||
|
source_url: str = Field(..., description="The url of the website")
|
||||||
|
content: str = Field(..., description="The content of the website")
|
||||||
|
title: str = Field(..., description="The title of the website")
|
||||||
|
description: str = Field(..., description="The description of the website")
|
||||||
|
|
||||||
|
|
||||||
|
class WebSiteInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
Website info
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: str | None = Field(..., description="crawl job status")
|
||||||
|
web_info_list: list[WebSiteInfoDetail] | None = []
|
||||||
|
total: int | None = Field(default=0, description="The total number of websites")
|
||||||
|
completed: int | None = Field(default=0, description="The number of completed websites")
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlMessage(BaseModel):
|
||||||
|
"""
|
||||||
|
Get website crawl response
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceMessage(ToolInvokeMessage):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# Online drive file
|
||||||
|
#########################
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveFile(BaseModel):
|
||||||
|
"""
|
||||||
|
Online drive file
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str = Field(..., description="The file ID")
|
||||||
|
name: str = Field(..., description="The file name")
|
||||||
|
size: int = Field(..., description="The file size")
|
||||||
|
type: str = Field(..., description="The file type: folder or file")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveFileBucket(BaseModel):
|
||||||
|
"""
|
||||||
|
Online drive file bucket
|
||||||
|
"""
|
||||||
|
|
||||||
|
bucket: str | None = Field(None, description="The file bucket")
|
||||||
|
files: list[OnlineDriveFile] = Field(..., description="The file list")
|
||||||
|
is_truncated: bool = Field(False, description="Whether the result is truncated")
|
||||||
|
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online drive file list request
|
||||||
|
"""
|
||||||
|
|
||||||
|
bucket: str | None = Field(None, description="The file bucket")
|
||||||
|
prefix: str = Field(..., description="The parent folder ID")
|
||||||
|
max_keys: int = Field(20, description="Page size for pagination")
|
||||||
|
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveBrowseFilesResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online drive file list response
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: list[OnlineDriveFileBucket] = Field(..., description="The list of file buckets")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveDownloadFileRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online drive file
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str = Field(..., description="The id of the file")
|
||||||
|
bucket: str | None = Field(None, description="The name of the bucket")
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderNotFoundError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceNotFoundError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceParameterValidationError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderCredentialValidationError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceNotSupportedError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceInvokeError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceApiSchemaError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceEngineInvokeError(Exception):
|
||||||
|
meta: DatasourceInvokeMeta
|
||||||
|
|
||||||
|
def __init__(self, meta, **kwargs):
|
||||||
|
self.meta = meta
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceProviderType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime, icon)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
|
def get_icon_url(self, tenant_id: str) -> str:
|
||||||
|
return self.icon
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, tenant_id)
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
|
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
validate the credentials of the provider
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return LocalFileDatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceMessage,
|
||||||
|
DatasourceProviderType,
|
||||||
|
GetOnlineDocumentPageContentRequest,
|
||||||
|
OnlineDocumentPagesMessage,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime, icon)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def get_online_document_pages(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
datasource_parameters: Mapping[str, Any],
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[OnlineDocumentPagesMessage, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.get_online_document_pages(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_online_document_page_content(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[DatasourceMessage, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.get_online_document_page_content(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, tenant_id)
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return OnlineDocumentDatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceMessage,
|
||||||
|
DatasourceProviderType,
|
||||||
|
OnlineDriveBrowseFilesRequest,
|
||||||
|
OnlineDriveBrowseFilesResponse,
|
||||||
|
OnlineDriveDownloadFileRequest,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime, icon)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def online_drive_browse_files(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
request: OnlineDriveBrowseFilesRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.online_drive_browse_files(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
request=request,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def online_drive_download_file(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
request: OnlineDriveDownloadFileRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[DatasourceMessage, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.online_drive_download_file(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
request=request,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
return DatasourceProviderType.ONLINE_DRIVE
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, tenant_id)
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.ONLINE_DRIVE
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return OnlineDriveDatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue