mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 22:47:15 +08:00
Merge branch 'feat/rag-2' of https://github.com/langgenius/dify into feat/rag-2
This commit is contained in:
commit
0316eb6064
@ -460,6 +460,16 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
|||||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||||
MAX_VARIABLE_SIZE=204800
|
MAX_VARIABLE_SIZE=204800
|
||||||
|
|
||||||
|
# GraphEngine Worker Pool Configuration
|
||||||
|
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||||
|
GRAPH_ENGINE_MIN_WORKERS=1
|
||||||
|
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||||
|
GRAPH_ENGINE_MAX_WORKERS=10
|
||||||
|
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||||
|
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
|
||||||
|
# Seconds of idle time before scaling down workers (default: 5.0)
|
||||||
|
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
|
||||||
|
|
||||||
# Workflow storage configuration
|
# Workflow storage configuration
|
||||||
# Options: rdbms, hybrid
|
# Options: rdbms, hybrid
|
||||||
# rdbms: Use only the relational database (default)
|
# rdbms: Use only the relational database (default)
|
||||||
|
|||||||
122
api/.importlinter
Normal file
122
api/.importlinter
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
[importlinter]
|
||||||
|
root_packages =
|
||||||
|
core
|
||||||
|
configs
|
||||||
|
controllers
|
||||||
|
models
|
||||||
|
tasks
|
||||||
|
services
|
||||||
|
|
||||||
|
[importlinter:contract:workflow]
|
||||||
|
name = Workflow
|
||||||
|
type=layers
|
||||||
|
layers =
|
||||||
|
graph_engine
|
||||||
|
graph_events
|
||||||
|
graph
|
||||||
|
nodes
|
||||||
|
node_events
|
||||||
|
entities
|
||||||
|
containers =
|
||||||
|
core.workflow
|
||||||
|
ignore_imports =
|
||||||
|
core.workflow.nodes.base.node -> core.workflow.graph_events
|
||||||
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
||||||
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||||
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||||
|
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
||||||
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
||||||
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||||
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||||
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||||
|
core.workflow.nodes.node_factory -> core.workflow.graph
|
||||||
|
|
||||||
|
[importlinter:contract:rsc]
|
||||||
|
name = RSC
|
||||||
|
type = layers
|
||||||
|
layers =
|
||||||
|
graph_engine
|
||||||
|
response_coordinator
|
||||||
|
output_registry
|
||||||
|
containers =
|
||||||
|
core.workflow.graph_engine
|
||||||
|
|
||||||
|
[importlinter:contract:worker]
|
||||||
|
name = Worker
|
||||||
|
type = layers
|
||||||
|
layers =
|
||||||
|
graph_engine
|
||||||
|
worker
|
||||||
|
containers =
|
||||||
|
core.workflow.graph_engine
|
||||||
|
|
||||||
|
[importlinter:contract:graph-engine-architecture]
|
||||||
|
name = Graph Engine Architecture
|
||||||
|
type = layers
|
||||||
|
layers =
|
||||||
|
graph_engine
|
||||||
|
orchestration
|
||||||
|
command_processing
|
||||||
|
event_management
|
||||||
|
error_handling
|
||||||
|
graph_traversal
|
||||||
|
state_management
|
||||||
|
worker_management
|
||||||
|
domain
|
||||||
|
containers =
|
||||||
|
core.workflow.graph_engine
|
||||||
|
|
||||||
|
[importlinter:contract:domain-isolation]
|
||||||
|
name = Domain Model Isolation
|
||||||
|
type = forbidden
|
||||||
|
source_modules =
|
||||||
|
core.workflow.graph_engine.domain
|
||||||
|
forbidden_modules =
|
||||||
|
core.workflow.graph_engine.worker_management
|
||||||
|
core.workflow.graph_engine.command_channels
|
||||||
|
core.workflow.graph_engine.layers
|
||||||
|
core.workflow.graph_engine.protocols
|
||||||
|
|
||||||
|
[importlinter:contract:state-management-layers]
|
||||||
|
name = State Management Layers
|
||||||
|
type = layers
|
||||||
|
layers =
|
||||||
|
execution_tracker
|
||||||
|
node_state_manager
|
||||||
|
edge_state_manager
|
||||||
|
containers =
|
||||||
|
core.workflow.graph_engine.state_management
|
||||||
|
|
||||||
|
[importlinter:contract:worker-management-layers]
|
||||||
|
name = Worker Management Layers
|
||||||
|
type = layers
|
||||||
|
layers =
|
||||||
|
worker_pool
|
||||||
|
worker_factory
|
||||||
|
dynamic_scaler
|
||||||
|
activity_tracker
|
||||||
|
containers =
|
||||||
|
core.workflow.graph_engine.worker_management
|
||||||
|
|
||||||
|
[importlinter:contract:error-handling-strategies]
|
||||||
|
name = Error Handling Strategies
|
||||||
|
type = independence
|
||||||
|
modules =
|
||||||
|
core.workflow.graph_engine.error_handling.abort_strategy
|
||||||
|
core.workflow.graph_engine.error_handling.retry_strategy
|
||||||
|
core.workflow.graph_engine.error_handling.fail_branch_strategy
|
||||||
|
core.workflow.graph_engine.error_handling.default_value_strategy
|
||||||
|
|
||||||
|
[importlinter:contract:graph-traversal-components]
|
||||||
|
name = Graph Traversal Components
|
||||||
|
type = independence
|
||||||
|
modules =
|
||||||
|
core.workflow.graph_engine.graph_traversal.node_readiness
|
||||||
|
core.workflow.graph_engine.graph_traversal.skip_propagator
|
||||||
|
|
||||||
|
[importlinter:contract:command-channels]
|
||||||
|
name = Command Channels Independence
|
||||||
|
type = independence
|
||||||
|
modules =
|
||||||
|
core.workflow.graph_engine.command_channels.in_memory_channel
|
||||||
|
core.workflow.graph_engine.command_channels.redis_channel
|
||||||
@ -14,7 +14,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.plugin.entities.plugin import DatasourceProviderID, PluginInstallationSource, ToolProviderID
|
from core.plugin.entities.plugin import PluginInstallationSource
|
||||||
from core.plugin.impl.plugin import PluginInstaller
|
from core.plugin.impl.plugin import PluginInstaller
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
@ -35,6 +35,7 @@ from models.dataset import Document as DatasetDocument
|
|||||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||||
from models.provider import Provider, ProviderModel
|
from models.provider import Provider, ProviderModel
|
||||||
|
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||||
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||||
from models.tools import ToolOAuthSystemClient
|
from models.tools import ToolOAuthSystemClient
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
@ -1570,7 +1571,7 @@ def transform_datasource_credentials():
|
|||||||
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
|
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
|
||||||
)
|
)
|
||||||
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
|
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.")
|
@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.")
|
||||||
@click.option(
|
@click.option(
|
||||||
@ -1591,4 +1592,4 @@ def install_rag_pipeline_plugins(input_file, output_file, workers):
|
|||||||
output_file,
|
output_file,
|
||||||
workers,
|
workers,
|
||||||
)
|
)
|
||||||
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))
|
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))
|
||||||
|
|||||||
@ -529,6 +529,28 @@ class WorkflowConfig(BaseSettings):
|
|||||||
default=200 * 1024,
|
default=200 * 1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# GraphEngine Worker Pool Configuration
|
||||||
|
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||||
|
description="Minimum number of workers per GraphEngine instance",
|
||||||
|
default=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||||
|
description="Maximum number of workers per GraphEngine instance",
|
||||||
|
default=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
|
||||||
|
description="Queue depth threshold that triggers worker scale up",
|
||||||
|
default=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
|
||||||
|
description="Seconds of idle time before scaling down workers",
|
||||||
|
default=5.0,
|
||||||
|
ge=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecutionConfig(BaseSettings):
|
class WorkflowNodeExecutionConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -16,7 +16,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
|||||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
from extensions.ext_database import db
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
|
from models import App
|
||||||
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
|
|
||||||
class RuleGenerateApi(Resource):
|
class RuleGenerateApi(Resource):
|
||||||
@ -135,9 +138,6 @@ class InstructionGenerateApi(Resource):
|
|||||||
try:
|
try:
|
||||||
# Generate from nothing for a workflow node
|
# Generate from nothing for a workflow node
|
||||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||||
from models import App, db
|
|
||||||
from services.workflow_service import WorkflowService
|
|
||||||
|
|
||||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||||
if not app:
|
if not app:
|
||||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
from core.helper.trace_id_helper import get_external_trace_id
|
from core.helper.trace_id_helper import get_external_trace_id
|
||||||
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory, variable_factory
|
from factories import file_factory, variable_factory
|
||||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||||
@ -413,7 +414,12 @@ class WorkflowTaskStopApi(Resource):
|
|||||||
if not current_user.is_editor:
|
if not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
# Stop using both mechanisms for backward compatibility
|
||||||
|
# Legacy stop flag mechanism (without user check)
|
||||||
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||||
|
|
||||||
|
# New graph engine command channel mechanism
|
||||||
|
GraphEngineManager.send_stop_command(task_id)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
from core.workflow.enums import WorkflowExecutionStatus
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
|
|||||||
@ -17,10 +17,11 @@ from core.variables.segment_group import SegmentGroup
|
|||||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||||
from core.variables.types import SegmentType
|
from core.variables.types import SegmentType
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
|
from extensions.ext_database import db
|
||||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||||
from factories.variable_factory import build_segment_with_type
|
from factories.variable_factory import build_segment_with_type
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
from models import App, AppMode, db
|
from models import App, AppMode
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.workflow import WorkflowDraftVariable
|
from models.workflow import WorkflowDraftVariable
|
||||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||||
|
|||||||
@ -19,7 +19,6 @@ from controllers.console.wraps import (
|
|||||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
@ -31,6 +30,7 @@ from fields.document_fields import document_status_fields
|
|||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||||
from models.dataset import DatasetPermissionEnum
|
from models.dataset import DatasetPermissionEnum
|
||||||
|
from models.provider_ids import ModelProviderID
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -11,10 +11,10 @@ from controllers.console.wraps import (
|
|||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.plugin.entities.plugin import DatasourceProviderID
|
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from libs.helper import StrLen
|
from libs.helper import StrLen
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
|
from models.provider_ids import DatasourceProviderID
|
||||||
from services.datasource_provider_service import DatasourceProviderService
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
from services.plugin.oauth_service import OAuthProxyService
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
|
|
||||||
|
|||||||
@ -17,10 +17,11 @@ from core.variables.segment_group import SegmentGroup
|
|||||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||||
from core.variables.types import SegmentType
|
from core.variables.types import SegmentType
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
|
from extensions.ext_database import db
|
||||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||||
from factories.variable_factory import build_segment_with_type
|
from factories.variable_factory import build_segment_with_type
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
from models import db
|
from models.account import Account
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
from models.workflow import WorkflowDraftVariable
|
from models.workflow import WorkflowDraftVariable
|
||||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
@ -131,7 +132,7 @@ def _api_prerequisite(f):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -62,7 +62,7 @@ class DraftRagPipelineApi(Resource):
|
|||||||
Get draft rag pipeline's workflow
|
Get draft rag pipeline's workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# fetch draft workflow by app_model
|
# fetch draft workflow by app_model
|
||||||
@ -84,7 +84,7 @@ class DraftRagPipelineApi(Resource):
|
|||||||
Sync draft workflow
|
Sync draft workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
content_type = request.headers.get("Content-Type", "")
|
content_type = request.headers.get("Content-Type", "")
|
||||||
@ -161,7 +161,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||||||
Run draft workflow iteration node
|
Run draft workflow iteration node
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
@ -198,7 +198,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||||||
Run draft workflow loop node
|
Run draft workflow loop node
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
@ -235,7 +235,7 @@ class DraftRagPipelineRunApi(Resource):
|
|||||||
Run draft workflow
|
Run draft workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
@ -272,7 +272,7 @@ class PublishedRagPipelineRunApi(Resource):
|
|||||||
Run published workflow
|
Run published workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
@ -384,8 +384,6 @@ class PublishedRagPipelineRunApi(Resource):
|
|||||||
#
|
#
|
||||||
# return result
|
# return result
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -396,7 +394,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||||||
Run rag pipeline datasource
|
Run rag pipeline datasource
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
@ -441,10 +439,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||||||
Run rag pipeline datasource
|
Run rag pipeline datasource
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -487,10 +482,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||||||
Run draft workflow node
|
Run draft workflow node
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -519,7 +511,7 @@ class RagPipelineTaskStopApi(Resource):
|
|||||||
Stop workflow task
|
Stop workflow task
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||||
@ -538,7 +530,7 @@ class PublishedRagPipelineApi(Resource):
|
|||||||
Get published pipeline
|
Get published pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
if not pipeline.is_published:
|
if not pipeline.is_published:
|
||||||
return None
|
return None
|
||||||
@ -558,10 +550,7 @@ class PublishedRagPipelineApi(Resource):
|
|||||||
Publish workflow
|
Publish workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
@ -595,7 +584,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
|||||||
Get default block config
|
Get default block config
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# Get default block configs
|
# Get default block configs
|
||||||
@ -613,7 +602,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||||||
Get default block config
|
Get default block config
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
@ -659,7 +648,7 @@ class PublishedAllRagPipelineApi(Resource):
|
|||||||
"""
|
"""
|
||||||
Get published workflows
|
Get published workflows
|
||||||
"""
|
"""
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -708,10 +697,7 @@ class RagPipelineByIdApi(Resource):
|
|||||||
Update workflow attributes
|
Update workflow attributes
|
||||||
"""
|
"""
|
||||||
# Check permission
|
# Check permission
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
@ -767,7 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||||||
Get second step parameters of rag pipeline
|
Get second step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||||
@ -792,7 +778,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
|||||||
Get first step parameters of rag pipeline
|
Get first step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||||
@ -817,7 +803,7 @@ class DraftRagPipelineFirstStepApi(Resource):
|
|||||||
Get first step parameters of rag pipeline
|
Get first step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||||
@ -842,7 +828,7 @@ class DraftRagPipelineSecondStepApi(Resource):
|
|||||||
Get second step parameters of rag pipeline
|
Get second step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||||
@ -926,8 +912,11 @@ class DatasourceListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
user = current_user
|
||||||
|
if not isinstance(user, Account):
|
||||||
|
raise Forbidden()
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
|
if not tenant_id:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
|
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
|
||||||
|
|
||||||
@ -974,10 +963,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
|||||||
"""
|
"""
|
||||||
Set datasource variables
|
Set datasource variables
|
||||||
"""
|
"""
|
||||||
if not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from typing import Optional
|
|||||||
from controllers.console.datasets.error import PipelineNotFoundError
|
from controllers.console.datasets.error import PipelineNotFoundError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
|
from models.account import Account
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
|
|
||||||
|
|
||||||
@ -17,6 +18,9 @@ def get_rag_pipeline(
|
|||||||
if not kwargs.get("pipeline_id"):
|
if not kwargs.get("pipeline_id"):
|
||||||
raise ValueError("missing pipeline_id in path parameters")
|
raise ValueError("missing pipeline_id in path parameters")
|
||||||
|
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise ValueError("current_user is not an account")
|
||||||
|
|
||||||
pipeline_id = kwargs.get("pipeline_id")
|
pipeline_id = kwargs.get("pipeline_id")
|
||||||
pipeline_id = str(pipeline_id)
|
pipeline_id = str(pipeline_id)
|
||||||
|
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from core.errors.error import (
|
|||||||
QuotaExceededError,
|
QuotaExceededError,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.model import AppMode, InstalledApp
|
from models.model import AppMode, InstalledApp
|
||||||
@ -78,6 +79,11 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
|||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
assert current_user is not None
|
assert current_user is not None
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
# Stop using both mechanisms for backward compatibility
|
||||||
|
# Legacy stop flag mechanism (without user check)
|
||||||
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||||
|
|
||||||
|
# New graph engine command channel mechanism
|
||||||
|
GraphEngineManager.send_stop_command(task_id)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|||||||
@ -32,4 +32,4 @@ class SpecSchemaDefinitionsApi(Resource):
|
|||||||
return [], 200
|
return [], 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
|
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
|
||||||
|
|||||||
@ -21,11 +21,11 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
|
|||||||
from core.mcp.error import MCPAuthError, MCPError
|
from core.mcp.error import MCPAuthError, MCPError
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.mcp_client import MCPClient
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin import ToolProviderID
|
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from core.tools.entities.tool_entities import CredentialType
|
from core.tools.entities.tool_entities import CredentialType
|
||||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
|
from models.provider_ids import ToolProviderID
|
||||||
from services.plugin.oauth_service import OAuthProxyService
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from controllers.common.errors import UnsupportedFileTypeError
|
|||||||
from controllers.files import files_ns
|
from controllers.files import files_ns
|
||||||
from core.tools.signature import verify_tool_file_signature
|
from core.tools.signature import verify_tool_file_signature
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from models import db as global_db
|
from extensions.ext_database import db as global_db
|
||||||
|
|
||||||
|
|
||||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||||
|
|||||||
@ -26,7 +26,8 @@ from core.errors.error import (
|
|||||||
)
|
)
|
||||||
from core.helper.trace_id_helper import get_external_trace_id
|
from core.helper.trace_id_helper import get_external_trace_id
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
from core.workflow.enums import WorkflowExecutionStatus
|
||||||
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||||
from libs import helper
|
from libs import helper
|
||||||
@ -262,7 +263,12 @@ class WorkflowTaskStopApi(Resource):
|
|||||||
if app_mode != AppMode.WORKFLOW:
|
if app_mode != AppMode.WORKFLOW:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
# Stop using both mechanisms for backward compatibility
|
||||||
|
# Legacy stop flag mechanism (without user check)
|
||||||
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||||
|
|
||||||
|
# New graph engine command channel mechanism
|
||||||
|
GraphEngineManager.send_stop_command(task_id)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|||||||
@ -13,13 +13,13 @@ from controllers.service_api.wraps import (
|
|||||||
validate_dataset_token,
|
validate_dataset_token,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from fields.dataset_fields import dataset_detail_fields
|
from fields.dataset_fields import dataset_detail_fields
|
||||||
from fields.tag_fields import build_dataset_tag_fields
|
from fields.tag_fields import build_dataset_tag_fields
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.dataset import Dataset, DatasetPermissionEnum
|
from models.dataset import Dataset, DatasetPermissionEnum
|
||||||
|
from models.provider_ids import ModelProviderID
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||||
from services.tag_service import TagService
|
from services.tag_service import TagService
|
||||||
|
|||||||
@ -133,6 +133,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||||||
# validate args
|
# validate args
|
||||||
DocumentService.document_create_args_validate(knowledge_config)
|
DocumentService.document_create_args_validate(knowledge_config)
|
||||||
|
|
||||||
|
if not current_user:
|
||||||
|
raise ValueError("current_user is required")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from core.errors.error import (
|
|||||||
QuotaExceededError,
|
QuotaExceededError,
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
@ -75,7 +76,12 @@ class WorkflowTaskStopApi(WebApiResource):
|
|||||||
if app_mode != AppMode.WORKFLOW:
|
if app_mode != AppMode.WORKFLOW:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
# Stop using both mechanisms for backward compatibility
|
||||||
|
# Legacy stop flag mechanism (without user check)
|
||||||
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||||
|
|
||||||
|
# New graph engine command channel mechanism
|
||||||
|
GraphEngineManager.send_stop_command(task_id)
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|||||||
@ -90,7 +90,9 @@ class BaseAgentRunner(AppRunner):
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
|
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
|
||||||
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
||||||
return_resource=app_config.additional_features.show_retrieve_source,
|
return_resource=(
|
||||||
|
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||||
|
),
|
||||||
invoke_from=application_generate_entity.invoke_from,
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
hit_callback=hit_callback,
|
hit_callback=hit_callback,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|||||||
@ -4,8 +4,8 @@ from typing import Any
|
|||||||
from core.app.app_config.entities import ModelConfigEntity
|
from core.app.app_config.entities import ModelConfigEntity
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
|
from models.provider_ids import ModelProviderID
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigManager:
|
class ModelConfigManager:
|
||||||
|
|||||||
@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
|
|
||||||
if invoke_from == InvokeFrom.DEBUGGER:
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
# always enable retriever resource in debugger mode
|
# always enable retriever resource in debugger mode
|
||||||
app_config.additional_features.show_retrieve_source = True
|
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||||
|
|
||||||
workflow_run_id = str(uuid.uuid4())
|
workflow_run_id = str(uuid.uuid4())
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
@ -23,16 +23,17 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
|
|||||||
from core.moderation.base import ModerationError
|
from core.moderation.base import ModerationError
|
||||||
from core.moderation.input_moderation import InputModeration
|
from core.moderation.input_moderation import InputModeration
|
||||||
from core.variables.variables import VariableUnion
|
from core.variables.variables import VariableUnion
|
||||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.variable_loader import VariableLoader
|
from core.workflow.variable_loader import VariableLoader
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
from models import Workflow
|
from models import Workflow
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.model import App, Conversation, Message, MessageAnnotation
|
from models.model import App, Conversation, Message, MessageAnnotation
|
||||||
from models.workflow import ConversationVariable, WorkflowType
|
from models.workflow import ConversationVariable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -76,23 +77,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
if not app_record:
|
if not app_record:
|
||||||
raise ValueError("App not found")
|
raise ValueError("App not found")
|
||||||
|
|
||||||
workflow_callbacks: list[WorkflowCallback] = []
|
|
||||||
if dify_config.DEBUG:
|
|
||||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
|
||||||
|
|
||||||
if self.application_generate_entity.single_iteration_run:
|
if self.application_generate_entity.single_iteration_run:
|
||||||
# if only single iteration run is requested
|
# if only single iteration run is requested
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool.empty(),
|
||||||
|
start_at=time.time(),
|
||||||
|
)
|
||||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||||
workflow=self._workflow,
|
workflow=self._workflow,
|
||||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
elif self.application_generate_entity.single_loop_run:
|
elif self.application_generate_entity.single_loop_run:
|
||||||
# if only single loop run is requested
|
# if only single loop run is requested
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool.empty(),
|
||||||
|
start_at=time.time(),
|
||||||
|
)
|
||||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||||
workflow=self._workflow,
|
workflow=self._workflow,
|
||||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inputs = self.application_generate_entity.inputs
|
inputs = self.application_generate_entity.inputs
|
||||||
@ -144,16 +151,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
|
||||||
|
graph = self._init_graph(
|
||||||
|
graph_config=self._workflow.graph_dict,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
workflow_id=self._workflow.id,
|
||||||
|
tenant_id=self._workflow.tenant_id,
|
||||||
|
user_id=self.application_generate_entity.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
# RUN WORKFLOW
|
# RUN WORKFLOW
|
||||||
|
# Create Redis command channel for this workflow execution
|
||||||
|
task_id = self.application_generate_entity.task_id
|
||||||
|
channel_key = f"workflow:{task_id}:commands"
|
||||||
|
command_channel = RedisChannel(redis_client, channel_key)
|
||||||
|
|
||||||
workflow_entry = WorkflowEntry(
|
workflow_entry = WorkflowEntry(
|
||||||
tenant_id=self._workflow.tenant_id,
|
tenant_id=self._workflow.tenant_id,
|
||||||
app_id=self._workflow.app_id,
|
app_id=self._workflow.app_id,
|
||||||
workflow_id=self._workflow.id,
|
workflow_id=self._workflow.id,
|
||||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=self._workflow.graph_dict,
|
graph_config=self._workflow.graph_dict,
|
||||||
user_id=self.application_generate_entity.user_id,
|
user_id=self.application_generate_entity.user_id,
|
||||||
@ -164,12 +182,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
),
|
),
|
||||||
invoke_from=self.application_generate_entity.invoke_from,
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
call_depth=self.application_generate_entity.call_depth,
|
call_depth=self.application_generate_entity.call_depth,
|
||||||
variable_pool=variable_pool,
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
command_channel=command_channel,
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = workflow_entry.run(
|
generator = workflow_entry.run()
|
||||||
callbacks=workflow_callbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
for event in generator:
|
for event in generator:
|
||||||
self._handle_event(workflow_entry, event)
|
self._handle_event(workflow_entry, event)
|
||||||
|
|||||||
@ -30,14 +30,9 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueMessageReplaceEvent,
|
QueueMessageReplaceEvent,
|
||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
|
||||||
QueueNodeInLoopFailedEvent,
|
|
||||||
QueueNodeRetryEvent,
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
|
||||||
QueueParallelBranchRunStartedEvent,
|
|
||||||
QueueParallelBranchRunSucceededEvent,
|
|
||||||
QueuePingEvent,
|
QueuePingEvent,
|
||||||
QueueRetrieverResourcesEvent,
|
QueueRetrieverResourcesEvent,
|
||||||
QueueStopEvent,
|
QueueStopEvent,
|
||||||
@ -64,8 +59,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
|||||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
from core.workflow.entities import GraphRuntimeState
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
@ -393,9 +388,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
|
|
||||||
def _handle_node_failed_events(
|
def _handle_node_failed_events(
|
||||||
self,
|
self,
|
||||||
event: Union[
|
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
|
||||||
],
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Generator[StreamResponse, None, None]:
|
) -> Generator[StreamResponse, None, None]:
|
||||||
"""Handle various node failure events."""
|
"""Handle various node failure events."""
|
||||||
@ -440,32 +433,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_parallel_branch_started_event(
|
|
||||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle parallel branch started events."""
|
|
||||||
self._ensure_workflow_initialized()
|
|
||||||
|
|
||||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
workflow_execution_id=self._workflow_run_id,
|
|
||||||
event=event,
|
|
||||||
)
|
|
||||||
yield parallel_start_resp
|
|
||||||
|
|
||||||
def _handle_parallel_branch_finished_events(
|
|
||||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle parallel branch finished events."""
|
|
||||||
self._ensure_workflow_initialized()
|
|
||||||
|
|
||||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
workflow_execution_id=self._workflow_run_id,
|
|
||||||
event=event,
|
|
||||||
)
|
|
||||||
yield parallel_finish_resp
|
|
||||||
|
|
||||||
def _handle_iteration_start_event(
|
def _handle_iteration_start_event(
|
||||||
self, event: QueueIterationStartEvent, **kwargs
|
self, event: QueueIterationStartEvent, **kwargs
|
||||||
) -> Generator[StreamResponse, None, None]:
|
) -> Generator[StreamResponse, None, None]:
|
||||||
@ -757,8 +724,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||||
# Parallel branch events
|
|
||||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
|
||||||
# Iteration events
|
# Iteration events
|
||||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||||
@ -806,8 +771,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
event,
|
event,
|
||||||
(
|
(
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
|
||||||
QueueNodeInLoopFailedEvent,
|
|
||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
@ -820,17 +783,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Handle parallel branch finished events with isinstance check
|
|
||||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
|
||||||
yield from self._handle_parallel_branch_finished_events(
|
|
||||||
event,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
tts_publisher=tts_publisher,
|
|
||||||
trace_manager=trace_manager,
|
|
||||||
queue_message=queue_message,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# For unhandled events, we continue (original behavior)
|
# For unhandled events, we continue (original behavior)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -854,11 +806,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
graph_runtime_state = event.graph_runtime_state
|
graph_runtime_state = event.graph_runtime_state
|
||||||
yield from self._handle_workflow_started_event(event)
|
yield from self._handle_workflow_started_event(event)
|
||||||
|
|
||||||
case QueueTextChunkEvent():
|
|
||||||
yield from self._handle_text_chunk_event(
|
|
||||||
event, tts_publisher=tts_publisher, queue_message=queue_message
|
|
||||||
)
|
|
||||||
|
|
||||||
case QueueErrorEvent():
|
case QueueErrorEvent():
|
||||||
yield from self._handle_error_event(event)
|
yield from self._handle_error_event(event)
|
||||||
break
|
break
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
|
|||||||
from core.app.app_config.entities import VariableEntityType
|
from core.app.app_config.entities import VariableEntityType
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file import File, FileUploadConfig
|
from core.file import File, FileUploadConfig
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
from core.workflow.repositories.draft_variable_repository import (
|
from core.workflow.repositories.draft_variable_repository import (
|
||||||
DraftVariableSaver,
|
DraftVariableSaver,
|
||||||
DraftVariableSaverFactory,
|
DraftVariableSaverFactory,
|
||||||
|
|||||||
@ -126,6 +126,21 @@ class AppQueueManager:
|
|||||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||||
redis_client.setex(stopped_cache_key, 600, 1)
|
redis_client.setex(stopped_cache_key, 600, 1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Set task stop flag without user permission check.
|
||||||
|
This method allows stopping workflows without user context.
|
||||||
|
|
||||||
|
:param task_id: The task ID to stop
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not task_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||||
|
redis_client.setex(stopped_cache_key, 600, 1)
|
||||||
|
|
||||||
def _is_stopped(self) -> bool:
|
def _is_stopped(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if task is stopped
|
Check if task is stopped
|
||||||
|
|||||||
@ -162,7 +162,9 @@ class ChatAppRunner(AppRunner):
|
|||||||
config=app_config.dataset,
|
config=app_config.dataset,
|
||||||
query=query,
|
query=query,
|
||||||
invoke_from=application_generate_entity.invoke_from,
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
show_retrieve_source=(
|
||||||
|
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||||
|
),
|
||||||
hit_callback=hit_callback,
|
hit_callback=hit_callback,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
|
|||||||
@ -17,14 +17,9 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueLoopStartEvent,
|
QueueLoopStartEvent,
|
||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
|
||||||
QueueNodeInLoopFailedEvent,
|
|
||||||
QueueNodeRetryEvent,
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
|
||||||
QueueParallelBranchRunStartedEvent,
|
|
||||||
QueueParallelBranchRunSucceededEvent,
|
|
||||||
)
|
)
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
AgentLogStreamResponse,
|
AgentLogStreamResponse,
|
||||||
@ -37,20 +32,18 @@ from core.app.entities.task_entities import (
|
|||||||
NodeFinishStreamResponse,
|
NodeFinishStreamResponse,
|
||||||
NodeRetryStreamResponse,
|
NodeRetryStreamResponse,
|
||||||
NodeStartStreamResponse,
|
NodeStartStreamResponse,
|
||||||
ParallelBranchFinishedStreamResponse,
|
|
||||||
ParallelBranchStartStreamResponse,
|
|
||||||
WorkflowFinishStreamResponse,
|
WorkflowFinishStreamResponse,
|
||||||
WorkflowStartStreamResponse,
|
WorkflowStartStreamResponse,
|
||||||
)
|
)
|
||||||
from core.file import FILE_MODEL_IDENTITY, File
|
from core.file import FILE_MODEL_IDENTITY, File
|
||||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
|
||||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models import (
|
from models import (
|
||||||
@ -180,11 +173,10 @@ class WorkflowResponseConverter:
|
|||||||
|
|
||||||
# extras logic
|
# extras logic
|
||||||
if event.node_type == NodeType.TOOL:
|
if event.node_type == NodeType.TOOL:
|
||||||
node_data = cast(ToolNodeData, event.node_data)
|
|
||||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||||
provider_type=node_data.provider_type,
|
provider_type=ToolProviderType(event.provider_type),
|
||||||
provider_id=node_data.provider_id,
|
provider_id=event.provider_id,
|
||||||
)
|
)
|
||||||
elif event.node_type == NodeType.DATASOURCE:
|
elif event.node_type == NodeType.DATASOURCE:
|
||||||
node_data = cast(DatasourceNodeData, event.node_data)
|
node_data = cast(DatasourceNodeData, event.node_data)
|
||||||
@ -200,11 +192,7 @@ class WorkflowResponseConverter:
|
|||||||
def workflow_node_finish_to_stream_response(
|
def workflow_node_finish_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
event: QueueNodeSucceededEvent
|
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||||
| QueueNodeFailedEvent
|
|
||||||
| QueueNodeInIterationFailedEvent
|
|
||||||
| QueueNodeInLoopFailedEvent
|
|
||||||
| QueueNodeExceptionEvent,
|
|
||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_node_execution: WorkflowNodeExecution,
|
workflow_node_execution: WorkflowNodeExecution,
|
||||||
) -> Optional[NodeFinishStreamResponse]:
|
) -> Optional[NodeFinishStreamResponse]:
|
||||||
@ -238,9 +226,6 @@ class WorkflowResponseConverter:
|
|||||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||||
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
iteration_id=event.in_iteration_id,
|
iteration_id=event.in_iteration_id,
|
||||||
loop_id=event.in_loop_id,
|
loop_id=event.in_loop_id,
|
||||||
),
|
),
|
||||||
@ -292,50 +277,6 @@ class WorkflowResponseConverter:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def workflow_parallel_branch_start_to_stream_response(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
task_id: str,
|
|
||||||
workflow_execution_id: str,
|
|
||||||
event: QueueParallelBranchRunStartedEvent,
|
|
||||||
) -> ParallelBranchStartStreamResponse:
|
|
||||||
return ParallelBranchStartStreamResponse(
|
|
||||||
task_id=task_id,
|
|
||||||
workflow_run_id=workflow_execution_id,
|
|
||||||
data=ParallelBranchStartStreamResponse.Data(
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_branch_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
iteration_id=event.in_iteration_id,
|
|
||||||
loop_id=event.in_loop_id,
|
|
||||||
created_at=int(time.time()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def workflow_parallel_branch_finished_to_stream_response(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
task_id: str,
|
|
||||||
workflow_execution_id: str,
|
|
||||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
|
||||||
) -> ParallelBranchFinishedStreamResponse:
|
|
||||||
return ParallelBranchFinishedStreamResponse(
|
|
||||||
task_id=task_id,
|
|
||||||
workflow_run_id=workflow_execution_id,
|
|
||||||
data=ParallelBranchFinishedStreamResponse.Data(
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_branch_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
iteration_id=event.in_iteration_id,
|
|
||||||
loop_id=event.in_loop_id,
|
|
||||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
|
||||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
|
||||||
created_at=int(time.time()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def workflow_iteration_start_to_stream_response(
|
def workflow_iteration_start_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -350,13 +291,11 @@ class WorkflowResponseConverter:
|
|||||||
id=event.node_id,
|
id=event.node_id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type.value,
|
node_type=event.node_type.value,
|
||||||
title=event.node_data.title,
|
title=event.node_title,
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
extras={},
|
extras={},
|
||||||
inputs=event.inputs or {},
|
inputs=event.inputs or {},
|
||||||
metadata=event.metadata or {},
|
metadata=event.metadata or {},
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -374,15 +313,10 @@ class WorkflowResponseConverter:
|
|||||||
id=event.node_id,
|
id=event.node_id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type.value,
|
node_type=event.node_type.value,
|
||||||
title=event.node_data.title,
|
title=event.node_title,
|
||||||
index=event.index,
|
index=event.index,
|
||||||
pre_iteration_output=event.output,
|
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
extras={},
|
extras={},
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
|
||||||
duration=event.duration,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -401,7 +335,7 @@ class WorkflowResponseConverter:
|
|||||||
id=event.node_id,
|
id=event.node_id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type.value,
|
node_type=event.node_type.value,
|
||||||
title=event.node_data.title,
|
title=event.node_title,
|
||||||
outputs=json_converter.to_json_encodable(event.outputs),
|
outputs=json_converter.to_json_encodable(event.outputs),
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
extras={},
|
extras={},
|
||||||
@ -415,8 +349,6 @@ class WorkflowResponseConverter:
|
|||||||
execution_metadata=event.metadata,
|
execution_metadata=event.metadata,
|
||||||
finished_at=int(time.time()),
|
finished_at=int(time.time()),
|
||||||
steps=event.steps,
|
steps=event.steps,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -430,7 +362,7 @@ class WorkflowResponseConverter:
|
|||||||
id=event.node_id,
|
id=event.node_id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type.value,
|
node_type=event.node_type.value,
|
||||||
title=event.node_data.title,
|
title=event.node_title,
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
extras={},
|
extras={},
|
||||||
inputs=event.inputs or {},
|
inputs=event.inputs or {},
|
||||||
@ -454,7 +386,7 @@ class WorkflowResponseConverter:
|
|||||||
id=event.node_id,
|
id=event.node_id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type.value,
|
node_type=event.node_type.value,
|
||||||
title=event.node_data.title,
|
title=event.node_title,
|
||||||
index=event.index,
|
index=event.index,
|
||||||
pre_loop_output=event.output,
|
pre_loop_output=event.output,
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
@ -462,7 +394,6 @@ class WorkflowResponseConverter:
|
|||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||||
duration=event.duration,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -480,7 +411,7 @@ class WorkflowResponseConverter:
|
|||||||
id=event.node_id,
|
id=event.node_id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type.value,
|
node_type=event.node_type.value,
|
||||||
title=event.node_data.title,
|
title=event.node_title,
|
||||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
|
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
extras={},
|
extras={},
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Mapping
|
import time
|
||||||
from typing import Any, Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
@ -11,10 +10,12 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
RagPipelineGenerateEntity,
|
RagPipelineGenerateEntity,
|
||||||
)
|
)
|
||||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||||
|
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||||
|
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.variable_loader import VariableLoader
|
from core.workflow.variable_loader import VariableLoader
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
@ -22,7 +23,7 @@ from extensions.ext_database import db
|
|||||||
from models.dataset import Document, Pipeline
|
from models.dataset import Document, Pipeline
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
from models.workflow import Workflow, WorkflowType
|
from models.workflow import Workflow
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -84,24 +85,30 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||||||
|
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
workflow_callbacks: list[WorkflowCallback] = []
|
|
||||||
if dify_config.DEBUG:
|
|
||||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
|
||||||
|
|
||||||
# if only single iteration run is requested
|
# if only single iteration run is requested
|
||||||
if self.application_generate_entity.single_iteration_run:
|
if self.application_generate_entity.single_iteration_run:
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool.empty(),
|
||||||
|
start_at=time.time(),
|
||||||
|
)
|
||||||
# if only single iteration run is requested
|
# if only single iteration run is requested
|
||||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
elif self.application_generate_entity.single_loop_run:
|
elif self.application_generate_entity.single_loop_run:
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool.empty(),
|
||||||
|
start_at=time.time(),
|
||||||
|
)
|
||||||
# if only single loop run is requested
|
# if only single loop run is requested
|
||||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inputs = self.application_generate_entity.inputs
|
inputs = self.application_generate_entity.inputs
|
||||||
@ -121,6 +128,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||||||
datasource_info=self.application_generate_entity.datasource_info,
|
datasource_info=self.application_generate_entity.datasource_info,
|
||||||
invoke_from=self.application_generate_entity.invoke_from.value,
|
invoke_from=self.application_generate_entity.invoke_from.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
rag_pipeline_variables = []
|
rag_pipeline_variables = []
|
||||||
if workflow.rag_pipeline_variables:
|
if workflow.rag_pipeline_variables:
|
||||||
for v in workflow.rag_pipeline_variables:
|
for v in workflow.rag_pipeline_variables:
|
||||||
@ -143,11 +151,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
rag_pipeline_variables=rag_pipeline_variables,
|
rag_pipeline_variables=rag_pipeline_variables,
|
||||||
)
|
)
|
||||||
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = self._init_rag_pipeline_graph(
|
graph = self._init_rag_pipeline_graph(
|
||||||
graph_config=workflow.graph_dict,
|
graph_runtime_state=graph_runtime_state,
|
||||||
start_node_id=self.application_generate_entity.start_node_id,
|
start_node_id=self.application_generate_entity.start_node_id,
|
||||||
|
workflow=workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
# RUN WORKFLOW
|
# RUN WORKFLOW
|
||||||
@ -155,7 +165,6 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||||||
tenant_id=workflow.tenant_id,
|
tenant_id=workflow.tenant_id,
|
||||||
app_id=workflow.app_id,
|
app_id=workflow.app_id,
|
||||||
workflow_id=workflow.id,
|
workflow_id=workflow.id,
|
||||||
workflow_type=WorkflowType.value_of(workflow.type),
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=workflow.graph_dict,
|
graph_config=workflow.graph_dict,
|
||||||
user_id=self.application_generate_entity.user_id,
|
user_id=self.application_generate_entity.user_id,
|
||||||
@ -166,11 +175,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||||||
),
|
),
|
||||||
invoke_from=self.application_generate_entity.invoke_from,
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
call_depth=self.application_generate_entity.call_depth,
|
call_depth=self.application_generate_entity.call_depth,
|
||||||
variable_pool=variable_pool,
|
graph_runtime_state=graph_runtime_state,
|
||||||
thread_pool_id=self.workflow_thread_pool_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
generator = workflow_entry.run()
|
||||||
|
|
||||||
for event in generator:
|
for event in generator:
|
||||||
self._update_document_status(
|
self._update_document_status(
|
||||||
@ -194,10 +202,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||||||
# return workflow
|
# return workflow
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
|
def _init_rag_pipeline_graph(
|
||||||
|
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None
|
||||||
|
) -> Graph:
|
||||||
"""
|
"""
|
||||||
Init pipeline graph
|
Init pipeline graph
|
||||||
"""
|
"""
|
||||||
|
graph_config = workflow.graph_dict
|
||||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||||
raise ValueError("nodes or edges not found in workflow graph")
|
raise ValueError("nodes or edges not found in workflow graph")
|
||||||
|
|
||||||
@ -227,7 +238,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||||||
graph_config["nodes"] = real_run_nodes
|
graph_config["nodes"] = real_run_nodes
|
||||||
graph_config["edges"] = real_edges
|
graph_config["edges"] = real_edges
|
||||||
# init graph
|
# init graph
|
||||||
graph = Graph.init(graph_config=graph_config)
|
# Create required parameters for Graph.init
|
||||||
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id=workflow.tenant_id,
|
||||||
|
app_id=self._app_id,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="",
|
||||||
|
user_from=UserFrom.ACCOUNT.value,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
node_factory = DifyNodeFactory(
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
|
||||||
|
|
||||||
if not graph:
|
if not graph:
|
||||||
raise ValueError("graph not found in workflow")
|
raise ValueError("graph not found in workflow")
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any, Literal, Optional, Union, overload
|
from typing import Any, Literal, Union, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
@ -53,7 +53,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[True],
|
streaming: Literal[True],
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_thread_pool_id: Optional[str],
|
|
||||||
) -> Generator[Mapping | str, None, None]: ...
|
) -> Generator[Mapping | str, None, None]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@ -67,7 +66,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[False],
|
streaming: Literal[False],
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_thread_pool_id: Optional[str],
|
|
||||||
) -> Mapping[str, Any]: ...
|
) -> Mapping[str, Any]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@ -81,7 +79,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_thread_pool_id: Optional[str],
|
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@ -94,7 +91,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
call_depth: int = 0,
|
call_depth: int = 0,
|
||||||
workflow_thread_pool_id: Optional[str] = None,
|
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||||
|
|
||||||
@ -186,7 +182,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
workflow_execution_repository=workflow_execution_repository,
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
@ -200,7 +195,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
workflow_thread_pool_id: Optional[str] = None,
|
|
||||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||||
"""
|
"""
|
||||||
@ -214,7 +208,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
:param workflow_execution_repository: repository for workflow execution
|
:param workflow_execution_repository: repository for workflow execution
|
||||||
:param workflow_node_execution_repository: repository for workflow node execution
|
:param workflow_node_execution_repository: repository for workflow node execution
|
||||||
:param streaming: is stream
|
:param streaming: is stream
|
||||||
:param workflow_thread_pool_id: workflow thread pool id
|
|
||||||
"""
|
"""
|
||||||
# init queue manager
|
# init queue manager
|
||||||
queue_manager = WorkflowAppQueueManager(
|
queue_manager = WorkflowAppQueueManager(
|
||||||
@ -237,7 +230,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
"application_generate_entity": application_generate_entity,
|
"application_generate_entity": application_generate_entity,
|
||||||
"queue_manager": queue_manager,
|
"queue_manager": queue_manager,
|
||||||
"context": context,
|
"context": context,
|
||||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
|
||||||
"variable_loader": variable_loader,
|
"variable_loader": variable_loader,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -434,17 +426,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
context: contextvars.Context,
|
context: contextvars.Context,
|
||||||
variable_loader: VariableLoader,
|
variable_loader: VariableLoader,
|
||||||
workflow_thread_pool_id: Optional[str] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
|
||||||
Generate worker in a new thread.
|
|
||||||
:param flask_app: Flask app
|
|
||||||
:param application_generate_entity: application generate entity
|
|
||||||
:param queue_manager: queue manager
|
|
||||||
:param workflow_thread_pool_id: workflow thread pool id
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
workflow = session.scalar(
|
workflow = session.scalar(
|
||||||
@ -474,7 +456,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
runner = WorkflowAppRunner(
|
runner = WorkflowAppRunner(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
|
||||||
variable_loader=variable_loader,
|
variable_loader=variable_loader,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
system_user_id=system_user_id,
|
system_user_id=system_user_id,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, cast
|
import time
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
@ -9,13 +9,14 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
WorkflowAppGenerateEntity,
|
WorkflowAppGenerateEntity,
|
||||||
)
|
)
|
||||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.variable_loader import VariableLoader
|
from core.workflow.variable_loader import VariableLoader
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import Workflow, WorkflowType
|
from models.workflow import Workflow
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -31,7 +32,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
application_generate_entity: WorkflowAppGenerateEntity,
|
application_generate_entity: WorkflowAppGenerateEntity,
|
||||||
queue_manager: AppQueueManager,
|
queue_manager: AppQueueManager,
|
||||||
variable_loader: VariableLoader,
|
variable_loader: VariableLoader,
|
||||||
workflow_thread_pool_id: Optional[str] = None,
|
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
system_user_id: str,
|
system_user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -41,7 +41,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
app_id=application_generate_entity.app_config.app_id,
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
)
|
)
|
||||||
self.application_generate_entity = application_generate_entity
|
self.application_generate_entity = application_generate_entity
|
||||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
|
||||||
self._workflow = workflow
|
self._workflow = workflow
|
||||||
self._sys_user_id = system_user_id
|
self._sys_user_id = system_user_id
|
||||||
|
|
||||||
@ -52,24 +51,30 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
app_config = self.application_generate_entity.app_config
|
app_config = self.application_generate_entity.app_config
|
||||||
app_config = cast(WorkflowAppConfig, app_config)
|
app_config = cast(WorkflowAppConfig, app_config)
|
||||||
|
|
||||||
workflow_callbacks: list[WorkflowCallback] = []
|
|
||||||
if dify_config.DEBUG:
|
|
||||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
|
||||||
|
|
||||||
# if only single iteration run is requested
|
# if only single iteration run is requested
|
||||||
if self.application_generate_entity.single_iteration_run:
|
if self.application_generate_entity.single_iteration_run:
|
||||||
# if only single iteration run is requested
|
# if only single iteration run is requested
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool.empty(),
|
||||||
|
start_at=time.time(),
|
||||||
|
)
|
||||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||||
workflow=self._workflow,
|
workflow=self._workflow,
|
||||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
elif self.application_generate_entity.single_loop_run:
|
elif self.application_generate_entity.single_loop_run:
|
||||||
# if only single loop run is requested
|
# if only single loop run is requested
|
||||||
|
graph_runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool.empty(),
|
||||||
|
start_at=time.time(),
|
||||||
|
)
|
||||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||||
workflow=self._workflow,
|
workflow=self._workflow,
|
||||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inputs = self.application_generate_entity.inputs
|
inputs = self.application_generate_entity.inputs
|
||||||
@ -92,15 +97,26 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
conversation_variables=[],
|
conversation_variables=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
graph = self._init_graph(
|
||||||
|
graph_config=self._workflow.graph_dict,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
workflow_id=self._workflow.id,
|
||||||
|
tenant_id=self._workflow.tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
# RUN WORKFLOW
|
# RUN WORKFLOW
|
||||||
|
# Create Redis command channel for this workflow execution
|
||||||
|
task_id = self.application_generate_entity.task_id
|
||||||
|
channel_key = f"workflow:{task_id}:commands"
|
||||||
|
command_channel = RedisChannel(redis_client, channel_key)
|
||||||
|
|
||||||
workflow_entry = WorkflowEntry(
|
workflow_entry = WorkflowEntry(
|
||||||
tenant_id=self._workflow.tenant_id,
|
tenant_id=self._workflow.tenant_id,
|
||||||
app_id=self._workflow.app_id,
|
app_id=self._workflow.app_id,
|
||||||
workflow_id=self._workflow.id,
|
workflow_id=self._workflow.id,
|
||||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
|
||||||
graph=graph,
|
graph=graph,
|
||||||
graph_config=self._workflow.graph_dict,
|
graph_config=self._workflow.graph_dict,
|
||||||
user_id=self.application_generate_entity.user_id,
|
user_id=self.application_generate_entity.user_id,
|
||||||
@ -111,11 +127,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
),
|
),
|
||||||
invoke_from=self.application_generate_entity.invoke_from,
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
call_depth=self.application_generate_entity.call_depth,
|
call_depth=self.application_generate_entity.call_depth,
|
||||||
variable_pool=variable_pool,
|
graph_runtime_state=graph_runtime_state,
|
||||||
thread_pool_id=self.workflow_thread_pool_id,
|
command_channel=command_channel,
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
generator = workflow_entry.run()
|
||||||
|
|
||||||
for event in generator:
|
for event in generator:
|
||||||
self._handle_event(workflow_entry, event)
|
self._handle_event(workflow_entry, event)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
|
|||||||
WorkflowAppGenerateEntity,
|
WorkflowAppGenerateEntity,
|
||||||
)
|
)
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
|
AppQueueEvent,
|
||||||
MessageQueueMessage,
|
MessageQueueMessage,
|
||||||
QueueAgentLogEvent,
|
QueueAgentLogEvent,
|
||||||
QueueErrorEvent,
|
QueueErrorEvent,
|
||||||
@ -25,14 +26,9 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueLoopStartEvent,
|
QueueLoopStartEvent,
|
||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
|
||||||
QueueNodeInLoopFailedEvent,
|
|
||||||
QueueNodeRetryEvent,
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
|
||||||
QueueParallelBranchRunStartedEvent,
|
|
||||||
QueueParallelBranchRunSucceededEvent,
|
|
||||||
QueuePingEvent,
|
QueuePingEvent,
|
||||||
QueueStopEvent,
|
QueueStopEvent,
|
||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
@ -57,8 +53,8 @@ from core.app.entities.task_entities import (
|
|||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
@ -349,9 +345,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
|
|
||||||
def _handle_node_failed_events(
|
def _handle_node_failed_events(
|
||||||
self,
|
self,
|
||||||
event: Union[
|
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
|
||||||
],
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Generator[StreamResponse, None, None]:
|
) -> Generator[StreamResponse, None, None]:
|
||||||
"""Handle various node failure events."""
|
"""Handle various node failure events."""
|
||||||
@ -370,32 +364,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
if node_failed_response:
|
if node_failed_response:
|
||||||
yield node_failed_response
|
yield node_failed_response
|
||||||
|
|
||||||
def _handle_parallel_branch_started_event(
|
|
||||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle parallel branch started events."""
|
|
||||||
self._ensure_workflow_initialized()
|
|
||||||
|
|
||||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
workflow_execution_id=self._workflow_run_id,
|
|
||||||
event=event,
|
|
||||||
)
|
|
||||||
yield parallel_start_resp
|
|
||||||
|
|
||||||
def _handle_parallel_branch_finished_events(
|
|
||||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle parallel branch finished events."""
|
|
||||||
self._ensure_workflow_initialized()
|
|
||||||
|
|
||||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
workflow_execution_id=self._workflow_run_id,
|
|
||||||
event=event,
|
|
||||||
)
|
|
||||||
yield parallel_finish_resp
|
|
||||||
|
|
||||||
def _handle_iteration_start_event(
|
def _handle_iteration_start_event(
|
||||||
self, event: QueueIterationStartEvent, **kwargs
|
self, event: QueueIterationStartEvent, **kwargs
|
||||||
) -> Generator[StreamResponse, None, None]:
|
) -> Generator[StreamResponse, None, None]:
|
||||||
@ -617,8 +585,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||||
# Parallel branch events
|
|
||||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
|
||||||
# Iteration events
|
# Iteration events
|
||||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||||
@ -633,7 +599,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
|
|
||||||
def _dispatch_event(
|
def _dispatch_event(
|
||||||
self,
|
self,
|
||||||
event: Any,
|
event: AppQueueEvent,
|
||||||
*,
|
*,
|
||||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||||
@ -660,8 +626,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
event,
|
event,
|
||||||
(
|
(
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
|
||||||
QueueNodeInLoopFailedEvent,
|
|
||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
@ -674,17 +638,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Handle parallel branch finished events with isinstance check
|
|
||||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
|
||||||
yield from self._handle_parallel_branch_finished_events(
|
|
||||||
event,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
tts_publisher=tts_publisher,
|
|
||||||
trace_manager=trace_manager,
|
|
||||||
queue_message=queue_message,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Handle workflow failed and stop events with isinstance check
|
# Handle workflow failed and stop events with isinstance check
|
||||||
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
|
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
|
||||||
yield from self._handle_workflow_failed_and_stop_events(
|
yield from self._handle_workflow_failed_and_stop_events(
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from collections.abc import Mapping
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
AppQueueEvent,
|
AppQueueEvent,
|
||||||
QueueAgentLogEvent,
|
QueueAgentLogEvent,
|
||||||
@ -13,14 +14,9 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueLoopStartEvent,
|
QueueLoopStartEvent,
|
||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
|
||||||
QueueNodeInLoopFailedEvent,
|
|
||||||
QueueNodeRetryEvent,
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
|
||||||
QueueParallelBranchRunStartedEvent,
|
|
||||||
QueueParallelBranchRunSucceededEvent,
|
|
||||||
QueueRetrieverResourcesEvent,
|
QueueRetrieverResourcesEvent,
|
||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
QueueWorkflowFailedEvent,
|
QueueWorkflowFailedEvent,
|
||||||
@ -28,42 +24,39 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_events import (
|
||||||
AgentLogEvent,
|
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
GraphRunPartialSucceededEvent,
|
GraphRunPartialSucceededEvent,
|
||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
IterationRunFailedEvent,
|
NodeRunAgentLogEvent,
|
||||||
IterationRunNextEvent,
|
|
||||||
IterationRunStartedEvent,
|
|
||||||
IterationRunSucceededEvent,
|
|
||||||
LoopRunFailedEvent,
|
|
||||||
LoopRunNextEvent,
|
|
||||||
LoopRunStartedEvent,
|
|
||||||
LoopRunSucceededEvent,
|
|
||||||
NodeInIterationFailedEvent,
|
|
||||||
NodeInLoopFailedEvent,
|
|
||||||
NodeRunExceptionEvent,
|
NodeRunExceptionEvent,
|
||||||
NodeRunFailedEvent,
|
NodeRunFailedEvent,
|
||||||
|
NodeRunIterationFailedEvent,
|
||||||
|
NodeRunIterationNextEvent,
|
||||||
|
NodeRunIterationStartedEvent,
|
||||||
|
NodeRunIterationSucceededEvent,
|
||||||
|
NodeRunLoopFailedEvent,
|
||||||
|
NodeRunLoopNextEvent,
|
||||||
|
NodeRunLoopStartedEvent,
|
||||||
|
NodeRunLoopSucceededEvent,
|
||||||
NodeRunRetrieverResourceEvent,
|
NodeRunRetrieverResourceEvent,
|
||||||
NodeRunRetryEvent,
|
NodeRunRetryEvent,
|
||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
ParallelBranchRunFailedEvent,
|
|
||||||
ParallelBranchRunStartedEvent,
|
|
||||||
ParallelBranchRunSucceededEvent,
|
|
||||||
)
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
|
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from models.enums import UserFrom
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
|
||||||
@ -79,7 +72,14 @@ class WorkflowBasedAppRunner:
|
|||||||
self._variable_loader = variable_loader
|
self._variable_loader = variable_loader
|
||||||
self._app_id = app_id
|
self._app_id = app_id
|
||||||
|
|
||||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
def _init_graph(
|
||||||
|
self,
|
||||||
|
graph_config: Mapping[str, Any],
|
||||||
|
graph_runtime_state: GraphRuntimeState,
|
||||||
|
workflow_id: str = "",
|
||||||
|
tenant_id: str = "",
|
||||||
|
user_id: str = "",
|
||||||
|
) -> Graph:
|
||||||
"""
|
"""
|
||||||
Init graph
|
Init graph
|
||||||
"""
|
"""
|
||||||
@ -91,8 +91,28 @@ class WorkflowBasedAppRunner:
|
|||||||
|
|
||||||
if not isinstance(graph_config.get("edges"), list):
|
if not isinstance(graph_config.get("edges"), list):
|
||||||
raise ValueError("edges in workflow graph must be a list")
|
raise ValueError("edges in workflow graph must be a list")
|
||||||
|
|
||||||
|
# Create required parameters for Graph.init
|
||||||
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id=tenant_id or "",
|
||||||
|
app_id=self._app_id,
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id=user_id,
|
||||||
|
user_from=UserFrom.ACCOUNT.value,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the provided graph_runtime_state for consistent state management
|
||||||
|
|
||||||
|
node_factory = DifyNodeFactory(
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = Graph.init(graph_config=graph_config)
|
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||||
|
|
||||||
if not graph:
|
if not graph:
|
||||||
raise ValueError("graph not found in workflow")
|
raise ValueError("graph not found in workflow")
|
||||||
@ -104,6 +124,7 @@ class WorkflowBasedAppRunner:
|
|||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
user_inputs: dict,
|
user_inputs: dict,
|
||||||
|
graph_runtime_state: GraphRuntimeState,
|
||||||
) -> tuple[Graph, VariablePool]:
|
) -> tuple[Graph, VariablePool]:
|
||||||
"""
|
"""
|
||||||
Get variable pool of single iteration
|
Get variable pool of single iteration
|
||||||
@ -145,8 +166,25 @@ class WorkflowBasedAppRunner:
|
|||||||
|
|
||||||
graph_config["edges"] = edge_configs
|
graph_config["edges"] = edge_configs
|
||||||
|
|
||||||
|
# Create required parameters for Graph.init
|
||||||
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id=workflow.tenant_id,
|
||||||
|
app_id=self._app_id,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="",
|
||||||
|
user_from=UserFrom.ACCOUNT.value,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
node_factory = DifyNodeFactory(
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||||
|
|
||||||
if not graph:
|
if not graph:
|
||||||
raise ValueError("graph not found in workflow")
|
raise ValueError("graph not found in workflow")
|
||||||
@ -201,6 +239,7 @@ class WorkflowBasedAppRunner:
|
|||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
user_inputs: dict,
|
user_inputs: dict,
|
||||||
|
graph_runtime_state: GraphRuntimeState,
|
||||||
) -> tuple[Graph, VariablePool]:
|
) -> tuple[Graph, VariablePool]:
|
||||||
"""
|
"""
|
||||||
Get variable pool of single loop
|
Get variable pool of single loop
|
||||||
@ -242,8 +281,25 @@ class WorkflowBasedAppRunner:
|
|||||||
|
|
||||||
graph_config["edges"] = edge_configs
|
graph_config["edges"] = edge_configs
|
||||||
|
|
||||||
|
# Create required parameters for Graph.init
|
||||||
|
graph_init_params = GraphInitParams(
|
||||||
|
tenant_id=workflow.tenant_id,
|
||||||
|
app_id=self._app_id,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="",
|
||||||
|
user_from=UserFrom.ACCOUNT.value,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
node_factory = DifyNodeFactory(
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||||
|
|
||||||
if not graph:
|
if not graph:
|
||||||
raise ValueError("graph not found in workflow")
|
raise ValueError("graph not found in workflow")
|
||||||
@ -310,29 +366,21 @@ class WorkflowBasedAppRunner:
|
|||||||
)
|
)
|
||||||
elif isinstance(event, GraphRunFailedEvent):
|
elif isinstance(event, GraphRunFailedEvent):
|
||||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||||
|
elif isinstance(event, GraphRunAbortedEvent):
|
||||||
|
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||||
elif isinstance(event, NodeRunRetryEvent):
|
elif isinstance(event, NodeRunRetryEvent):
|
||||||
node_run_result = event.route_node_state.node_run_result
|
node_run_result = event.node_run_result
|
||||||
inputs: Mapping[str, Any] | None = {}
|
inputs = node_run_result.inputs
|
||||||
process_data: Mapping[str, Any] | None = {}
|
process_data = node_run_result.process_data
|
||||||
outputs: Mapping[str, Any] | None = {}
|
outputs = node_run_result.outputs
|
||||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
|
execution_metadata = node_run_result.metadata
|
||||||
if node_run_result:
|
|
||||||
inputs = node_run_result.inputs
|
|
||||||
process_data = node_run_result.process_data
|
|
||||||
outputs = node_run_result.outputs
|
|
||||||
execution_metadata = node_run_result.metadata
|
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueNodeRetryEvent(
|
QueueNodeRetryEvent(
|
||||||
node_execution_id=event.id,
|
node_execution_id=event.id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
|
node_title=event.node_title,
|
||||||
node_type=event.node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.node_data,
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.start_at,
|
start_at=event.start_at,
|
||||||
node_run_index=event.route_node_state.index,
|
|
||||||
predecessor_node_id=event.predecessor_node_id,
|
predecessor_node_id=event.predecessor_node_id,
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
@ -343,6 +391,8 @@ class WorkflowBasedAppRunner:
|
|||||||
error=event.error,
|
error=event.error,
|
||||||
execution_metadata=execution_metadata,
|
execution_metadata=execution_metadata,
|
||||||
retry_index=event.retry_index,
|
retry_index=event.retry_index,
|
||||||
|
provider_type=event.provider_type,
|
||||||
|
provider_id=event.provider_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, NodeRunStartedEvent):
|
elif isinstance(event, NodeRunStartedEvent):
|
||||||
@ -350,44 +400,30 @@ class WorkflowBasedAppRunner:
|
|||||||
QueueNodeStartedEvent(
|
QueueNodeStartedEvent(
|
||||||
node_execution_id=event.id,
|
node_execution_id=event.id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
|
node_title=event.node_title,
|
||||||
node_type=event.node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.node_data,
|
start_at=event.start_at,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.route_node_state.start_at,
|
|
||||||
node_run_index=event.route_node_state.index,
|
|
||||||
predecessor_node_id=event.predecessor_node_id,
|
predecessor_node_id=event.predecessor_node_id,
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||||
agent_strategy=event.agent_strategy,
|
agent_strategy=event.agent_strategy,
|
||||||
|
provider_type=event.provider_type,
|
||||||
|
provider_id=event.provider_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, NodeRunSucceededEvent):
|
elif isinstance(event, NodeRunSucceededEvent):
|
||||||
node_run_result = event.route_node_state.node_run_result
|
node_run_result = event.node_run_result
|
||||||
if node_run_result:
|
inputs = node_run_result.inputs
|
||||||
inputs = node_run_result.inputs
|
process_data = node_run_result.process_data
|
||||||
process_data = node_run_result.process_data
|
outputs = node_run_result.outputs
|
||||||
outputs = node_run_result.outputs
|
execution_metadata = node_run_result.metadata
|
||||||
execution_metadata = node_run_result.metadata
|
|
||||||
else:
|
|
||||||
inputs = {}
|
|
||||||
process_data = {}
|
|
||||||
outputs = {}
|
|
||||||
execution_metadata = {}
|
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueNodeSucceededEvent(
|
QueueNodeSucceededEvent(
|
||||||
node_execution_id=event.id,
|
node_execution_id=event.id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.node_data,
|
start_at=event.start_at,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.route_node_state.start_at,
|
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
process_data=process_data,
|
process_data=process_data,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
@ -396,34 +432,18 @@ class WorkflowBasedAppRunner:
|
|||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(event, NodeRunFailedEvent):
|
elif isinstance(event, NodeRunFailedEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueNodeFailedEvent(
|
QueueNodeFailedEvent(
|
||||||
node_execution_id=event.id,
|
node_execution_id=event.id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.node_data,
|
start_at=event.start_at,
|
||||||
parallel_id=event.parallel_id,
|
inputs=event.node_run_result.inputs,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
process_data=event.node_run_result.process_data,
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
outputs=event.node_run_result.outputs,
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
error=event.node_run_result.error or "Unknown error",
|
||||||
start_at=event.route_node_state.start_at,
|
execution_metadata=event.node_run_result.metadata,
|
||||||
inputs=event.route_node_state.node_run_result.inputs
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
process_data=event.route_node_state.node_run_result.process_data
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
error=event.route_node_state.node_run_result.error
|
|
||||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
|
||||||
else "Unknown error",
|
|
||||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
)
|
)
|
||||||
@ -434,93 +454,21 @@ class WorkflowBasedAppRunner:
|
|||||||
node_execution_id=event.id,
|
node_execution_id=event.id,
|
||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.node_data,
|
start_at=event.start_at,
|
||||||
parallel_id=event.parallel_id,
|
inputs=event.node_run_result.inputs,
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
process_data=event.node_run_result.process_data,
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
outputs=event.node_run_result.outputs,
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
error=event.node_run_result.error or "Unknown error",
|
||||||
start_at=event.route_node_state.start_at,
|
execution_metadata=event.node_run_result.metadata,
|
||||||
inputs=event.route_node_state.node_run_result.inputs
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
process_data=event.route_node_state.node_run_result.process_data
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
outputs=event.route_node_state.node_run_result.outputs
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
error=event.route_node_state.node_run_result.error
|
|
||||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
|
||||||
else "Unknown error",
|
|
||||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(event, NodeInIterationFailedEvent):
|
|
||||||
self._publish_event(
|
|
||||||
QueueNodeInIterationFailedEvent(
|
|
||||||
node_execution_id=event.id,
|
|
||||||
node_id=event.node_id,
|
|
||||||
node_type=event.node_type,
|
|
||||||
node_data=event.node_data,
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.route_node_state.start_at,
|
|
||||||
inputs=event.route_node_state.node_run_result.inputs
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
process_data=event.route_node_state.node_run_result.process_data
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
in_iteration_id=event.in_iteration_id,
|
|
||||||
error=event.error,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, NodeInLoopFailedEvent):
|
|
||||||
self._publish_event(
|
|
||||||
QueueNodeInLoopFailedEvent(
|
|
||||||
node_execution_id=event.id,
|
|
||||||
node_id=event.node_id,
|
|
||||||
node_type=event.node_type,
|
|
||||||
node_data=event.node_data,
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.route_node_state.start_at,
|
|
||||||
inputs=event.route_node_state.node_run_result.inputs
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
process_data=event.route_node_state.node_run_result.process_data
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
|
||||||
if event.route_node_state.node_run_result
|
|
||||||
else {},
|
|
||||||
in_loop_id=event.in_loop_id,
|
|
||||||
error=event.error,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueTextChunkEvent(
|
QueueTextChunkEvent(
|
||||||
text=event.chunk_content,
|
text=event.chunk,
|
||||||
from_variable_selector=event.from_variable_selector,
|
from_variable_selector=list(event.selector),
|
||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
)
|
)
|
||||||
@ -533,10 +481,10 @@ class WorkflowBasedAppRunner:
|
|||||||
in_loop_id=event.in_loop_id,
|
in_loop_id=event.in_loop_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, AgentLogEvent):
|
elif isinstance(event, NodeRunAgentLogEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueAgentLogEvent(
|
QueueAgentLogEvent(
|
||||||
id=event.id,
|
id=event.message_id,
|
||||||
label=event.label,
|
label=event.label,
|
||||||
node_execution_id=event.node_execution_id,
|
node_execution_id=event.node_execution_id,
|
||||||
parent_id=event.parent_id,
|
parent_id=event.parent_id,
|
||||||
@ -547,51 +495,13 @@ class WorkflowBasedAppRunner:
|
|||||||
node_id=event.node_id,
|
node_id=event.node_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||||
self._publish_event(
|
|
||||||
QueueParallelBranchRunStartedEvent(
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
in_iteration_id=event.in_iteration_id,
|
|
||||||
in_loop_id=event.in_loop_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
|
||||||
self._publish_event(
|
|
||||||
QueueParallelBranchRunSucceededEvent(
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
in_iteration_id=event.in_iteration_id,
|
|
||||||
in_loop_id=event.in_loop_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
|
||||||
self._publish_event(
|
|
||||||
QueueParallelBranchRunFailedEvent(
|
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
in_iteration_id=event.in_iteration_id,
|
|
||||||
in_loop_id=event.in_loop_id,
|
|
||||||
error=event.error,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, IterationRunStartedEvent):
|
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueIterationStartEvent(
|
QueueIterationStartEvent(
|
||||||
node_execution_id=event.iteration_id,
|
node_execution_id=event.id,
|
||||||
node_id=event.iteration_node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.iteration_node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.iteration_node_data,
|
node_title=event.node_title,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.start_at,
|
start_at=event.start_at,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
inputs=event.inputs,
|
inputs=event.inputs,
|
||||||
@ -599,55 +509,41 @@ class WorkflowBasedAppRunner:
|
|||||||
metadata=event.metadata,
|
metadata=event.metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, IterationRunNextEvent):
|
elif isinstance(event, NodeRunIterationNextEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueIterationNextEvent(
|
QueueIterationNextEvent(
|
||||||
node_execution_id=event.iteration_id,
|
node_execution_id=event.id,
|
||||||
node_id=event.iteration_node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.iteration_node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.iteration_node_data,
|
node_title=event.node_title,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
index=event.index,
|
index=event.index,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
output=event.pre_iteration_output,
|
output=event.pre_iteration_output,
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
|
||||||
duration=event.duration,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueIterationCompletedEvent(
|
QueueIterationCompletedEvent(
|
||||||
node_execution_id=event.iteration_id,
|
node_execution_id=event.id,
|
||||||
node_id=event.iteration_node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.iteration_node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.iteration_node_data,
|
node_title=event.node_title,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.start_at,
|
start_at=event.start_at,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
inputs=event.inputs,
|
inputs=event.inputs,
|
||||||
outputs=event.outputs,
|
outputs=event.outputs,
|
||||||
metadata=event.metadata,
|
metadata=event.metadata,
|
||||||
steps=event.steps,
|
steps=event.steps,
|
||||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, LoopRunStartedEvent):
|
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueLoopStartEvent(
|
QueueLoopStartEvent(
|
||||||
node_execution_id=event.loop_id,
|
node_execution_id=event.id,
|
||||||
node_id=event.loop_node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.loop_node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.loop_node_data,
|
node_title=event.node_title,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.start_at,
|
start_at=event.start_at,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
inputs=event.inputs,
|
inputs=event.inputs,
|
||||||
@ -655,42 +551,32 @@ class WorkflowBasedAppRunner:
|
|||||||
metadata=event.metadata,
|
metadata=event.metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, LoopRunNextEvent):
|
elif isinstance(event, NodeRunLoopNextEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueLoopNextEvent(
|
QueueLoopNextEvent(
|
||||||
node_execution_id=event.loop_id,
|
node_execution_id=event.id,
|
||||||
node_id=event.loop_node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.loop_node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.loop_node_data,
|
node_title=event.node_title,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
index=event.index,
|
index=event.index,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
output=event.pre_loop_output,
|
output=event.pre_loop_output,
|
||||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
|
||||||
duration=event.duration,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)):
|
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueLoopCompletedEvent(
|
QueueLoopCompletedEvent(
|
||||||
node_execution_id=event.loop_id,
|
node_execution_id=event.id,
|
||||||
node_id=event.loop_node_id,
|
node_id=event.node_id,
|
||||||
node_type=event.loop_node_type,
|
node_type=event.node_type,
|
||||||
node_data=event.loop_node_data,
|
node_title=event.node_title,
|
||||||
parallel_id=event.parallel_id,
|
|
||||||
parallel_start_node_id=event.parallel_start_node_id,
|
|
||||||
parent_parallel_id=event.parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
|
||||||
start_at=event.start_at,
|
start_at=event.start_at,
|
||||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
inputs=event.inputs,
|
inputs=event.inputs,
|
||||||
outputs=event.outputs,
|
outputs=event.outputs,
|
||||||
metadata=event.metadata,
|
metadata=event.metadata,
|
||||||
steps=event.steps,
|
steps=event.steps,
|
||||||
error=event.error if isinstance(event, LoopRunFailedEvent) else None,
|
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -7,11 +7,9 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.nodes.base import BaseNodeData
|
|
||||||
|
|
||||||
|
|
||||||
class QueueEvent(StrEnum):
|
class QueueEvent(StrEnum):
|
||||||
@ -43,9 +41,6 @@ class QueueEvent(StrEnum):
|
|||||||
ANNOTATION_REPLY = "annotation_reply"
|
ANNOTATION_REPLY = "annotation_reply"
|
||||||
AGENT_THOUGHT = "agent_thought"
|
AGENT_THOUGHT = "agent_thought"
|
||||||
MESSAGE_FILE = "message_file"
|
MESSAGE_FILE = "message_file"
|
||||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
|
||||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
|
||||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
|
||||||
AGENT_LOG = "agent_log"
|
AGENT_LOG = "agent_log"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
PING = "ping"
|
PING = "ping"
|
||||||
@ -80,15 +75,7 @@ class QueueIterationStartEvent(AppQueueEvent):
|
|||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_title: str
|
||||||
parallel_id: Optional[str] = None
|
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: Optional[str] = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
@ -108,20 +95,9 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_title: str
|
||||||
parallel_id: Optional[str] = None
|
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: Optional[str] = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
parallel_mode_run_id: Optional[str] = None
|
|
||||||
"""iteratoin run in parallel mode run id"""
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
output: Optional[Any] = None # output for the current iteration
|
output: Optional[Any] = None # output for the current iteration
|
||||||
duration: Optional[float] = None
|
|
||||||
|
|
||||||
|
|
||||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||||
@ -134,15 +110,7 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
|||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_title: str
|
||||||
parallel_id: Optional[str] = None
|
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: Optional[str] = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
@ -163,7 +131,7 @@ class QueueLoopStartEvent(AppQueueEvent):
|
|||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_title: str
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: Optional[str] = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: Optional[str] = None
|
parallel_start_node_id: Optional[str] = None
|
||||||
@ -191,7 +159,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
|||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_title: str
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: Optional[str] = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: Optional[str] = None
|
parallel_start_node_id: Optional[str] = None
|
||||||
@ -204,7 +172,6 @@ class QueueLoopNextEvent(AppQueueEvent):
|
|||||||
"""iteratoin run in parallel mode run id"""
|
"""iteratoin run in parallel mode run id"""
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
output: Optional[Any] = None # output for the current loop
|
output: Optional[Any] = None # output for the current loop
|
||||||
duration: Optional[float] = None
|
|
||||||
|
|
||||||
|
|
||||||
class QueueLoopCompletedEvent(AppQueueEvent):
|
class QueueLoopCompletedEvent(AppQueueEvent):
|
||||||
@ -217,7 +184,7 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
|||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_title: str
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: Optional[str] = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: Optional[str] = None
|
parallel_start_node_id: Optional[str] = None
|
||||||
@ -364,27 +331,24 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
|||||||
|
|
||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
|
node_title: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||||
node_run_index: int = 1
|
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: Optional[str] = None
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: Optional[str] = None
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: Optional[str] = None
|
parallel_start_node_id: Optional[str] = None
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: Optional[str] = None
|
parent_parallel_id: Optional[str] = None
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
parent_parallel_start_node_id: Optional[str] = None
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: Optional[str] = None
|
in_iteration_id: Optional[str] = None
|
||||||
"""iteration id if node is in iteration"""
|
|
||||||
in_loop_id: Optional[str] = None
|
in_loop_id: Optional[str] = None
|
||||||
"""loop id if node is in loop"""
|
|
||||||
start_at: datetime
|
start_at: datetime
|
||||||
parallel_mode_run_id: Optional[str] = None
|
parallel_mode_run_id: Optional[str] = None
|
||||||
"""iteratoin run in parallel mode run id"""
|
|
||||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||||
|
|
||||||
|
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||||
|
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||||
|
provider_id: str
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
@ -396,7 +360,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: Optional[str] = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: Optional[str] = None
|
parallel_start_node_id: Optional[str] = None
|
||||||
@ -417,10 +380,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||||
|
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
"""single iteration duration map"""
|
|
||||||
iteration_duration_map: Optional[dict[str, float]] = None
|
|
||||||
"""single loop duration map"""
|
|
||||||
loop_duration_map: Optional[dict[str, float]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class QueueAgentLogEvent(AppQueueEvent):
|
class QueueAgentLogEvent(AppQueueEvent):
|
||||||
@ -454,72 +413,6 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
|||||||
retry_index: int # retry index
|
retry_index: int # retry index
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueNodeInIterationFailedEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
|
||||||
|
|
||||||
node_execution_id: str
|
|
||||||
node_id: str
|
|
||||||
node_type: NodeType
|
|
||||||
node_data: BaseNodeData
|
|
||||||
parallel_id: Optional[str] = None
|
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: Optional[str] = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: Optional[str] = None
|
|
||||||
"""iteration id if node is in iteration"""
|
|
||||||
in_loop_id: Optional[str] = None
|
|
||||||
"""loop id if node is in loop"""
|
|
||||||
start_at: datetime
|
|
||||||
|
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
|
||||||
process_data: Optional[Mapping[str, Any]] = None
|
|
||||||
outputs: Optional[Mapping[str, Any]] = None
|
|
||||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
|
||||||
|
|
||||||
error: str
|
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeInLoopFailedEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueNodeInLoopFailedEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
|
||||||
|
|
||||||
node_execution_id: str
|
|
||||||
node_id: str
|
|
||||||
node_type: NodeType
|
|
||||||
node_data: BaseNodeData
|
|
||||||
parallel_id: Optional[str] = None
|
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: Optional[str] = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: Optional[str] = None
|
|
||||||
"""iteration id if node is in iteration"""
|
|
||||||
in_loop_id: Optional[str] = None
|
|
||||||
"""loop id if node is in loop"""
|
|
||||||
start_at: datetime
|
|
||||||
|
|
||||||
inputs: Optional[Mapping[str, Any]] = None
|
|
||||||
process_data: Optional[Mapping[str, Any]] = None
|
|
||||||
outputs: Optional[Mapping[str, Any]] = None
|
|
||||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
|
||||||
|
|
||||||
error: str
|
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeExceptionEvent(AppQueueEvent):
|
class QueueNodeExceptionEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
QueueNodeExceptionEvent entity
|
QueueNodeExceptionEvent entity
|
||||||
@ -530,7 +423,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
|||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: Optional[str] = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: Optional[str] = None
|
parallel_start_node_id: Optional[str] = None
|
||||||
@ -563,15 +455,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||||||
node_execution_id: str
|
node_execution_id: str
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: Optional[str] = None
|
||||||
"""parallel id if node is in parallel"""
|
|
||||||
parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parallel start node id if node is in parallel"""
|
|
||||||
parent_parallel_id: Optional[str] = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: Optional[str] = None
|
in_iteration_id: Optional[str] = None
|
||||||
"""iteration id if node is in iteration"""
|
"""iteration id if node is in iteration"""
|
||||||
in_loop_id: Optional[str] = None
|
in_loop_id: Optional[str] = None
|
||||||
@ -678,61 +562,3 @@ class WorkflowQueueMessage(QueueMessage):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueParallelBranchRunStartedEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
|
||||||
|
|
||||||
parallel_id: str
|
|
||||||
parallel_start_node_id: str
|
|
||||||
parent_parallel_id: Optional[str] = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: Optional[str] = None
|
|
||||||
"""iteration id if node is in iteration"""
|
|
||||||
in_loop_id: Optional[str] = None
|
|
||||||
"""loop id if node is in loop"""
|
|
||||||
|
|
||||||
|
|
||||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueParallelBranchRunSucceededEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
|
||||||
|
|
||||||
parallel_id: str
|
|
||||||
parallel_start_node_id: str
|
|
||||||
parent_parallel_id: Optional[str] = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: Optional[str] = None
|
|
||||||
"""iteration id if node is in iteration"""
|
|
||||||
in_loop_id: Optional[str] = None
|
|
||||||
"""loop id if node is in loop"""
|
|
||||||
|
|
||||||
|
|
||||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueParallelBranchRunFailedEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
|
||||||
|
|
||||||
parallel_id: str
|
|
||||||
parallel_start_node_id: str
|
|
||||||
parent_parallel_id: Optional[str] = None
|
|
||||||
"""parent parallel id if node is in parallel"""
|
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
|
||||||
"""parent parallel start node id if node is in parallel"""
|
|
||||||
in_iteration_id: Optional[str] = None
|
|
||||||
"""iteration id if node is in iteration"""
|
|
||||||
in_loop_id: Optional[str] = None
|
|
||||||
"""loop id if node is in loop"""
|
|
||||||
error: str
|
|
||||||
|
|||||||
@ -7,8 +7,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
from core.workflow.entities import AgentNodeStrategyInit
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class AnnotationReplyAccount(BaseModel):
|
class AnnotationReplyAccount(BaseModel):
|
||||||
@ -71,8 +71,6 @@ class StreamEvent(Enum):
|
|||||||
NODE_STARTED = "node_started"
|
NODE_STARTED = "node_started"
|
||||||
NODE_FINISHED = "node_finished"
|
NODE_FINISHED = "node_finished"
|
||||||
NODE_RETRY = "node_retry"
|
NODE_RETRY = "node_retry"
|
||||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
|
||||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
|
||||||
ITERATION_STARTED = "iteration_started"
|
ITERATION_STARTED = "iteration_started"
|
||||||
ITERATION_NEXT = "iteration_next"
|
ITERATION_NEXT = "iteration_next"
|
||||||
ITERATION_COMPLETED = "iteration_completed"
|
ITERATION_COMPLETED = "iteration_completed"
|
||||||
@ -440,54 +438,6 @@ class NodeRetryStreamResponse(StreamResponse):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
|
||||||
"""
|
|
||||||
ParallelBranchStartStreamResponse entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Data(BaseModel):
|
|
||||||
"""
|
|
||||||
Data entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
parallel_id: str
|
|
||||||
parallel_branch_id: str
|
|
||||||
parent_parallel_id: Optional[str] = None
|
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
|
||||||
iteration_id: Optional[str] = None
|
|
||||||
loop_id: Optional[str] = None
|
|
||||||
created_at: int
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
|
||||||
workflow_run_id: str
|
|
||||||
data: Data
|
|
||||||
|
|
||||||
|
|
||||||
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
|
||||||
"""
|
|
||||||
ParallelBranchFinishedStreamResponse entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Data(BaseModel):
|
|
||||||
"""
|
|
||||||
Data entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
parallel_id: str
|
|
||||||
parallel_branch_id: str
|
|
||||||
parent_parallel_id: Optional[str] = None
|
|
||||||
parent_parallel_start_node_id: Optional[str] = None
|
|
||||||
iteration_id: Optional[str] = None
|
|
||||||
loop_id: Optional[str] = None
|
|
||||||
status: str
|
|
||||||
error: Optional[str] = None
|
|
||||||
created_at: int
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
|
||||||
workflow_run_id: str
|
|
||||||
data: Data
|
|
||||||
|
|
||||||
|
|
||||||
class IterationNodeStartStreamResponse(StreamResponse):
|
class IterationNodeStartStreamResponse(StreamResponse):
|
||||||
"""
|
"""
|
||||||
NodeStartStreamResponse entity
|
NodeStartStreamResponse entity
|
||||||
@ -506,8 +456,6 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
|||||||
extras: dict = Field(default_factory=dict)
|
extras: dict = Field(default_factory=dict)
|
||||||
metadata: Mapping = {}
|
metadata: Mapping = {}
|
||||||
inputs: Mapping = {}
|
inputs: Mapping = {}
|
||||||
parallel_id: Optional[str] = None
|
|
||||||
parallel_start_node_id: Optional[str] = None
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
@ -530,12 +478,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
|||||||
title: str
|
title: str
|
||||||
index: int
|
index: int
|
||||||
created_at: int
|
created_at: int
|
||||||
pre_iteration_output: Optional[Any] = None
|
|
||||||
extras: dict = Field(default_factory=dict)
|
extras: dict = Field(default_factory=dict)
|
||||||
parallel_id: Optional[str] = None
|
|
||||||
parallel_start_node_id: Optional[str] = None
|
|
||||||
parallel_mode_run_id: Optional[str] = None
|
|
||||||
duration: Optional[float] = None
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
@ -567,8 +510,6 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
|||||||
execution_metadata: Optional[Mapping] = None
|
execution_metadata: Optional[Mapping] = None
|
||||||
finished_at: int
|
finished_at: int
|
||||||
steps: int
|
steps: int
|
||||||
parallel_id: Optional[str] = None
|
|
||||||
parallel_start_node_id: Optional[str] = None
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
@ -622,7 +563,6 @@ class LoopNodeNextStreamResponse(StreamResponse):
|
|||||||
parallel_id: Optional[str] = None
|
parallel_id: Optional[str] = None
|
||||||
parallel_start_node_id: Optional[str] = None
|
parallel_start_node_id: Optional[str] = None
|
||||||
parallel_mode_run_id: Optional[str] = None
|
parallel_mode_run_id: Optional[str] = None
|
||||||
duration: Optional[float] = None
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.LOOP_NEXT
|
event: StreamEvent = StreamEvent.LOOP_NEXT
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
|
|||||||
@ -28,7 +28,6 @@ from core.model_runtime.entities.provider_entities import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models.provider import (
|
from models.provider import (
|
||||||
@ -41,6 +40,7 @@ from models.provider import (
|
|||||||
ProviderType,
|
ProviderType,
|
||||||
TenantPreferredModelProvider,
|
TenantPreferredModelProvider,
|
||||||
)
|
)
|
||||||
|
from models.provider_ids import ModelProviderID
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -627,6 +627,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
Get custom model credentials.
|
Get custom model credentials.
|
||||||
"""
|
"""
|
||||||
# get provider model
|
# get provider model
|
||||||
|
|
||||||
model_provider_id = ModelProviderID(self.provider.provider)
|
model_provider_id = ModelProviderID(self.provider.provider)
|
||||||
provider_names = [self.provider.provider]
|
provider_names = [self.provider.provider]
|
||||||
if model_provider_id.is_langgenius():
|
if model_provider_id.is_langgenius():
|
||||||
@ -1124,6 +1125,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Get provider model setting.
|
Get provider model setting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_provider_id = ModelProviderID(self.provider.provider)
|
model_provider_id = ModelProviderID(self.provider.provider)
|
||||||
provider_names = [self.provider.provider]
|
provider_names = [self.provider.provider]
|
||||||
if model_provider_id.is_langgenius():
|
if model_provider_id.is_langgenius():
|
||||||
@ -1207,6 +1209,7 @@ class ProviderConfiguration(BaseModel):
|
|||||||
:param model: model name
|
:param model: model name
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_provider_id = ModelProviderID(self.provider.provider)
|
model_provider_id = ModelProviderID(self.provider.provider)
|
||||||
provider_names = [self.provider.provider]
|
provider_names = [self.provider.provider]
|
||||||
if model_provider_id.is_langgenius():
|
if model_provider_id.is_langgenius():
|
||||||
|
|||||||
@ -12,8 +12,8 @@ def obfuscated_token(token: str):
|
|||||||
|
|
||||||
|
|
||||||
def encrypt_token(tenant_id: str, token: str):
|
def encrypt_token(tenant_id: str, token: str):
|
||||||
|
from extensions.ext_database import db
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
from models.engine import db
|
|
||||||
|
|
||||||
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
|
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
|
||||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||||
|
|||||||
@ -28,8 +28,9 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
|||||||
from core.ops.utils import measure_time
|
from core.ops.utils import measure_time
|
||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
from core.workflow.node_events import AgentLogEvent
|
||||||
from models import App, Message, WorkflowNodeExecutionModel, db
|
from extensions.ext_database import db
|
||||||
|
from models import App, Message, WorkflowNodeExecutionModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from collections.abc import Sequence
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.file import file_manager
|
from core.file import file_manager
|
||||||
@ -39,86 +40,89 @@ class TokenBufferMemory:
|
|||||||
:param max_token_limit: max token limit
|
:param max_token_limit: max token limit
|
||||||
:param message_limit: message limit
|
:param message_limit: message limit
|
||||||
"""
|
"""
|
||||||
app_record = self.conversation.app
|
with Session(db.engine) as session:
|
||||||
|
app_record = self.conversation.app
|
||||||
|
|
||||||
# fetch limited messages, and return reversed
|
# fetch limited messages, and return reversed
|
||||||
stmt = (
|
stmt = (
|
||||||
select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc())
|
select(Message)
|
||||||
)
|
.where(Message.conversation_id == self.conversation.id)
|
||||||
|
.order_by(Message.created_at.desc())
|
||||||
if message_limit and message_limit > 0:
|
)
|
||||||
message_limit = min(message_limit, 500)
|
|
||||||
else:
|
|
||||||
message_limit = 500
|
|
||||||
|
|
||||||
stmt = stmt.limit(message_limit)
|
|
||||||
|
|
||||||
messages = db.session.scalars(stmt).all()
|
|
||||||
|
|
||||||
# instead of all messages from the conversation, we only need to extract messages
|
|
||||||
# that belong to the thread of last message
|
|
||||||
thread_messages = extract_thread_messages(messages)
|
|
||||||
|
|
||||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
|
||||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
|
||||||
thread_messages.pop(0)
|
|
||||||
|
|
||||||
messages = list(reversed(thread_messages))
|
|
||||||
|
|
||||||
prompt_messages: list[PromptMessage] = []
|
|
||||||
for message in messages:
|
|
||||||
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
|
||||||
if files:
|
|
||||||
file_extra_config = None
|
|
||||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
|
||||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
|
||||||
workflow_run = db.session.scalar(
|
|
||||||
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
|
|
||||||
)
|
|
||||||
if not workflow_run:
|
|
||||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
|
||||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
|
||||||
if not workflow:
|
|
||||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
|
||||||
else:
|
|
||||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
|
||||||
|
|
||||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
|
||||||
if file_extra_config and app_record:
|
|
||||||
file_objs = file_factory.build_from_message_files(
|
|
||||||
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
|
||||||
)
|
|
||||||
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
|
||||||
detail = file_extra_config.image_config.detail
|
|
||||||
else:
|
|
||||||
file_objs = []
|
|
||||||
|
|
||||||
if not file_objs:
|
|
||||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
|
||||||
else:
|
|
||||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
|
||||||
for file in file_objs:
|
|
||||||
prompt_message = file_manager.to_prompt_message_content(
|
|
||||||
file,
|
|
||||||
image_detail_config=detail,
|
|
||||||
)
|
|
||||||
prompt_message_contents.append(prompt_message)
|
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
|
||||||
|
|
||||||
|
if message_limit and message_limit > 0:
|
||||||
|
message_limit = min(message_limit, 500)
|
||||||
else:
|
else:
|
||||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
message_limit = 500
|
||||||
|
|
||||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
stmt = stmt.limit(message_limit)
|
||||||
|
|
||||||
if not prompt_messages:
|
messages = session.scalars(stmt).all()
|
||||||
return []
|
|
||||||
|
|
||||||
# prune the chat message if it exceeds the max token limit
|
# instead of all messages from the conversation, we only need to extract messages
|
||||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
# that belong to the thread of last message
|
||||||
|
thread_messages = extract_thread_messages(messages)
|
||||||
|
|
||||||
|
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||||
|
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||||
|
thread_messages.pop(0)
|
||||||
|
|
||||||
|
messages = list(reversed(thread_messages))
|
||||||
|
|
||||||
|
prompt_messages: list[PromptMessage] = []
|
||||||
|
for message in messages:
|
||||||
|
files = session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
||||||
|
if files:
|
||||||
|
file_extra_config = None
|
||||||
|
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||||
|
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||||
|
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||||
|
workflow_run = session.scalar(
|
||||||
|
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
|
||||||
|
)
|
||||||
|
if not workflow_run:
|
||||||
|
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||||
|
workflow = session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||||
|
if not workflow:
|
||||||
|
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||||
|
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||||
|
|
||||||
|
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||||
|
if file_extra_config and app_record:
|
||||||
|
file_objs = file_factory.build_from_message_files(
|
||||||
|
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||||
|
)
|
||||||
|
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||||
|
detail = file_extra_config.image_config.detail
|
||||||
|
else:
|
||||||
|
file_objs = []
|
||||||
|
|
||||||
|
if not file_objs:
|
||||||
|
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||||
|
else:
|
||||||
|
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||||
|
for file in file_objs:
|
||||||
|
prompt_message = file_manager.to_prompt_message_content(
|
||||||
|
file,
|
||||||
|
image_detail_config=detail,
|
||||||
|
)
|
||||||
|
prompt_message_contents.append(prompt_message)
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||||
|
|
||||||
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
|
|
||||||
|
else:
|
||||||
|
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||||
|
|
||||||
|
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||||
|
|
||||||
|
if not prompt_messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# prune the chat message if it exceeds the max token limit
|
||||||
|
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||||
|
|
||||||
if curr_message_tokens > max_token_limit:
|
if curr_message_tokens > max_token_limit:
|
||||||
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
|
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||||
|
|||||||
@ -24,8 +24,7 @@ from core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
|
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||||
from core.plugin.impl.model import PluginModelClient
|
|
||||||
|
|
||||||
|
|
||||||
class AIModel(BaseModel):
|
class AIModel(BaseModel):
|
||||||
@ -53,6 +52,8 @@ class AIModel(BaseModel):
|
|||||||
|
|
||||||
:return: Invoke error mapping
|
:return: Invoke error mapping
|
||||||
"""
|
"""
|
||||||
|
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
|
||||||
|
|
||||||
return {
|
return {
|
||||||
InvokeConnectionError: [InvokeConnectionError],
|
InvokeConnectionError: [InvokeConnectionError],
|
||||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||||
@ -140,6 +141,8 @@ class AIModel(BaseModel):
|
|||||||
:param credentials: model credentials
|
:param credentials: model credentials
|
||||||
:return: model schema
|
:return: model schema
|
||||||
"""
|
"""
|
||||||
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
plugin_model_manager = PluginModelClient()
|
plugin_model_manager = PluginModelClient()
|
||||||
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
||||||
# sort credentials
|
# sort credentials
|
||||||
|
|||||||
@ -22,7 +22,6 @@ from core.model_runtime.entities.model_entities import (
|
|||||||
PriceType,
|
PriceType,
|
||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
from core.plugin.impl.model import PluginModelClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -142,6 +141,8 @@ class LargeLanguageModel(AIModel):
|
|||||||
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
plugin_model_manager = PluginModelClient()
|
plugin_model_manager = PluginModelClient()
|
||||||
result = plugin_model_manager.invoke_llm(
|
result = plugin_model_manager.invoke_llm(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
@ -340,6 +341,8 @@ class LargeLanguageModel(AIModel):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||||
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
plugin_model_manager = PluginModelClient()
|
plugin_model_manager = PluginModelClient()
|
||||||
return plugin_model_manager.get_llm_num_tokens(
|
return plugin_model_manager.get_llm_num_tokens(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
|
|||||||
@ -5,7 +5,6 @@ from pydantic import ConfigDict
|
|||||||
|
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
from core.plugin.impl.model import PluginModelClient
|
|
||||||
|
|
||||||
|
|
||||||
class ModerationModel(AIModel):
|
class ModerationModel(AIModel):
|
||||||
@ -31,6 +30,8 @@ class ModerationModel(AIModel):
|
|||||||
self.started_at = time.perf_counter()
|
self.started_at = time.perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
plugin_model_manager = PluginModelClient()
|
plugin_model_manager = PluginModelClient()
|
||||||
return plugin_model_manager.invoke_moderation(
|
return plugin_model_manager.invoke_moderation(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from typing import Optional
|
|||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
from core.plugin.impl.model import PluginModelClient
|
|
||||||
|
|
||||||
|
|
||||||
class RerankModel(AIModel):
|
class RerankModel(AIModel):
|
||||||
@ -36,6 +35,8 @@ class RerankModel(AIModel):
|
|||||||
:return: rerank result
|
:return: rerank result
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
plugin_model_manager = PluginModelClient()
|
plugin_model_manager = PluginModelClient()
|
||||||
return plugin_model_manager.invoke_rerank(
|
return plugin_model_manager.invoke_rerank(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from pydantic import ConfigDict
|
|||||||
|
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
from core.plugin.impl.model import PluginModelClient
|
|
||||||
|
|
||||||
|
|
||||||
class Speech2TextModel(AIModel):
|
class Speech2TextModel(AIModel):
|
||||||
@ -28,6 +27,8 @@ class Speech2TextModel(AIModel):
|
|||||||
:return: text for given audio file
|
:return: text for given audio file
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
plugin_model_manager = PluginModelClient()
|
plugin_model_manager = PluginModelClient()
|
||||||
return plugin_model_manager.invoke_speech_to_text(
|
return plugin_model_manager.invoke_speech_to_text(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from core.entities.embedding_type import EmbeddingInputType
|
|||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
from core.plugin.impl.model import PluginModelClient
|
|
||||||
|
|
||||||
|
|
||||||
class TextEmbeddingModel(AIModel):
|
class TextEmbeddingModel(AIModel):
|
||||||
@ -37,6 +36,8 @@ class TextEmbeddingModel(AIModel):
|
|||||||
:param input_type: input type
|
:param input_type: input type
|
||||||
:return: embeddings result
|
:return: embeddings result
|
||||||
"""
|
"""
|
||||||
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
try:
|
try:
|
||||||
plugin_model_manager = PluginModelClient()
|
plugin_model_manager = PluginModelClient()
|
||||||
return plugin_model_manager.invoke_text_embedding(
|
return plugin_model_manager.invoke_text_embedding(
|
||||||
@ -61,6 +62,8 @@ class TextEmbeddingModel(AIModel):
|
|||||||
:param texts: texts to embed
|
:param texts: texts to embed
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
plugin_model_manager = PluginModelClient()
|
plugin_model_manager = PluginModelClient()
|
||||||
return plugin_model_manager.get_text_embedding_num_tokens(
|
return plugin_model_manager.get_text_embedding_num_tokens(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from pydantic import ConfigDict
|
|||||||
|
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
from core.plugin.impl.model import PluginModelClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -42,6 +41,8 @@ class TTSModel(AIModel):
|
|||||||
:return: translated audio file
|
:return: translated audio file
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
plugin_model_manager = PluginModelClient()
|
plugin_model_manager = PluginModelClient()
|
||||||
return plugin_model_manager.invoke_tts(
|
return plugin_model_manager.invoke_tts(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
@ -65,6 +66,8 @@ class TTSModel(AIModel):
|
|||||||
:param credentials: The credentials required to access the TTS model.
|
:param credentials: The credentials required to access the TTS model.
|
||||||
:return: A list of voices supported by the TTS model.
|
:return: A list of voices supported by the TTS model.
|
||||||
"""
|
"""
|
||||||
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
plugin_model_manager = PluginModelClient()
|
plugin_model_manager = PluginModelClient()
|
||||||
return plugin_model_manager.get_tts_model_voices(
|
return plugin_model_manager.get_tts_model_voices(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
|
|||||||
@ -20,10 +20,8 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
|
|||||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||||
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||||
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
|
||||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||||
from core.plugin.impl.asset import PluginAssetManager
|
from models.provider_ids import ModelProviderID
|
||||||
from core.plugin.impl.model import PluginModelClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -37,6 +35,8 @@ class ModelProviderFactory:
|
|||||||
provider_position_map: dict[str, int]
|
provider_position_map: dict[str, int]
|
||||||
|
|
||||||
def __init__(self, tenant_id: str) -> None:
|
def __init__(self, tenant_id: str) -> None:
|
||||||
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
self.provider_position_map = {}
|
self.provider_position_map = {}
|
||||||
|
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
@ -71,7 +71,7 @@ class ModelProviderFactory:
|
|||||||
|
|
||||||
return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()]
|
return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()]
|
||||||
|
|
||||||
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
|
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
|
||||||
"""
|
"""
|
||||||
Get all plugin model providers
|
Get all plugin model providers
|
||||||
:return: list of plugin model providers
|
:return: list of plugin model providers
|
||||||
@ -109,7 +109,7 @@ class ModelProviderFactory:
|
|||||||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||||
return plugin_model_provider_entity.declaration
|
return plugin_model_provider_entity.declaration
|
||||||
|
|
||||||
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
|
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
|
||||||
"""
|
"""
|
||||||
Get plugin model provider
|
Get plugin model provider
|
||||||
:param provider: provider name
|
:param provider: provider name
|
||||||
@ -366,6 +366,8 @@ class ModelProviderFactory:
|
|||||||
mime_type = image_mime_types.get(extension, "image/png")
|
mime_type = image_mime_types.get(extension, "image/png")
|
||||||
|
|
||||||
# get icon bytes from plugin asset manager
|
# get icon bytes from plugin asset manager
|
||||||
|
from core.plugin.impl.asset import PluginAssetManager
|
||||||
|
|
||||||
plugin_asset_manager = PluginAssetManager()
|
plugin_asset_manager = PluginAssetManager()
|
||||||
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||||
|
|
||||||
@ -375,5 +377,6 @@ class ModelProviderFactory:
|
|||||||
:param provider: provider name
|
:param provider: provider name
|
||||||
:return: plugin id and provider name
|
:return: plugin id and provider name
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider_id = ModelProviderID(provider)
|
provider_id = ModelProviderID(provider)
|
||||||
return provider_id.plugin_id, provider_id.provider_name
|
return provider_id.plugin_id, provider_id.provider_name
|
||||||
|
|||||||
@ -54,13 +54,10 @@ from core.ops.entities.trace_entity import (
|
|||||||
)
|
)
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from core.workflow.entities.workflow_node_execution import (
|
from core.workflow.entities import WorkflowNodeExecution
|
||||||
WorkflowNodeExecution,
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
WorkflowNodeExecutionMetadataKey,
|
from extensions.ext_database import db
|
||||||
WorkflowNodeExecutionStatus,
|
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
|
||||||
)
|
|
||||||
from core.workflow.nodes import NodeType
|
|
||||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
|||||||
)
|
)
|
||||||
from core.ops.utils import filter_none_values
|
from core.ops.utils import filter_none_values
|
||||||
from core.repositories import DifyCoreRepositoryFactory
|
from core.repositories import DifyCoreRepositoryFactory
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||||
from models.enums import MessageStatus
|
from models.enums import MessageStatus
|
||||||
|
|||||||
@ -28,8 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
|||||||
)
|
)
|
||||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||||
from core.repositories import DifyCoreRepositoryFactory
|
from core.repositories import DifyCoreRepositoryFactory
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||||
from core.workflow.nodes.enums import NodeType
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
|
||||||
|
|||||||
@ -22,8 +22,7 @@ from core.ops.entities.trace_entity import (
|
|||||||
WorkflowTraceInfo,
|
WorkflowTraceInfo,
|
||||||
)
|
)
|
||||||
from core.repositories import DifyCoreRepositoryFactory
|
from core.repositories import DifyCoreRepositoryFactory
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||||
from core.workflow.nodes.enums import NodeType
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import queue
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from cachetools import LRUCache
|
from cachetools import LRUCache
|
||||||
@ -30,13 +30,15 @@ from core.ops.entities.trace_entity import (
|
|||||||
WorkflowTraceInfo,
|
WorkflowTraceInfo,
|
||||||
)
|
)
|
||||||
from core.ops.utils import get_message_data
|
from core.ops.utils import get_message_data
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||||
from models.workflow import WorkflowAppLog, WorkflowRun
|
from models.workflow import WorkflowAppLog, WorkflowRun
|
||||||
from tasks.ops_trace_task import process_trace_tasks
|
from tasks.ops_trace_task import process_trace_tasks
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.workflow.entities import WorkflowExecution
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -410,7 +412,7 @@ class TraceTask:
|
|||||||
self,
|
self,
|
||||||
trace_type: Any,
|
trace_type: Any,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
workflow_execution: Optional[WorkflowExecution] = None,
|
workflow_execution: Optional["WorkflowExecution"] = None,
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
timer: Optional[Any] = None,
|
timer: Optional[Any] = None,
|
||||||
|
|||||||
@ -23,8 +23,7 @@ from core.ops.entities.trace_entity import (
|
|||||||
)
|
)
|
||||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||||
from core.repositories import DifyCoreRepositoryFactory
|
from core.repositories import DifyCoreRepositoryFactory
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||||
from core.workflow.nodes.enums import NodeType
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
|
||||||
|
|||||||
@ -164,7 +164,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
invoke_from=InvokeFrom.SERVICE_API,
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
streaming=stream,
|
streaming=stream,
|
||||||
call_depth=1,
|
call_depth=1,
|
||||||
workflow_thread_pool_id=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
from core.workflow.nodes.parameter_extractor.entities import (
|
from core.workflow.nodes.parameter_extractor.entities import (
|
||||||
ModelConfig as ParameterExtractorModelConfig,
|
ModelConfig as ParameterExtractorModelConfig,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
import enum
|
import enum
|
||||||
|
import json
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from core.entities.parameter_entities import CommonParameterType
|
from core.entities.parameter_entities import CommonParameterType
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.workflow.nodes.base.entities import NumberType
|
|
||||||
|
|
||||||
|
|
||||||
class PluginParameterOption(BaseModel):
|
class PluginParameterOption(BaseModel):
|
||||||
@ -154,7 +154,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
|||||||
raise ValueError("The tools selector must be a list.")
|
raise ValueError("The tools selector must be a list.")
|
||||||
return value
|
return value
|
||||||
case PluginParameterType.ANY:
|
case PluginParameterType.ANY:
|
||||||
if value and not isinstance(value, str | dict | list | NumberType):
|
if value and not isinstance(value, str | dict | list | int | float):
|
||||||
raise ValueError("The var selector must be a string, dictionary, list or number.")
|
raise ValueError("The var selector must be a string, dictionary, list or number.")
|
||||||
return value
|
return value
|
||||||
case PluginParameterType.ARRAY:
|
case PluginParameterType.ARRAY:
|
||||||
@ -162,8 +162,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
|||||||
# Try to parse JSON string for arrays
|
# Try to parse JSON string for arrays
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
|
|
||||||
parsed_value = json.loads(value)
|
parsed_value = json.loads(value)
|
||||||
if isinstance(parsed_value, list):
|
if isinstance(parsed_value, list):
|
||||||
return parsed_value
|
return parsed_value
|
||||||
@ -176,8 +174,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
|||||||
# Try to parse JSON string for objects
|
# Try to parse JSON string for objects
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
|
|
||||||
parsed_value = json.loads(value)
|
parsed_value = json.loads(value)
|
||||||
if isinstance(parsed_value, dict):
|
if isinstance(parsed_value, dict):
|
||||||
return parsed_value
|
return parsed_value
|
||||||
|
|||||||
@ -1,11 +1,9 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import enum
|
import enum
|
||||||
import re
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from werkzeug.exceptions import NotFound
|
|
||||||
|
|
||||||
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
||||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntity
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntity
|
||||||
@ -141,60 +139,6 @@ class PluginEntity(PluginInstallation):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class GenericProviderID:
|
|
||||||
organization: str
|
|
||||||
plugin_name: str
|
|
||||||
provider_name: str
|
|
||||||
is_hardcoded: bool
|
|
||||||
|
|
||||||
def to_string(self) -> str:
|
|
||||||
return str(self)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
|
|
||||||
|
|
||||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
|
||||||
if not value:
|
|
||||||
raise NotFound("plugin not found, please add plugin")
|
|
||||||
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
|
|
||||||
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
|
|
||||||
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
|
|
||||||
if re.match(r"^[a-z0-9_-]+$", value):
|
|
||||||
value = f"langgenius/{value}/{value}"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid plugin id {value}")
|
|
||||||
|
|
||||||
self.organization, self.plugin_name, self.provider_name = value.split("/")
|
|
||||||
self.is_hardcoded = is_hardcoded
|
|
||||||
|
|
||||||
def is_langgenius(self) -> bool:
|
|
||||||
return self.organization == "langgenius"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def plugin_id(self) -> str:
|
|
||||||
return f"{self.organization}/{self.plugin_name}"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProviderID(GenericProviderID):
|
|
||||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
|
||||||
super().__init__(value, is_hardcoded)
|
|
||||||
if self.organization == "langgenius" and self.provider_name == "google":
|
|
||||||
self.plugin_name = "gemini"
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderID(GenericProviderID):
|
|
||||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
|
||||||
super().__init__(value, is_hardcoded)
|
|
||||||
if self.organization == "langgenius":
|
|
||||||
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
|
|
||||||
self.plugin_name = f"{self.provider_name}_tool"
|
|
||||||
|
|
||||||
|
|
||||||
class DatasourceProviderID(GenericProviderID):
|
|
||||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
|
||||||
super().__init__(value, is_hardcoded)
|
|
||||||
|
|
||||||
|
|
||||||
class PluginDependency(BaseModel):
|
class PluginDependency(BaseModel):
|
||||||
class Type(enum.StrEnum):
|
class Type(enum.StrEnum):
|
||||||
Github = PluginInstallationSource.Github.value
|
Github = PluginInstallationSource.Github.value
|
||||||
|
|||||||
@ -2,13 +2,13 @@ from collections.abc import Generator
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.agent.entities import AgentInvokeMessage
|
from core.agent.entities import AgentInvokeMessage
|
||||||
from core.plugin.entities.plugin import GenericProviderID
|
|
||||||
from core.plugin.entities.plugin_daemon import (
|
from core.plugin.entities.plugin_daemon import (
|
||||||
PluginAgentProviderEntity,
|
PluginAgentProviderEntity,
|
||||||
)
|
)
|
||||||
from core.plugin.entities.request import PluginInvokeContext
|
from core.plugin.entities.request import PluginInvokeContext
|
||||||
from core.plugin.impl.base import BasePluginClient
|
from core.plugin.impl.base import BasePluginClient
|
||||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||||
|
from models.provider_ids import GenericProviderID
|
||||||
|
|
||||||
|
|
||||||
class PluginAgentClient(BasePluginClient):
|
class PluginAgentClient(BasePluginClient):
|
||||||
|
|||||||
@ -10,13 +10,13 @@ from core.datasource.entities.datasource_entities import (
|
|||||||
OnlineDriveDownloadFileRequest,
|
OnlineDriveDownloadFileRequest,
|
||||||
WebsiteCrawlMessage,
|
WebsiteCrawlMessage,
|
||||||
)
|
)
|
||||||
from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID
|
|
||||||
from core.plugin.entities.plugin_daemon import (
|
from core.plugin.entities.plugin_daemon import (
|
||||||
PluginBasicBooleanResponse,
|
PluginBasicBooleanResponse,
|
||||||
PluginDatasourceProviderEntity,
|
PluginDatasourceProviderEntity,
|
||||||
)
|
)
|
||||||
from core.plugin.impl.base import BasePluginClient
|
from core.plugin.impl.base import BasePluginClient
|
||||||
from core.schemas.resolver import resolve_dify_schema_refs
|
from core.schemas.resolver import resolve_dify_schema_refs
|
||||||
|
from models.provider_ids import DatasourceProviderID, GenericProviderID
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.plugin.entities.plugin import GenericProviderID
|
|
||||||
from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
|
from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
|
||||||
from core.plugin.impl.base import BasePluginClient
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
from models.provider_ids import GenericProviderID
|
||||||
|
|
||||||
|
|
||||||
class DynamicSelectClient(BasePluginClient):
|
class DynamicSelectClient(BasePluginClient):
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from collections.abc import Sequence
|
|||||||
|
|
||||||
from core.plugin.entities.bundle import PluginBundleDependency
|
from core.plugin.entities.bundle import PluginBundleDependency
|
||||||
from core.plugin.entities.plugin import (
|
from core.plugin.entities.plugin import (
|
||||||
GenericProviderID,
|
|
||||||
MissingPluginDependency,
|
MissingPluginDependency,
|
||||||
PluginDeclaration,
|
PluginDeclaration,
|
||||||
PluginEntity,
|
PluginEntity,
|
||||||
@ -16,6 +15,7 @@ from core.plugin.entities.plugin_daemon import (
|
|||||||
PluginListResponse,
|
PluginListResponse,
|
||||||
)
|
)
|
||||||
from core.plugin.impl.base import BasePluginClient
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
from models.provider_ids import GenericProviderID
|
||||||
|
|
||||||
|
|
||||||
class PluginInstaller(BasePluginClient):
|
class PluginInstaller(BasePluginClient):
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
|
||||||
from core.plugin.entities.plugin_daemon import (
|
from core.plugin.entities.plugin_daemon import (
|
||||||
PluginBasicBooleanResponse,
|
PluginBasicBooleanResponse,
|
||||||
PluginToolProviderEntity,
|
PluginToolProviderEntity,
|
||||||
@ -12,6 +11,7 @@ from core.plugin.impl.base import BasePluginClient
|
|||||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||||
from core.schemas.resolver import resolve_dify_schema_refs
|
from core.schemas.resolver import resolve_dify_schema_refs
|
||||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
||||||
|
from models.provider_ids import GenericProviderID, ToolProviderID
|
||||||
|
|
||||||
|
|
||||||
class PluginToolManager(BasePluginClient):
|
class PluginToolManager(BasePluginClient):
|
||||||
|
|||||||
@ -34,7 +34,6 @@ from core.model_runtime.entities.provider_entities import (
|
|||||||
ProviderEntity,
|
ProviderEntity,
|
||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
|
||||||
from extensions import ext_hosting_provider
|
from extensions import ext_hosting_provider
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
@ -49,6 +48,7 @@ from models.provider import (
|
|||||||
TenantDefaultModel,
|
TenantDefaultModel,
|
||||||
TenantPreferredModelProvider,
|
TenantPreferredModelProvider,
|
||||||
)
|
)
|
||||||
|
from models.provider_ids import ModelProviderID
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,10 +2,9 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.model_manager import ModelInstance
|
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.rag.splitter.fixed_text_splitter import (
|
from core.rag.splitter.fixed_text_splitter import (
|
||||||
@ -16,6 +15,9 @@ from core.rag.splitter.text_splitter import TextSplitter
|
|||||||
from models.dataset import Dataset, DatasetProcessRule
|
from models.dataset import Dataset, DatasetProcessRule
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.model_manager import ModelInstance
|
||||||
|
|
||||||
|
|
||||||
class BaseIndexProcessor(ABC):
|
class BaseIndexProcessor(ABC):
|
||||||
"""Interface for extract files."""
|
"""Interface for extract files."""
|
||||||
@ -61,7 +63,7 @@ class BaseIndexProcessor(ABC):
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
chunk_overlap: int,
|
chunk_overlap: int,
|
||||||
separator: str,
|
separator: str,
|
||||||
embedding_model_instance: Optional[ModelInstance],
|
embedding_model_instance: Optional["ModelInstance"],
|
||||||
) -> TextSplitter:
|
) -> TextSplitter:
|
||||||
"""
|
"""
|
||||||
Get the NodeParser object according to the processing rule.
|
Get the NodeParser object according to the processing rule.
|
||||||
|
|||||||
@ -9,11 +9,8 @@ from typing import Optional, Union
|
|||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.workflow.entities.workflow_execution import (
|
from core.workflow.entities import WorkflowExecution
|
||||||
WorkflowExecution,
|
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||||
WorkflowExecutionStatus,
|
|
||||||
WorkflowType,
|
|
||||||
)
|
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||||
from libs.helper import extract_tenant_id
|
from libs.helper import extract_tenant_id
|
||||||
@ -203,5 +200,4 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
# Update the in-memory cache for faster subsequent lookups
|
# Update the in-memory cache for faster subsequent lookups
|
||||||
logger.debug("Updating cache for execution_id: %s", db_model.id)
|
|
||||||
self._execution_cache[db_model.id] = db_model
|
self._execution_cache[db_model.id] = db_model
|
||||||
|
|||||||
@ -12,12 +12,8 @@ from sqlalchemy.engine import Engine
|
|||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.workflow.entities.workflow_node_execution import (
|
from core.workflow.entities import WorkflowNodeExecution
|
||||||
WorkflowNodeExecution,
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
WorkflowNodeExecutionMetadataKey,
|
|
||||||
WorkflowNodeExecutionStatus,
|
|
||||||
)
|
|
||||||
from core.workflow.nodes.enums import NodeType
|
|
||||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||||
from libs.helper import extract_tenant_id
|
from libs.helper import extract_tenant_id
|
||||||
@ -215,7 +211,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
# Update the in-memory cache for faster subsequent lookups
|
# Update the in-memory cache for faster subsequent lookups
|
||||||
# Only cache if we have a node_execution_id to use as the cache key
|
# Only cache if we have a node_execution_id to use as the cache key
|
||||||
if db_model.node_execution_id:
|
if db_model.node_execution_id:
|
||||||
logger.debug("Updating cache for node_execution_id: %s", db_model.node_execution_id)
|
|
||||||
self._node_execution_cache[db_model.node_execution_id] = db_model
|
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||||
|
|
||||||
def get_db_models_by_workflow_run(
|
def get_db_models_by_workflow_run(
|
||||||
|
|||||||
@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
from .resolver import resolve_dify_schema_refs
|
from .resolver import resolve_dify_schema_refs
|
||||||
|
|
||||||
__all__ = ["resolve_dify_schema_refs"]
|
__all__ = ["resolve_dify_schema_refs"]
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from typing import Any, ClassVar, Optional
|
|||||||
|
|
||||||
class SchemaRegistry:
|
class SchemaRegistry:
|
||||||
"""Schema registry manages JSON schemas with version support"""
|
"""Schema registry manages JSON schemas with version support"""
|
||||||
|
|
||||||
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
|
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
|
||||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||||
|
|
||||||
@ -25,41 +25,41 @@ class SchemaRegistry:
|
|||||||
if cls._default_instance is None:
|
if cls._default_instance is None:
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
schema_dir = current_dir / "builtin" / "schemas"
|
schema_dir = current_dir / "builtin" / "schemas"
|
||||||
|
|
||||||
registry = cls(str(schema_dir))
|
registry = cls(str(schema_dir))
|
||||||
registry.load_all_versions()
|
registry.load_all_versions()
|
||||||
|
|
||||||
cls._default_instance = registry
|
cls._default_instance = registry
|
||||||
|
|
||||||
return cls._default_instance
|
return cls._default_instance
|
||||||
|
|
||||||
def load_all_versions(self) -> None:
|
def load_all_versions(self) -> None:
|
||||||
"""Scans the schema directory and loads all versions"""
|
"""Scans the schema directory and loads all versions"""
|
||||||
if not self.base_dir.exists():
|
if not self.base_dir.exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
for entry in self.base_dir.iterdir():
|
for entry in self.base_dir.iterdir():
|
||||||
if not entry.is_dir():
|
if not entry.is_dir():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
version = entry.name
|
version = entry.name
|
||||||
if not version.startswith("v"):
|
if not version.startswith("v"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self._load_version_dir(version, entry)
|
self._load_version_dir(version, entry)
|
||||||
|
|
||||||
def _load_version_dir(self, version: str, version_dir: Path) -> None:
|
def _load_version_dir(self, version: str, version_dir: Path) -> None:
|
||||||
"""Loads all schemas in a version directory"""
|
"""Loads all schemas in a version directory"""
|
||||||
if not version_dir.exists():
|
if not version_dir.exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
if version not in self.versions:
|
if version not in self.versions:
|
||||||
self.versions[version] = {}
|
self.versions[version] = {}
|
||||||
|
|
||||||
for entry in version_dir.iterdir():
|
for entry in version_dir.iterdir():
|
||||||
if entry.suffix != ".json":
|
if entry.suffix != ".json":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
schema_name = entry.stem
|
schema_name = entry.stem
|
||||||
self._load_schema(version, schema_name, entry)
|
self._load_schema(version, schema_name, entry)
|
||||||
|
|
||||||
@ -68,10 +68,10 @@ class SchemaRegistry:
|
|||||||
try:
|
try:
|
||||||
with open(schema_path, encoding="utf-8") as f:
|
with open(schema_path, encoding="utf-8") as f:
|
||||||
schema = json.load(f)
|
schema = json.load(f)
|
||||||
|
|
||||||
# Store the schema
|
# Store the schema
|
||||||
self.versions[version][schema_name] = schema
|
self.versions[version][schema_name] = schema
|
||||||
|
|
||||||
# Extract and store metadata
|
# Extract and store metadata
|
||||||
uri = f"https://dify.ai/schemas/{version}/{schema_name}.json"
|
uri = f"https://dify.ai/schemas/{version}/{schema_name}.json"
|
||||||
metadata = {
|
metadata = {
|
||||||
@ -81,26 +81,26 @@ class SchemaRegistry:
|
|||||||
"deprecated": schema.get("deprecated", False),
|
"deprecated": schema.get("deprecated", False),
|
||||||
}
|
}
|
||||||
self.metadata[uri] = metadata
|
self.metadata[uri] = metadata
|
||||||
|
|
||||||
except (OSError, json.JSONDecodeError) as e:
|
except (OSError, json.JSONDecodeError) as e:
|
||||||
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
|
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def get_schema(self, uri: str) -> Optional[Any]:
|
def get_schema(self, uri: str) -> Optional[Any]:
|
||||||
"""Retrieves a schema by URI with version support"""
|
"""Retrieves a schema by URI with version support"""
|
||||||
version, schema_name = self._parse_uri(uri)
|
version, schema_name = self._parse_uri(uri)
|
||||||
if not version or not schema_name:
|
if not version or not schema_name:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
version_schemas = self.versions.get(version)
|
version_schemas = self.versions.get(version)
|
||||||
if not version_schemas:
|
if not version_schemas:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return version_schemas.get(schema_name)
|
return version_schemas.get(schema_name)
|
||||||
|
|
||||||
def _parse_uri(self, uri: str) -> tuple[str, str]:
|
def _parse_uri(self, uri: str) -> tuple[str, str]:
|
||||||
"""Parses a schema URI to extract version and schema name"""
|
"""Parses a schema URI to extract version and schema name"""
|
||||||
from core.schemas.resolver import parse_dify_schema_uri
|
from core.schemas.resolver import parse_dify_schema_uri
|
||||||
|
|
||||||
return parse_dify_schema_uri(uri)
|
return parse_dify_schema_uri(uri)
|
||||||
|
|
||||||
def list_versions(self) -> list[str]:
|
def list_versions(self) -> list[str]:
|
||||||
@ -112,19 +112,15 @@ class SchemaRegistry:
|
|||||||
version_schemas = self.versions.get(version)
|
version_schemas = self.versions.get(version)
|
||||||
if not version_schemas:
|
if not version_schemas:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return sorted(version_schemas.keys())
|
return sorted(version_schemas.keys())
|
||||||
|
|
||||||
def get_all_schemas_for_version(self, version: str = "v1") -> list[Mapping[str, Any]]:
|
def get_all_schemas_for_version(self, version: str = "v1") -> list[Mapping[str, Any]]:
|
||||||
"""Returns all schemas for a version in the API format"""
|
"""Returns all schemas for a version in the API format"""
|
||||||
version_schemas = self.versions.get(version, {})
|
version_schemas = self.versions.get(version, {})
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for schema_name, schema in version_schemas.items():
|
for schema_name, schema in version_schemas.items():
|
||||||
result.append({
|
result.append({"name": schema_name, "label": schema.get("title", schema_name), "schema": schema})
|
||||||
"name": schema_name,
|
|
||||||
"label": schema.get("title", schema_name),
|
return result
|
||||||
"schema": schema
|
|
||||||
})
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|||||||
@ -19,11 +19,13 @@ _DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$
|
|||||||
|
|
||||||
class SchemaResolutionError(Exception):
|
class SchemaResolutionError(Exception):
|
||||||
"""Base exception for schema resolution errors"""
|
"""Base exception for schema resolution errors"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CircularReferenceError(SchemaResolutionError):
|
class CircularReferenceError(SchemaResolutionError):
|
||||||
"""Raised when a circular reference is detected"""
|
"""Raised when a circular reference is detected"""
|
||||||
|
|
||||||
def __init__(self, ref_uri: str, ref_path: list[str]):
|
def __init__(self, ref_uri: str, ref_path: list[str]):
|
||||||
self.ref_uri = ref_uri
|
self.ref_uri = ref_uri
|
||||||
self.ref_path = ref_path
|
self.ref_path = ref_path
|
||||||
@ -32,6 +34,7 @@ class CircularReferenceError(SchemaResolutionError):
|
|||||||
|
|
||||||
class MaxDepthExceededError(SchemaResolutionError):
|
class MaxDepthExceededError(SchemaResolutionError):
|
||||||
"""Raised when maximum resolution depth is exceeded"""
|
"""Raised when maximum resolution depth is exceeded"""
|
||||||
|
|
||||||
def __init__(self, max_depth: int):
|
def __init__(self, max_depth: int):
|
||||||
self.max_depth = max_depth
|
self.max_depth = max_depth
|
||||||
super().__init__(f"Maximum resolution depth ({max_depth}) exceeded")
|
super().__init__(f"Maximum resolution depth ({max_depth}) exceeded")
|
||||||
@ -39,6 +42,7 @@ class MaxDepthExceededError(SchemaResolutionError):
|
|||||||
|
|
||||||
class SchemaNotFoundError(SchemaResolutionError):
|
class SchemaNotFoundError(SchemaResolutionError):
|
||||||
"""Raised when a referenced schema cannot be found"""
|
"""Raised when a referenced schema cannot be found"""
|
||||||
|
|
||||||
def __init__(self, ref_uri: str):
|
def __init__(self, ref_uri: str):
|
||||||
self.ref_uri = ref_uri
|
self.ref_uri = ref_uri
|
||||||
super().__init__(f"Schema not found: {ref_uri}")
|
super().__init__(f"Schema not found: {ref_uri}")
|
||||||
@ -47,6 +51,7 @@ class SchemaNotFoundError(SchemaResolutionError):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class QueueItem:
|
class QueueItem:
|
||||||
"""Represents an item in the BFS queue"""
|
"""Represents an item in the BFS queue"""
|
||||||
|
|
||||||
current: Any
|
current: Any
|
||||||
parent: Optional[Any]
|
parent: Optional[Any]
|
||||||
key: Optional[Union[str, int]]
|
key: Optional[Union[str, int]]
|
||||||
@ -56,39 +61,39 @@ class QueueItem:
|
|||||||
|
|
||||||
class SchemaResolver:
|
class SchemaResolver:
|
||||||
"""Resolver for Dify schema references with caching and optimizations"""
|
"""Resolver for Dify schema references with caching and optimizations"""
|
||||||
|
|
||||||
_cache: dict[str, SchemaDict] = {}
|
_cache: dict[str, SchemaDict] = {}
|
||||||
_cache_lock = threading.Lock()
|
_cache_lock = threading.Lock()
|
||||||
|
|
||||||
def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10):
|
def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10):
|
||||||
"""
|
"""
|
||||||
Initialize the schema resolver
|
Initialize the schema resolver
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
registry: Schema registry to use (defaults to default registry)
|
registry: Schema registry to use (defaults to default registry)
|
||||||
max_depth: Maximum depth for reference resolution
|
max_depth: Maximum depth for reference resolution
|
||||||
"""
|
"""
|
||||||
self.registry = registry or SchemaRegistry.default_registry()
|
self.registry = registry or SchemaRegistry.default_registry()
|
||||||
self.max_depth = max_depth
|
self.max_depth = max_depth
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clear_cache(cls) -> None:
|
def clear_cache(cls) -> None:
|
||||||
"""Clear the global schema cache"""
|
"""Clear the global schema cache"""
|
||||||
with cls._cache_lock:
|
with cls._cache_lock:
|
||||||
cls._cache.clear()
|
cls._cache.clear()
|
||||||
|
|
||||||
def resolve(self, schema: SchemaType) -> SchemaType:
|
def resolve(self, schema: SchemaType) -> SchemaType:
|
||||||
"""
|
"""
|
||||||
Resolve all $ref references in the schema
|
Resolve all $ref references in the schema
|
||||||
|
|
||||||
Performance optimization: quickly checks for $ref presence before processing.
|
Performance optimization: quickly checks for $ref presence before processing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema: Schema to resolve
|
schema: Schema to resolve
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Resolved schema with all references expanded
|
Resolved schema with all references expanded
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
CircularReferenceError: If circular reference detected
|
CircularReferenceError: If circular reference detected
|
||||||
MaxDepthExceededError: If max depth exceeded
|
MaxDepthExceededError: If max depth exceeded
|
||||||
@ -96,44 +101,39 @@ class SchemaResolver:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(schema, (dict, list)):
|
if not isinstance(schema, (dict, list)):
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
# Fast path: if no Dify refs found, return original schema unchanged
|
# Fast path: if no Dify refs found, return original schema unchanged
|
||||||
# This avoids expensive deepcopy and BFS traversal for schemas without refs
|
# This avoids expensive deepcopy and BFS traversal for schemas without refs
|
||||||
if not _has_dify_refs(schema):
|
if not _has_dify_refs(schema):
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
# Slow path: schema contains refs, perform full resolution
|
# Slow path: schema contains refs, perform full resolution
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
result = copy.deepcopy(schema)
|
result = copy.deepcopy(schema)
|
||||||
|
|
||||||
# Initialize BFS queue
|
# Initialize BFS queue
|
||||||
queue = deque([QueueItem(
|
queue = deque([QueueItem(current=result, parent=None, key=None, depth=0, ref_path=set())])
|
||||||
current=result,
|
|
||||||
parent=None,
|
|
||||||
key=None,
|
|
||||||
depth=0,
|
|
||||||
ref_path=set()
|
|
||||||
)])
|
|
||||||
|
|
||||||
while queue:
|
while queue:
|
||||||
item = queue.popleft()
|
item = queue.popleft()
|
||||||
|
|
||||||
# Process the current item
|
# Process the current item
|
||||||
self._process_queue_item(queue, item)
|
self._process_queue_item(queue, item)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _process_queue_item(self, queue: deque, item: QueueItem) -> None:
|
def _process_queue_item(self, queue: deque, item: QueueItem) -> None:
|
||||||
"""Process a single queue item"""
|
"""Process a single queue item"""
|
||||||
if isinstance(item.current, dict):
|
if isinstance(item.current, dict):
|
||||||
self._process_dict(queue, item)
|
self._process_dict(queue, item)
|
||||||
elif isinstance(item.current, list):
|
elif isinstance(item.current, list):
|
||||||
self._process_list(queue, item)
|
self._process_list(queue, item)
|
||||||
|
|
||||||
def _process_dict(self, queue: deque, item: QueueItem) -> None:
|
def _process_dict(self, queue: deque, item: QueueItem) -> None:
|
||||||
"""Process a dictionary item"""
|
"""Process a dictionary item"""
|
||||||
ref_uri = item.current.get("$ref")
|
ref_uri = item.current.get("$ref")
|
||||||
|
|
||||||
if ref_uri and _is_dify_schema_ref(ref_uri):
|
if ref_uri and _is_dify_schema_ref(ref_uri):
|
||||||
# Handle $ref resolution
|
# Handle $ref resolution
|
||||||
self._resolve_ref(queue, item, ref_uri)
|
self._resolve_ref(queue, item, ref_uri)
|
||||||
@ -144,14 +144,10 @@ class SchemaResolver:
|
|||||||
next_depth = item.depth + 1
|
next_depth = item.depth + 1
|
||||||
if next_depth >= self.max_depth:
|
if next_depth >= self.max_depth:
|
||||||
raise MaxDepthExceededError(self.max_depth)
|
raise MaxDepthExceededError(self.max_depth)
|
||||||
queue.append(QueueItem(
|
queue.append(
|
||||||
current=value,
|
QueueItem(current=value, parent=item.current, key=key, depth=next_depth, ref_path=item.ref_path)
|
||||||
parent=item.current,
|
)
|
||||||
key=key,
|
|
||||||
depth=next_depth,
|
|
||||||
ref_path=item.ref_path
|
|
||||||
))
|
|
||||||
|
|
||||||
def _process_list(self, queue: deque, item: QueueItem) -> None:
|
def _process_list(self, queue: deque, item: QueueItem) -> None:
|
||||||
"""Process a list item"""
|
"""Process a list item"""
|
||||||
for idx, value in enumerate(item.current):
|
for idx, value in enumerate(item.current):
|
||||||
@ -159,14 +155,10 @@ class SchemaResolver:
|
|||||||
next_depth = item.depth + 1
|
next_depth = item.depth + 1
|
||||||
if next_depth >= self.max_depth:
|
if next_depth >= self.max_depth:
|
||||||
raise MaxDepthExceededError(self.max_depth)
|
raise MaxDepthExceededError(self.max_depth)
|
||||||
queue.append(QueueItem(
|
queue.append(
|
||||||
current=value,
|
QueueItem(current=value, parent=item.current, key=idx, depth=next_depth, ref_path=item.ref_path)
|
||||||
parent=item.current,
|
)
|
||||||
key=idx,
|
|
||||||
depth=next_depth,
|
|
||||||
ref_path=item.ref_path
|
|
||||||
))
|
|
||||||
|
|
||||||
def _resolve_ref(self, queue: deque, item: QueueItem, ref_uri: str) -> None:
|
def _resolve_ref(self, queue: deque, item: QueueItem, ref_uri: str) -> None:
|
||||||
"""Resolve a $ref reference"""
|
"""Resolve a $ref reference"""
|
||||||
# Check for circular reference
|
# Check for circular reference
|
||||||
@ -175,82 +167,78 @@ class SchemaResolver:
|
|||||||
item.current["$circular_ref"] = True
|
item.current["$circular_ref"] = True
|
||||||
logger.warning("Circular reference detected: %s", ref_uri)
|
logger.warning("Circular reference detected: %s", ref_uri)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get resolved schema (from cache or registry)
|
# Get resolved schema (from cache or registry)
|
||||||
resolved_schema = self._get_resolved_schema(ref_uri)
|
resolved_schema = self._get_resolved_schema(ref_uri)
|
||||||
if not resolved_schema:
|
if not resolved_schema:
|
||||||
logger.warning("Schema not found: %s", ref_uri)
|
logger.warning("Schema not found: %s", ref_uri)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Update ref path
|
# Update ref path
|
||||||
new_ref_path = item.ref_path | {ref_uri}
|
new_ref_path = item.ref_path | {ref_uri}
|
||||||
|
|
||||||
# Replace the reference with resolved schema
|
# Replace the reference with resolved schema
|
||||||
next_depth = item.depth + 1
|
next_depth = item.depth + 1
|
||||||
if next_depth >= self.max_depth:
|
if next_depth >= self.max_depth:
|
||||||
raise MaxDepthExceededError(self.max_depth)
|
raise MaxDepthExceededError(self.max_depth)
|
||||||
|
|
||||||
if item.parent is None:
|
if item.parent is None:
|
||||||
# Root level replacement
|
# Root level replacement
|
||||||
item.current.clear()
|
item.current.clear()
|
||||||
item.current.update(resolved_schema)
|
item.current.update(resolved_schema)
|
||||||
queue.append(QueueItem(
|
queue.append(
|
||||||
current=item.current,
|
QueueItem(current=item.current, parent=None, key=None, depth=next_depth, ref_path=new_ref_path)
|
||||||
parent=None,
|
)
|
||||||
key=None,
|
|
||||||
depth=next_depth,
|
|
||||||
ref_path=new_ref_path
|
|
||||||
))
|
|
||||||
else:
|
else:
|
||||||
# Update parent container
|
# Update parent container
|
||||||
item.parent[item.key] = resolved_schema.copy()
|
item.parent[item.key] = resolved_schema.copy()
|
||||||
queue.append(QueueItem(
|
queue.append(
|
||||||
current=item.parent[item.key],
|
QueueItem(
|
||||||
parent=item.parent,
|
current=item.parent[item.key],
|
||||||
key=item.key,
|
parent=item.parent,
|
||||||
depth=next_depth,
|
key=item.key,
|
||||||
ref_path=new_ref_path
|
depth=next_depth,
|
||||||
))
|
ref_path=new_ref_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]:
|
def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]:
|
||||||
"""Get resolved schema from cache or registry"""
|
"""Get resolved schema from cache or registry"""
|
||||||
# Check cache first
|
# Check cache first
|
||||||
with self._cache_lock:
|
with self._cache_lock:
|
||||||
if ref_uri in self._cache:
|
if ref_uri in self._cache:
|
||||||
return self._cache[ref_uri].copy()
|
return self._cache[ref_uri].copy()
|
||||||
|
|
||||||
# Fetch from registry
|
# Fetch from registry
|
||||||
schema = self.registry.get_schema(ref_uri)
|
schema = self.registry.get_schema(ref_uri)
|
||||||
if not schema:
|
if not schema:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Clean and cache
|
# Clean and cache
|
||||||
cleaned = _remove_metadata_fields(schema)
|
cleaned = _remove_metadata_fields(schema)
|
||||||
with self._cache_lock:
|
with self._cache_lock:
|
||||||
self._cache[ref_uri] = cleaned
|
self._cache[ref_uri] = cleaned
|
||||||
|
|
||||||
return cleaned.copy()
|
return cleaned.copy()
|
||||||
|
|
||||||
|
|
||||||
def resolve_dify_schema_refs(
|
def resolve_dify_schema_refs(
|
||||||
schema: SchemaType,
|
schema: SchemaType, registry: Optional[SchemaRegistry] = None, max_depth: int = 30
|
||||||
registry: Optional[SchemaRegistry] = None,
|
|
||||||
max_depth: int = 30
|
|
||||||
) -> SchemaType:
|
) -> SchemaType:
|
||||||
"""
|
"""
|
||||||
Resolve $ref references in Dify schema to actual schema content
|
Resolve $ref references in Dify schema to actual schema content
|
||||||
|
|
||||||
This is a convenience function that creates a resolver and resolves the schema.
|
This is a convenience function that creates a resolver and resolves the schema.
|
||||||
Performance optimization: quickly checks for $ref presence before processing.
|
Performance optimization: quickly checks for $ref presence before processing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema: Schema object that may contain $ref references
|
schema: Schema object that may contain $ref references
|
||||||
registry: Optional schema registry, defaults to default registry
|
registry: Optional schema registry, defaults to default registry
|
||||||
max_depth: Maximum depth to prevent infinite loops (default: 30)
|
max_depth: Maximum depth to prevent infinite loops (default: 30)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Schema with all $ref references resolved to actual content
|
Schema with all $ref references resolved to actual content
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
CircularReferenceError: If circular reference detected
|
CircularReferenceError: If circular reference detected
|
||||||
MaxDepthExceededError: If maximum depth exceeded
|
MaxDepthExceededError: If maximum depth exceeded
|
||||||
@ -260,7 +248,7 @@ def resolve_dify_schema_refs(
|
|||||||
# This avoids expensive deepcopy and BFS traversal for schemas without refs
|
# This avoids expensive deepcopy and BFS traversal for schemas without refs
|
||||||
if not _has_dify_refs(schema):
|
if not _has_dify_refs(schema):
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
# Slow path: schema contains refs, perform full resolution
|
# Slow path: schema contains refs, perform full resolution
|
||||||
resolver = SchemaResolver(registry, max_depth)
|
resolver = SchemaResolver(registry, max_depth)
|
||||||
return resolver.resolve(schema)
|
return resolver.resolve(schema)
|
||||||
@ -269,36 +257,36 @@ def resolve_dify_schema_refs(
|
|||||||
def _remove_metadata_fields(schema: dict) -> dict:
|
def _remove_metadata_fields(schema: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
Remove metadata fields from schema that shouldn't be included in resolved output
|
Remove metadata fields from schema that shouldn't be included in resolved output
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema: Schema dictionary
|
schema: Schema dictionary
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Cleaned schema without metadata fields
|
Cleaned schema without metadata fields
|
||||||
"""
|
"""
|
||||||
# Create a copy and remove metadata fields
|
# Create a copy and remove metadata fields
|
||||||
cleaned = schema.copy()
|
cleaned = schema.copy()
|
||||||
metadata_fields = ["$id", "$schema", "version"]
|
metadata_fields = ["$id", "$schema", "version"]
|
||||||
|
|
||||||
for field in metadata_fields:
|
for field in metadata_fields:
|
||||||
cleaned.pop(field, None)
|
cleaned.pop(field, None)
|
||||||
|
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
def _is_dify_schema_ref(ref_uri: Any) -> bool:
|
def _is_dify_schema_ref(ref_uri: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the reference URI is a Dify schema reference
|
Check if the reference URI is a Dify schema reference
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ref_uri: URI to check
|
ref_uri: URI to check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if it's a Dify schema reference
|
True if it's a Dify schema reference
|
||||||
"""
|
"""
|
||||||
if not isinstance(ref_uri, str):
|
if not isinstance(ref_uri, str):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Use pre-compiled pattern for better performance
|
# Use pre-compiled pattern for better performance
|
||||||
return bool(_DIFY_SCHEMA_PATTERN.match(ref_uri))
|
return bool(_DIFY_SCHEMA_PATTERN.match(ref_uri))
|
||||||
|
|
||||||
@ -306,12 +294,12 @@ def _is_dify_schema_ref(ref_uri: Any) -> bool:
|
|||||||
def _has_dify_refs_recursive(schema: SchemaType) -> bool:
|
def _has_dify_refs_recursive(schema: SchemaType) -> bool:
|
||||||
"""
|
"""
|
||||||
Recursively check if a schema contains any Dify $ref references
|
Recursively check if a schema contains any Dify $ref references
|
||||||
|
|
||||||
This is the fallback method when string-based detection is not possible.
|
This is the fallback method when string-based detection is not possible.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema: Schema to check for references
|
schema: Schema to check for references
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if any Dify $ref is found, False otherwise
|
True if any Dify $ref is found, False otherwise
|
||||||
"""
|
"""
|
||||||
@ -320,18 +308,18 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool:
|
|||||||
ref_uri = schema.get("$ref")
|
ref_uri = schema.get("$ref")
|
||||||
if ref_uri and _is_dify_schema_ref(ref_uri):
|
if ref_uri and _is_dify_schema_ref(ref_uri):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check nested values
|
# Check nested values
|
||||||
for value in schema.values():
|
for value in schema.values():
|
||||||
if _has_dify_refs_recursive(value):
|
if _has_dify_refs_recursive(value):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
elif isinstance(schema, list):
|
elif isinstance(schema, list):
|
||||||
# Check each item in the list
|
# Check each item in the list
|
||||||
for item in schema:
|
for item in schema:
|
||||||
if _has_dify_refs_recursive(item):
|
if _has_dify_refs_recursive(item):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Primitive types don't contain refs
|
# Primitive types don't contain refs
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -339,36 +327,37 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool:
|
|||||||
def _has_dify_refs_hybrid(schema: SchemaType) -> bool:
|
def _has_dify_refs_hybrid(schema: SchemaType) -> bool:
|
||||||
"""
|
"""
|
||||||
Hybrid detection: fast string scan followed by precise recursive check
|
Hybrid detection: fast string scan followed by precise recursive check
|
||||||
|
|
||||||
Performance optimization using two-phase detection:
|
Performance optimization using two-phase detection:
|
||||||
1. Fast string scan to quickly eliminate schemas without $ref
|
1. Fast string scan to quickly eliminate schemas without $ref
|
||||||
2. Precise recursive validation only for potential candidates
|
2. Precise recursive validation only for potential candidates
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema: Schema to check for references
|
schema: Schema to check for references
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if any Dify $ref is found, False otherwise
|
True if any Dify $ref is found, False otherwise
|
||||||
"""
|
"""
|
||||||
# Phase 1: Fast string-based pre-filtering
|
# Phase 1: Fast string-based pre-filtering
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
schema_str = json.dumps(schema, separators=(',', ':'))
|
|
||||||
|
schema_str = json.dumps(schema, separators=(",", ":"))
|
||||||
|
|
||||||
# Quick elimination: no $ref at all
|
# Quick elimination: no $ref at all
|
||||||
if '"$ref"' not in schema_str:
|
if '"$ref"' not in schema_str:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Quick elimination: no Dify schema URLs
|
# Quick elimination: no Dify schema URLs
|
||||||
if 'https://dify.ai/schemas/' not in schema_str:
|
if "https://dify.ai/schemas/" not in schema_str:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except (TypeError, ValueError, OverflowError):
|
except (TypeError, ValueError, OverflowError):
|
||||||
# JSON serialization failed (e.g., circular references, non-serializable objects)
|
# JSON serialization failed (e.g., circular references, non-serializable objects)
|
||||||
# Fall back to recursive detection
|
# Fall back to recursive detection
|
||||||
logger.debug("JSON serialization failed for schema, using recursive detection")
|
logger.debug("JSON serialization failed for schema, using recursive detection")
|
||||||
return _has_dify_refs_recursive(schema)
|
return _has_dify_refs_recursive(schema)
|
||||||
|
|
||||||
# Phase 2: Precise recursive validation
|
# Phase 2: Precise recursive validation
|
||||||
# Only executed for schemas that passed string pre-filtering
|
# Only executed for schemas that passed string pre-filtering
|
||||||
return _has_dify_refs_recursive(schema)
|
return _has_dify_refs_recursive(schema)
|
||||||
@ -377,14 +366,14 @@ def _has_dify_refs_hybrid(schema: SchemaType) -> bool:
|
|||||||
def _has_dify_refs(schema: SchemaType) -> bool:
|
def _has_dify_refs(schema: SchemaType) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a schema contains any Dify $ref references
|
Check if a schema contains any Dify $ref references
|
||||||
|
|
||||||
Uses hybrid detection for optimal performance:
|
Uses hybrid detection for optimal performance:
|
||||||
- Fast string scan for quick elimination
|
- Fast string scan for quick elimination
|
||||||
- Precise recursive check for validation
|
- Precise recursive check for validation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema: Schema to check for references
|
schema: Schema to check for references
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if any Dify $ref is found, False otherwise
|
True if any Dify $ref is found, False otherwise
|
||||||
"""
|
"""
|
||||||
@ -394,15 +383,15 @@ def _has_dify_refs(schema: SchemaType) -> bool:
|
|||||||
def parse_dify_schema_uri(uri: str) -> tuple[str, str]:
|
def parse_dify_schema_uri(uri: str) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Parse a Dify schema URI to extract version and schema name
|
Parse a Dify schema URI to extract version and schema name
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
uri: Schema URI to parse
|
uri: Schema URI to parse
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (version, schema_name) or ("", "") if invalid
|
Tuple of (version, schema_name) or ("", "") if invalid
|
||||||
"""
|
"""
|
||||||
match = _DIFY_SCHEMA_PATTERN.match(uri)
|
match = _DIFY_SCHEMA_PATTERN.match(uri)
|
||||||
if not match:
|
if not match:
|
||||||
return "", ""
|
return "", ""
|
||||||
|
|
||||||
return match.group(1), match.group(2)
|
return match.group(1), match.group(2)
|
||||||
|
|||||||
@ -13,10 +13,10 @@ class SchemaManager:
|
|||||||
def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]:
|
def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Get all JSON Schema definitions for a specific version
|
Get all JSON Schema definitions for a specific version
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
version: Schema version, defaults to v1
|
version: Schema version, defaults to v1
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Array containing schema definitions, each element contains name and schema fields
|
Array containing schema definitions, each element contains name and schema fields
|
||||||
"""
|
"""
|
||||||
@ -25,31 +25,28 @@ class SchemaManager:
|
|||||||
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]:
|
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Get a specific schema by name
|
Get a specific schema by name
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema_name: Schema name
|
schema_name: Schema name
|
||||||
version: Schema version, defaults to v1
|
version: Schema version, defaults to v1
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary containing name and schema, returns None if not found
|
Dictionary containing name and schema, returns None if not found
|
||||||
"""
|
"""
|
||||||
uri = f"https://dify.ai/schemas/{version}/{schema_name}.json"
|
uri = f"https://dify.ai/schemas/{version}/{schema_name}.json"
|
||||||
schema = self.registry.get_schema(uri)
|
schema = self.registry.get_schema(uri)
|
||||||
|
|
||||||
if schema:
|
if schema:
|
||||||
return {
|
return {"name": schema_name, "schema": schema}
|
||||||
"name": schema_name,
|
|
||||||
"schema": schema
|
|
||||||
}
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def list_available_schemas(self, version: str = "v1") -> list[str]:
|
def list_available_schemas(self, version: str = "v1") -> list[str]:
|
||||||
"""
|
"""
|
||||||
List all available schema names for a specific version
|
List all available schema names for a specific version
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
version: Schema version, defaults to v1
|
version: Schema version, defaults to v1
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of schema names
|
List of schema names
|
||||||
"""
|
"""
|
||||||
@ -58,8 +55,8 @@ class SchemaManager:
|
|||||||
def list_available_versions(self) -> list[str]:
|
def list_available_versions(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
List all available schema versions
|
List all available schema versions
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of versions
|
List of versions
|
||||||
"""
|
"""
|
||||||
return self.registry.list_versions()
|
return self.registry.list_versions()
|
||||||
|
|||||||
@ -152,7 +152,6 @@ class ToolEngine:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||||
workflow_call_depth: int,
|
workflow_call_depth: int,
|
||||||
thread_pool_id: Optional[str] = None,
|
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
app_id: Optional[str] = None,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
@ -166,7 +165,6 @@ class ToolEngine:
|
|||||||
|
|
||||||
if isinstance(tool, WorkflowTool):
|
if isinstance(tool, WorkflowTool):
|
||||||
tool.workflow_call_depth = workflow_call_depth + 1
|
tool.workflow_call_depth = workflow_call_depth + 1
|
||||||
tool.thread_pool_id = thread_pool_id
|
|
||||||
|
|
||||||
if tool.runtime and tool.runtime.runtime_parameters:
|
if tool.runtime and tool.runtime.runtime_parameters:
|
||||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||||
|
|||||||
@ -13,31 +13,16 @@ from sqlalchemy.orm import Session
|
|||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
|
||||||
from core.plugin.entities.plugin import ToolProviderID
|
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
|
||||||
from core.plugin.impl.tool import PluginToolManager
|
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
|
||||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
|
||||||
from core.tools.mcp_tool.tool import MCPTool
|
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
|
||||||
from core.tools.plugin_tool.tool import PluginTool
|
|
||||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
|
||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
|
||||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.workflow.nodes.tool.entities import ToolEntity
|
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||||
from core.helper.position_helper import is_filtered
|
from core.helper.position_helper import is_filtered
|
||||||
|
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||||
from core.tools.builtin_tool.tool import BuiltinTool
|
from core.tools.builtin_tool.tool import BuiltinTool
|
||||||
@ -53,16 +38,28 @@ from core.tools.entities.tool_entities import (
|
|||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from core.tools.errors import ToolProviderNotFoundError
|
from core.tools.errors import ToolProviderNotFoundError
|
||||||
|
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||||
|
from core.tools.mcp_tool.tool import MCPTool
|
||||||
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
|
from core.tools.plugin_tool.tool import PluginTool
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.utils.configuration import (
|
from core.tools.utils.configuration import (
|
||||||
ToolParameterConfigurationManager,
|
ToolParameterConfigurationManager,
|
||||||
)
|
)
|
||||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||||
|
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||||
|
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from models.provider_ids import ToolProviderID
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||||
|
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.workflow.entities import VariablePool
|
||||||
|
from core.workflow.nodes.tool.entities import ToolEntity
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -117,6 +114,8 @@ class ToolManager:
|
|||||||
get the plugin provider
|
get the plugin provider
|
||||||
"""
|
"""
|
||||||
# check if context is set
|
# check if context is set
|
||||||
|
from core.plugin.impl.tool import PluginToolManager
|
||||||
|
|
||||||
try:
|
try:
|
||||||
contexts.plugin_tool_providers.get()
|
contexts.plugin_tool_providers.get()
|
||||||
except LookupError:
|
except LookupError:
|
||||||
@ -172,6 +171,7 @@ class ToolManager:
|
|||||||
|
|
||||||
:return: the tool
|
:return: the tool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if provider_type == ToolProviderType.BUILT_IN:
|
if provider_type == ToolProviderType.BUILT_IN:
|
||||||
# check if the builtin tool need credentials
|
# check if the builtin tool need credentials
|
||||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||||
@ -216,16 +216,16 @@ class ToolManager:
|
|||||||
# fallback to the default provider
|
# fallback to the default provider
|
||||||
if builtin_provider is None:
|
if builtin_provider is None:
|
||||||
# use the default provider
|
# use the default provider
|
||||||
builtin_provider = (
|
with Session(db.engine) as session:
|
||||||
db.session.query(BuiltinToolProvider)
|
builtin_provider = session.scalar(
|
||||||
.where(
|
sa.select(BuiltinToolProvider)
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
.where(
|
||||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||||
|
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||||
|
)
|
||||||
|
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||||
)
|
)
|
||||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if builtin_provider is None:
|
if builtin_provider is None:
|
||||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||||
else:
|
else:
|
||||||
@ -256,6 +256,7 @@ class ToolManager:
|
|||||||
# check if the credentials is expired
|
# check if the credentials is expired
|
||||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||||
# TODO: circular import
|
# TODO: circular import
|
||||||
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||||
|
|
||||||
# refresh the credentials
|
# refresh the credentials
|
||||||
@ -263,6 +264,7 @@ class ToolManager:
|
|||||||
provider_name = tool_provider.provider_name
|
provider_name = tool_provider.provider_name
|
||||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||||
|
|
||||||
oauth_handler = OAuthHandler()
|
oauth_handler = OAuthHandler()
|
||||||
# refresh the credentials
|
# refresh the credentials
|
||||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||||
@ -358,7 +360,7 @@ class ToolManager:
|
|||||||
app_id: str,
|
app_id: str,
|
||||||
agent_tool: AgentToolEntity,
|
agent_tool: AgentToolEntity,
|
||||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||||
variable_pool: Optional[VariablePool] = None,
|
variable_pool: Optional["VariablePool"] = None,
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
"""
|
"""
|
||||||
get the agent tool runtime
|
get the agent tool runtime
|
||||||
@ -400,7 +402,7 @@ class ToolManager:
|
|||||||
node_id: str,
|
node_id: str,
|
||||||
workflow_tool: "ToolEntity",
|
workflow_tool: "ToolEntity",
|
||||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||||
variable_pool: Optional[VariablePool] = None,
|
variable_pool: Optional["VariablePool"] = None,
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
"""
|
"""
|
||||||
get the workflow tool runtime
|
get the workflow tool runtime
|
||||||
@ -516,6 +518,8 @@ class ToolManager:
|
|||||||
"""
|
"""
|
||||||
list all the plugin providers
|
list all the plugin providers
|
||||||
"""
|
"""
|
||||||
|
from core.plugin.impl.tool import PluginToolManager
|
||||||
|
|
||||||
manager = PluginToolManager()
|
manager = PluginToolManager()
|
||||||
provider_entities = manager.fetch_tool_providers(tenant_id)
|
provider_entities = manager.fetch_tool_providers(tenant_id)
|
||||||
return [
|
return [
|
||||||
@ -977,7 +981,7 @@ class ToolManager:
|
|||||||
def _convert_tool_parameters_type(
|
def _convert_tool_parameters_type(
|
||||||
cls,
|
cls,
|
||||||
parameters: list[ToolParameter],
|
parameters: list[ToolParameter],
|
||||||
variable_pool: Optional[VariablePool],
|
variable_pool: Optional["VariablePool"],
|
||||||
tool_configurations: dict[str, Any],
|
tool_configurations: dict[str, Any],
|
||||||
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
|
|||||||
@ -39,14 +39,12 @@ class WorkflowTool(Tool):
|
|||||||
entity: ToolEntity,
|
entity: ToolEntity,
|
||||||
runtime: ToolRuntime,
|
runtime: ToolRuntime,
|
||||||
label: str = "Workflow",
|
label: str = "Workflow",
|
||||||
thread_pool_id: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
self.workflow_app_id = workflow_app_id
|
self.workflow_app_id = workflow_app_id
|
||||||
self.workflow_as_tool_id = workflow_as_tool_id
|
self.workflow_as_tool_id = workflow_as_tool_id
|
||||||
self.version = version
|
self.version = version
|
||||||
self.workflow_entities = workflow_entities
|
self.workflow_entities = workflow_entities
|
||||||
self.workflow_call_depth = workflow_call_depth
|
self.workflow_call_depth = workflow_call_depth
|
||||||
self.thread_pool_id = thread_pool_id
|
|
||||||
self.label = label
|
self.label = label
|
||||||
|
|
||||||
super().__init__(entity=entity, runtime=runtime)
|
super().__init__(entity=entity, runtime=runtime)
|
||||||
@ -90,7 +88,6 @@ class WorkflowTool(Tool):
|
|||||||
invoke_from=self.runtime.invoke_from,
|
invoke_from=self.runtime.invoke_from,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
call_depth=self.workflow_call_depth + 1,
|
call_depth=self.workflow_call_depth + 1,
|
||||||
workflow_thread_pool_id=self.thread_pool_id,
|
|
||||||
)
|
)
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
data = result.get("data", {})
|
data = result.get("data", {})
|
||||||
|
|||||||
@ -130,7 +130,7 @@ class ArraySegment(Segment):
|
|||||||
def markdown(self) -> str:
|
def markdown(self) -> str:
|
||||||
items = []
|
items = []
|
||||||
for item in self.value:
|
for item in self.value:
|
||||||
items.append(str(item))
|
items.append(f"- {item}")
|
||||||
return "\n".join(items)
|
return "\n".join(items)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +0,0 @@
|
|||||||
from .base_workflow_callback import WorkflowCallback
|
|
||||||
from .workflow_logging_callback import WorkflowLoggingCallback
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"WorkflowCallback",
|
|
||||||
"WorkflowLoggingCallback",
|
|
||||||
]
|
|
||||||
@ -1,12 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowCallback(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def on_event(self, event: GraphEngineEvent) -> None:
|
|
||||||
"""
|
|
||||||
Published event
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
@ -1,263 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
||||||
from core.workflow.graph_engine.entities.event import (
|
|
||||||
GraphEngineEvent,
|
|
||||||
GraphRunFailedEvent,
|
|
||||||
GraphRunPartialSucceededEvent,
|
|
||||||
GraphRunStartedEvent,
|
|
||||||
GraphRunSucceededEvent,
|
|
||||||
IterationRunFailedEvent,
|
|
||||||
IterationRunNextEvent,
|
|
||||||
IterationRunStartedEvent,
|
|
||||||
IterationRunSucceededEvent,
|
|
||||||
LoopRunFailedEvent,
|
|
||||||
LoopRunNextEvent,
|
|
||||||
LoopRunStartedEvent,
|
|
||||||
LoopRunSucceededEvent,
|
|
||||||
NodeRunFailedEvent,
|
|
||||||
NodeRunStartedEvent,
|
|
||||||
NodeRunStreamChunkEvent,
|
|
||||||
NodeRunSucceededEvent,
|
|
||||||
ParallelBranchRunFailedEvent,
|
|
||||||
ParallelBranchRunStartedEvent,
|
|
||||||
ParallelBranchRunSucceededEvent,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .base_workflow_callback import WorkflowCallback
|
|
||||||
|
|
||||||
_TEXT_COLOR_MAPPING = {
|
|
||||||
"blue": "36;1",
|
|
||||||
"yellow": "33;1",
|
|
||||||
"pink": "38;5;200",
|
|
||||||
"green": "32;1",
|
|
||||||
"red": "31;1",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowLoggingCallback(WorkflowCallback):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.current_node_id: Optional[str] = None
|
|
||||||
|
|
||||||
def on_event(self, event: GraphEngineEvent) -> None:
|
|
||||||
if isinstance(event, GraphRunStartedEvent):
|
|
||||||
self.print_text("\n[GraphRunStartedEvent]", color="pink")
|
|
||||||
elif isinstance(event, GraphRunSucceededEvent):
|
|
||||||
self.print_text("\n[GraphRunSucceededEvent]", color="green")
|
|
||||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
|
||||||
self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink")
|
|
||||||
elif isinstance(event, GraphRunFailedEvent):
|
|
||||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
|
|
||||||
elif isinstance(event, NodeRunStartedEvent):
|
|
||||||
self.on_workflow_node_execute_started(event=event)
|
|
||||||
elif isinstance(event, NodeRunSucceededEvent):
|
|
||||||
self.on_workflow_node_execute_succeeded(event=event)
|
|
||||||
elif isinstance(event, NodeRunFailedEvent):
|
|
||||||
self.on_workflow_node_execute_failed(event=event)
|
|
||||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
|
||||||
self.on_node_text_chunk(event=event)
|
|
||||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
|
||||||
self.on_workflow_parallel_started(event=event)
|
|
||||||
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
|
||||||
self.on_workflow_parallel_completed(event=event)
|
|
||||||
elif isinstance(event, IterationRunStartedEvent):
|
|
||||||
self.on_workflow_iteration_started(event=event)
|
|
||||||
elif isinstance(event, IterationRunNextEvent):
|
|
||||||
self.on_workflow_iteration_next(event=event)
|
|
||||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
|
||||||
self.on_workflow_iteration_completed(event=event)
|
|
||||||
elif isinstance(event, LoopRunStartedEvent):
|
|
||||||
self.on_workflow_loop_started(event=event)
|
|
||||||
elif isinstance(event, LoopRunNextEvent):
|
|
||||||
self.on_workflow_loop_next(event=event)
|
|
||||||
elif isinstance(event, LoopRunSucceededEvent | LoopRunFailedEvent):
|
|
||||||
self.on_workflow_loop_completed(event=event)
|
|
||||||
else:
|
|
||||||
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
|
|
||||||
|
|
||||||
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
|
|
||||||
"""
|
|
||||||
Workflow node execute started
|
|
||||||
"""
|
|
||||||
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
|
|
||||||
self.print_text(f"Node ID: {event.node_id}", color="yellow")
|
|
||||||
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
|
|
||||||
self.print_text(f"Type: {event.node_type.value}", color="yellow")
|
|
||||||
|
|
||||||
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
|
||||||
"""
|
|
||||||
Workflow node execute succeeded
|
|
||||||
"""
|
|
||||||
route_node_state = event.route_node_state
|
|
||||||
|
|
||||||
self.print_text("\n[NodeRunSucceededEvent]", color="green")
|
|
||||||
self.print_text(f"Node ID: {event.node_id}", color="green")
|
|
||||||
self.print_text(f"Node Title: {event.node_data.title}", color="green")
|
|
||||||
self.print_text(f"Type: {event.node_type.value}", color="green")
|
|
||||||
|
|
||||||
if route_node_state.node_run_result:
|
|
||||||
node_run_result = route_node_state.node_run_result
|
|
||||||
self.print_text(
|
|
||||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
|
||||||
color="green",
|
|
||||||
)
|
|
||||||
self.print_text(
|
|
||||||
f"Process Data: "
|
|
||||||
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
|
||||||
color="green",
|
|
||||||
)
|
|
||||||
self.print_text(
|
|
||||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
|
||||||
color="green",
|
|
||||||
)
|
|
||||||
self.print_text(
|
|
||||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
|
|
||||||
color="green",
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
|
|
||||||
"""
|
|
||||||
Workflow node execute failed
|
|
||||||
"""
|
|
||||||
route_node_state = event.route_node_state
|
|
||||||
|
|
||||||
self.print_text("\n[NodeRunFailedEvent]", color="red")
|
|
||||||
self.print_text(f"Node ID: {event.node_id}", color="red")
|
|
||||||
self.print_text(f"Node Title: {event.node_data.title}", color="red")
|
|
||||||
self.print_text(f"Type: {event.node_type.value}", color="red")
|
|
||||||
|
|
||||||
if route_node_state.node_run_result:
|
|
||||||
node_run_result = route_node_state.node_run_result
|
|
||||||
self.print_text(f"Error: {node_run_result.error}", color="red")
|
|
||||||
self.print_text(
|
|
||||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
self.print_text(
|
|
||||||
f"Process Data: "
|
|
||||||
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
self.print_text(
|
|
||||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
|
||||||
"""
|
|
||||||
Publish text chunk
|
|
||||||
"""
|
|
||||||
route_node_state = event.route_node_state
|
|
||||||
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
|
||||||
self.current_node_id = route_node_state.node_id
|
|
||||||
self.print_text("\n[NodeRunStreamChunkEvent]")
|
|
||||||
self.print_text(f"Node ID: {route_node_state.node_id}")
|
|
||||||
|
|
||||||
node_run_result = route_node_state.node_run_result
|
|
||||||
if node_run_result:
|
|
||||||
self.print_text(
|
|
||||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.print_text(event.chunk_content, color="pink", end="")
|
|
||||||
|
|
||||||
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
|
|
||||||
"""
|
|
||||||
Publish parallel started
|
|
||||||
"""
|
|
||||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
|
|
||||||
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
|
|
||||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
|
|
||||||
if event.in_iteration_id:
|
|
||||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
|
|
||||||
if event.in_loop_id:
|
|
||||||
self.print_text(f"Loop ID: {event.in_loop_id}", color="blue")
|
|
||||||
|
|
||||||
def on_workflow_parallel_completed(
|
|
||||||
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Publish parallel completed
|
|
||||||
"""
|
|
||||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
|
||||||
color = "blue"
|
|
||||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
|
||||||
color = "red"
|
|
||||||
|
|
||||||
self.print_text(
|
|
||||||
"\n[ParallelBranchRunSucceededEvent]"
|
|
||||||
if isinstance(event, ParallelBranchRunSucceededEvent)
|
|
||||||
else "\n[ParallelBranchRunFailedEvent]",
|
|
||||||
color=color,
|
|
||||||
)
|
|
||||||
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
|
|
||||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
|
||||||
if event.in_iteration_id:
|
|
||||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
|
|
||||||
if event.in_loop_id:
|
|
||||||
self.print_text(f"Loop ID: {event.in_loop_id}", color=color)
|
|
||||||
|
|
||||||
if isinstance(event, ParallelBranchRunFailedEvent):
|
|
||||||
self.print_text(f"Error: {event.error}", color=color)
|
|
||||||
|
|
||||||
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
|
|
||||||
"""
|
|
||||||
Publish iteration started
|
|
||||||
"""
|
|
||||||
self.print_text("\n[IterationRunStartedEvent]", color="blue")
|
|
||||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
|
||||||
|
|
||||||
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
|
|
||||||
"""
|
|
||||||
Publish iteration next
|
|
||||||
"""
|
|
||||||
self.print_text("\n[IterationRunNextEvent]", color="blue")
|
|
||||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
|
||||||
self.print_text(f"Iteration Index: {event.index}", color="blue")
|
|
||||||
|
|
||||||
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
|
|
||||||
"""
|
|
||||||
Publish iteration completed
|
|
||||||
"""
|
|
||||||
self.print_text(
|
|
||||||
"\n[IterationRunSucceededEvent]"
|
|
||||||
if isinstance(event, IterationRunSucceededEvent)
|
|
||||||
else "\n[IterationRunFailedEvent]",
|
|
||||||
color="blue",
|
|
||||||
)
|
|
||||||
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
|
|
||||||
|
|
||||||
def on_workflow_loop_started(self, event: LoopRunStartedEvent) -> None:
|
|
||||||
"""
|
|
||||||
Publish loop started
|
|
||||||
"""
|
|
||||||
self.print_text("\n[LoopRunStartedEvent]", color="blue")
|
|
||||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
|
||||||
|
|
||||||
def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None:
|
|
||||||
"""
|
|
||||||
Publish loop next
|
|
||||||
"""
|
|
||||||
self.print_text("\n[LoopRunNextEvent]", color="blue")
|
|
||||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
|
||||||
self.print_text(f"Loop Index: {event.index}", color="blue")
|
|
||||||
|
|
||||||
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None:
|
|
||||||
"""
|
|
||||||
Publish loop completed
|
|
||||||
"""
|
|
||||||
self.print_text(
|
|
||||||
"\n[LoopRunSucceededEvent]" if isinstance(event, LoopRunSucceededEvent) else "\n[LoopRunFailedEvent]",
|
|
||||||
color="blue",
|
|
||||||
)
|
|
||||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
|
||||||
|
|
||||||
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
|
|
||||||
"""Print text with highlighting and no end characters."""
|
|
||||||
text_to_print = self._get_colored_text(text, color) if color else text
|
|
||||||
print(f"{text_to_print}", end=end)
|
|
||||||
|
|
||||||
def _get_colored_text(self, text: str, color: str) -> str:
|
|
||||||
"""Get colored text."""
|
|
||||||
color_str = _TEXT_COLOR_MAPPING[color]
|
|
||||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
|
||||||
173
api/core/workflow/docs/WORKER_POOL_CONFIG.md
Normal file
173
api/core/workflow/docs/WORKER_POOL_CONFIG.md
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
# GraphEngine Worker Pool Configuration
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The GraphEngine now supports **dynamic worker pool management** to optimize performance and resource usage. Instead of a fixed 10-worker pool, the engine can:
|
||||||
|
|
||||||
|
1. **Start with optimal worker count** based on graph complexity
|
||||||
|
1. **Scale up** when workload increases
|
||||||
|
1. **Scale down** when workers are idle
|
||||||
|
1. **Respect configurable min/max limits**
|
||||||
|
|
||||||
|
## Benefits
|
||||||
|
|
||||||
|
- **Resource Efficiency**: Uses fewer workers for simple sequential workflows
|
||||||
|
- **Better Performance**: Scales up for parallel-heavy workflows
|
||||||
|
- **Gevent Optimization**: Works efficiently with Gevent's greenlet model
|
||||||
|
- **Memory Savings**: Reduces memory footprint for simple workflows
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Configuration Variables (via dify_config)
|
||||||
|
|
||||||
|
| Variable | Default | Description |
|
||||||
|
|----------|---------|-------------|
|
||||||
|
| `GRAPH_ENGINE_MIN_WORKERS` | 1 | Minimum number of workers per engine |
|
||||||
|
| `GRAPH_ENGINE_MAX_WORKERS` | 10 | Maximum number of workers per engine |
|
||||||
|
| `GRAPH_ENGINE_SCALE_UP_THRESHOLD` | 3 | Queue depth that triggers scale up |
|
||||||
|
| `GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME` | 5.0 | Seconds of idle time before scaling down |
|
||||||
|
|
||||||
|
### Example Configurations
|
||||||
|
|
||||||
|
#### Low-Resource Environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export GRAPH_ENGINE_MIN_WORKERS=1
|
||||||
|
export GRAPH_ENGINE_MAX_WORKERS=3
|
||||||
|
export GRAPH_ENGINE_SCALE_UP_THRESHOLD=2
|
||||||
|
export GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=3.0
|
||||||
|
```
|
||||||
|
|
||||||
|
#### High-Performance Environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export GRAPH_ENGINE_MIN_WORKERS=2
|
||||||
|
export GRAPH_ENGINE_MAX_WORKERS=20
|
||||||
|
export GRAPH_ENGINE_SCALE_UP_THRESHOLD=5
|
||||||
|
export GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=10.0
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Default (Balanced)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Uses defaults: min=1, max=10, threshold=3, idle_time=5.0
|
||||||
|
```
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
### Initial Worker Calculation
|
||||||
|
|
||||||
|
The engine analyzes the graph structure at startup:
|
||||||
|
|
||||||
|
- **Sequential graphs** (no branches): 1 worker
|
||||||
|
- **Limited parallelism** (few branches): 2 workers
|
||||||
|
- **Moderate parallelism**: 3 workers
|
||||||
|
- **High parallelism** (many branches): 5 workers
|
||||||
|
|
||||||
|
### Dynamic Scaling
|
||||||
|
|
||||||
|
During execution:
|
||||||
|
|
||||||
|
1. **Scale Up** triggers when:
|
||||||
|
|
||||||
|
- Queue depth exceeds `SCALE_UP_THRESHOLD`
|
||||||
|
- All workers are busy and queue has items
|
||||||
|
- Not at `MAX_WORKERS` limit
|
||||||
|
|
||||||
|
1. **Scale Down** triggers when:
|
||||||
|
|
||||||
|
- Worker idle for more than `SCALE_DOWN_IDLE_TIME` seconds
|
||||||
|
- Above `MIN_WORKERS` limit
|
||||||
|
|
||||||
|
### Gevent Compatibility
|
||||||
|
|
||||||
|
Since Gevent patches threading to use greenlets:
|
||||||
|
|
||||||
|
- Workers are lightweight coroutines, not OS threads
|
||||||
|
- Dynamic scaling has minimal overhead
|
||||||
|
- Can efficiently handle many concurrent workers
|
||||||
|
|
||||||
|
## Migration Guide
|
||||||
|
|
||||||
|
### Before (Fixed 10 Workers)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Every GraphEngine instance created 10 workers
|
||||||
|
# Resource waste for simple workflows
|
||||||
|
# No adaptation to workload
|
||||||
|
```
|
||||||
|
|
||||||
|
### After (Dynamic Workers)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# GraphEngine creates 1-5 initial workers based on graph
|
||||||
|
# Scales up/down based on workload
|
||||||
|
# Configurable via environment variables
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backward Compatibility
|
||||||
|
|
||||||
|
The default configuration (`max=10`) maintains compatibility with existing deployments. To get the old behavior exactly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export GRAPH_ENGINE_MIN_WORKERS=10
|
||||||
|
export GRAPH_ENGINE_MAX_WORKERS=10
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Impact
|
||||||
|
|
||||||
|
### Memory Usage
|
||||||
|
|
||||||
|
- **Simple workflows**: ~80% reduction (1 vs 10 workers)
|
||||||
|
- **Complex workflows**: Similar or slightly better
|
||||||
|
|
||||||
|
### Execution Time
|
||||||
|
|
||||||
|
- **Sequential workflows**: No change
|
||||||
|
- **Parallel workflows**: Improved with proper scaling
|
||||||
|
- **Bursty workloads**: Better adaptation
|
||||||
|
|
||||||
|
### Example Metrics
|
||||||
|
|
||||||
|
| Workflow Type | Old (10 workers) | New (Dynamic) | Improvement |
|
||||||
|
|--------------|------------------|---------------|-------------|
|
||||||
|
| Sequential | 10 workers idle | 1 worker active | 90% fewer workers |
|
||||||
|
| 3-way parallel | 7 workers idle | 3 workers active | 70% fewer workers |
|
||||||
|
| Heavy parallel | 10 workers busy | 10+ workers (scales up) | Better throughput |
|
||||||
|
|
||||||
|
## Monitoring
|
||||||
|
|
||||||
|
Log messages indicate scaling activity:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
INFO: GraphEngine initialized with 2 workers (min: 1, max: 10)
|
||||||
|
INFO: Scaled up workers: 2 -> 3 (queue_depth: 4)
|
||||||
|
INFO: Scaled down workers: 3 -> 2 (removed 1 idle workers)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Start with defaults** - They work well for most cases
|
||||||
|
1. **Monitor queue depth** - Adjust `SCALE_UP_THRESHOLD` if queues back up
|
||||||
|
1. **Consider workload patterns**:
|
||||||
|
- Bursty: Lower `SCALE_DOWN_IDLE_TIME`
|
||||||
|
- Steady: Higher `SCALE_DOWN_IDLE_TIME`
|
||||||
|
1. **Test with your workloads** - Measure and tune
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Workers not scaling up
|
||||||
|
|
||||||
|
- Check `GRAPH_ENGINE_MAX_WORKERS` limit
|
||||||
|
- Verify queue depth exceeds threshold
|
||||||
|
- Check logs for scaling messages
|
||||||
|
|
||||||
|
### Workers scaling down too quickly
|
||||||
|
|
||||||
|
- Increase `GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME`
|
||||||
|
- Consider workload patterns
|
||||||
|
|
||||||
|
### Out of memory
|
||||||
|
|
||||||
|
- Reduce `GRAPH_ENGINE_MAX_WORKERS`
|
||||||
|
- Check for memory leaks in nodes
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
from .agent import AgentNodeStrategyInit
|
||||||
|
from .graph_init_params import GraphInitParams
|
||||||
|
from .graph_runtime_state import GraphRuntimeState
|
||||||
|
from .run_condition import RunCondition
|
||||||
|
from .variable_pool import VariablePool, VariableValue
|
||||||
|
from .workflow_execution import WorkflowExecution
|
||||||
|
from .workflow_node_execution import WorkflowNodeExecution
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentNodeStrategyInit",
|
||||||
|
"GraphInitParams",
|
||||||
|
"GraphRuntimeState",
|
||||||
|
"RunCondition",
|
||||||
|
"VariablePool",
|
||||||
|
"VariableValue",
|
||||||
|
"WorkflowExecution",
|
||||||
|
"WorkflowNodeExecution",
|
||||||
|
]
|
||||||
10
api/core/workflow/entities/agent.py
Normal file
10
api/core/workflow/entities/agent.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class AgentNodeStrategyInit(BaseModel):
|
||||||
|
"""Agent node strategy initialization data."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
icon: Optional[str] = None
|
||||||
@ -3,19 +3,18 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from models.enums import UserFrom
|
|
||||||
from models.workflow import WorkflowType
|
|
||||||
|
|
||||||
|
|
||||||
class GraphInitParams(BaseModel):
|
class GraphInitParams(BaseModel):
|
||||||
# init params
|
# init params
|
||||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||||
app_id: str = Field(..., description="app id")
|
app_id: str = Field(..., description="app id")
|
||||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
|
||||||
workflow_id: str = Field(..., description="workflow id")
|
workflow_id: str = Field(..., description="workflow id")
|
||||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||||
user_id: str = Field(..., description="user id")
|
user_id: str = Field(..., description="user id")
|
||||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
user_from: str = Field(
|
||||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
..., description="user from, account or end-user"
|
||||||
|
) # Should be UserFrom enum: 'account' | 'end-user'
|
||||||
|
invoke_from: str = Field(
|
||||||
|
..., description="invoke from, service-api, web-app, explore or debugger"
|
||||||
|
) # Should be InvokeFrom enum: 'service-api' | 'web-app' | 'explore' | 'debugger'
|
||||||
call_depth: int = Field(..., description="call depth")
|
call_depth: int = Field(..., description="call depth")
|
||||||
@ -3,8 +3,8 @@ from typing import Any
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
|
||||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
from .variable_pool import VariablePool
|
||||||
|
|
||||||
|
|
||||||
class GraphRuntimeState(BaseModel):
|
class GraphRuntimeState(BaseModel):
|
||||||
@ -26,6 +26,3 @@ class GraphRuntimeState(BaseModel):
|
|||||||
|
|
||||||
node_run_steps: int = 0
|
node_run_steps: int = 0
|
||||||
"""node run steps"""
|
"""node run steps"""
|
||||||
|
|
||||||
node_run_state: RuntimeRouteState = RuntimeRouteState()
|
|
||||||
"""node run state"""
|
|
||||||
@ -1,34 +0,0 @@
|
|||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
|
||||||
|
|
||||||
|
|
||||||
class NodeRunResult(BaseModel):
|
|
||||||
"""
|
|
||||||
Node Run Result.
|
|
||||||
"""
|
|
||||||
|
|
||||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
|
||||||
|
|
||||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
|
||||||
process_data: Optional[Mapping[str, Any]] = None # process data
|
|
||||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
|
||||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata
|
|
||||||
llm_usage: Optional[LLMUsage] = None # llm usage
|
|
||||||
|
|
||||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
|
||||||
|
|
||||||
error: Optional[str] = None # error message if status is failed
|
|
||||||
error_type: Optional[str] = None # error type if status is failed
|
|
||||||
|
|
||||||
# single step node run retry
|
|
||||||
retry_index: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class AgentNodeStrategyInit(BaseModel):
|
|
||||||
name: str
|
|
||||||
icon: str | None = None
|
|
||||||
@ -1,12 +0,0 @@
|
|||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class VariableSelector(BaseModel):
|
|
||||||
"""
|
|
||||||
Variable Selector.
|
|
||||||
"""
|
|
||||||
|
|
||||||
variable: str
|
|
||||||
value_selector: Sequence[str]
|
|
||||||
@ -68,10 +68,10 @@ class VariablePool(BaseModel):
|
|||||||
# Add rag pipeline variables to the variable pool
|
# Add rag pipeline variables to the variable pool
|
||||||
if self.rag_pipeline_variables:
|
if self.rag_pipeline_variables:
|
||||||
rag_pipeline_variables_map = defaultdict(dict)
|
rag_pipeline_variables_map = defaultdict(dict)
|
||||||
for var in self.rag_pipeline_variables:
|
for rag_var in self.rag_pipeline_variables:
|
||||||
node_id = var.variable.belong_to_node_id
|
node_id = rag_var.variable.belong_to_node_id
|
||||||
key = var.variable.variable
|
key = rag_var.variable.variable
|
||||||
value = var.value
|
value = rag_var.value
|
||||||
rag_pipeline_variables_map[node_id][key] = value
|
rag_pipeline_variables_map[node_id][key] = value
|
||||||
for key, value in rag_pipeline_variables_map.items():
|
for key, value in rag_pipeline_variables_map.items():
|
||||||
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
|
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
|
||||||
|
|||||||
@ -7,32 +7,14 @@ implementation details like tenant_id, app_id, etc.
|
|||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
|
||||||
|
|
||||||
class WorkflowType(StrEnum):
|
|
||||||
"""
|
|
||||||
Workflow Type Enum for domain layer
|
|
||||||
"""
|
|
||||||
|
|
||||||
WORKFLOW = "workflow"
|
|
||||||
CHAT = "chat"
|
|
||||||
RAG_PIPELINE = "rag-pipeline"
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowExecutionStatus(StrEnum):
|
|
||||||
RUNNING = "running"
|
|
||||||
SUCCEEDED = "succeeded"
|
|
||||||
FAILED = "failed"
|
|
||||||
STOPPED = "stopped"
|
|
||||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowExecution(BaseModel):
|
class WorkflowExecution(BaseModel):
|
||||||
"""
|
"""
|
||||||
Domain model for workflow execution based on WorkflowRun but without
|
Domain model for workflow execution based on WorkflowRun but without
|
||||||
|
|||||||
@ -8,50 +8,11 @@ and don't contain implementation details like tenant_id, app_id, etc.
|
|||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
|
||||||
"""
|
|
||||||
Node Run Metadata Key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
TOTAL_TOKENS = "total_tokens"
|
|
||||||
TOTAL_PRICE = "total_price"
|
|
||||||
CURRENCY = "currency"
|
|
||||||
TOOL_INFO = "tool_info"
|
|
||||||
AGENT_LOG = "agent_log"
|
|
||||||
ITERATION_ID = "iteration_id"
|
|
||||||
ITERATION_INDEX = "iteration_index"
|
|
||||||
DATASOURCE_INFO = "datasource_info"
|
|
||||||
LOOP_ID = "loop_id"
|
|
||||||
LOOP_INDEX = "loop_index"
|
|
||||||
PARALLEL_ID = "parallel_id"
|
|
||||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
|
||||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
|
||||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
|
||||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
|
||||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
|
||||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
|
||||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
|
||||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecutionStatus(StrEnum):
|
|
||||||
"""
|
|
||||||
Node Execution Status Enum.
|
|
||||||
"""
|
|
||||||
|
|
||||||
RUNNING = "running"
|
|
||||||
SUCCEEDED = "succeeded"
|
|
||||||
FAILED = "failed"
|
|
||||||
EXCEPTION = "exception"
|
|
||||||
RETRY = "retry"
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecution(BaseModel):
|
class WorkflowNodeExecution(BaseModel):
|
||||||
|
|||||||
@ -1,4 +1,12 @@
|
|||||||
from enum import StrEnum
|
from enum import Enum, StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class NodeState(Enum):
|
||||||
|
"""State of a node or edge during workflow execution."""
|
||||||
|
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
TAKEN = "taken"
|
||||||
|
SKIPPED = "skipped"
|
||||||
|
|
||||||
|
|
||||||
class SystemVariableKey(StrEnum):
|
class SystemVariableKey(StrEnum):
|
||||||
@ -21,3 +29,107 @@ class SystemVariableKey(StrEnum):
|
|||||||
DATASOURCE_TYPE = "datasource_type"
|
DATASOURCE_TYPE = "datasource_type"
|
||||||
DATASOURCE_INFO = "datasource_info"
|
DATASOURCE_INFO = "datasource_info"
|
||||||
INVOKE_FROM = "invoke_from"
|
INVOKE_FROM = "invoke_from"
|
||||||
|
|
||||||
|
|
||||||
|
class NodeType(StrEnum):
|
||||||
|
START = "start"
|
||||||
|
END = "end"
|
||||||
|
ANSWER = "answer"
|
||||||
|
LLM = "llm"
|
||||||
|
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||||
|
KNOWLEDGE_INDEX = "knowledge-index"
|
||||||
|
IF_ELSE = "if-else"
|
||||||
|
CODE = "code"
|
||||||
|
TEMPLATE_TRANSFORM = "template-transform"
|
||||||
|
QUESTION_CLASSIFIER = "question-classifier"
|
||||||
|
HTTP_REQUEST = "http-request"
|
||||||
|
TOOL = "tool"
|
||||||
|
DATASOURCE = "datasource"
|
||||||
|
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||||
|
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||||
|
LOOP = "loop"
|
||||||
|
LOOP_START = "loop-start"
|
||||||
|
LOOP_END = "loop-end"
|
||||||
|
ITERATION = "iteration"
|
||||||
|
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||||
|
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||||
|
VARIABLE_ASSIGNER = "assigner"
|
||||||
|
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||||
|
LIST_OPERATOR = "list-operator"
|
||||||
|
AGENT = "agent"
|
||||||
|
|
||||||
|
|
||||||
|
class NodeExecutionType(StrEnum):
|
||||||
|
"""Node execution type classification."""
|
||||||
|
|
||||||
|
EXECUTABLE = "executable" # Regular nodes that execute and produce outputs
|
||||||
|
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
|
||||||
|
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
|
||||||
|
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorStrategy(StrEnum):
|
||||||
|
FAIL_BRANCH = "fail-branch"
|
||||||
|
DEFAULT_VALUE = "default-value"
|
||||||
|
|
||||||
|
|
||||||
|
class FailBranchSourceHandle(StrEnum):
|
||||||
|
FAILED = "fail-branch"
|
||||||
|
SUCCESS = "success-branch"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowType(StrEnum):
|
||||||
|
"""
|
||||||
|
Workflow Type Enum for domain layer
|
||||||
|
"""
|
||||||
|
|
||||||
|
WORKFLOW = "workflow"
|
||||||
|
CHAT = "chat"
|
||||||
|
RAG_PIPELINE = "rag-pipeline"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowExecutionStatus(StrEnum):
|
||||||
|
RUNNING = "running"
|
||||||
|
SUCCEEDED = "succeeded"
|
||||||
|
FAILED = "failed"
|
||||||
|
STOPPED = "stopped"
|
||||||
|
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||||
|
"""
|
||||||
|
Node Run Metadata Key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
TOTAL_TOKENS = "total_tokens"
|
||||||
|
TOTAL_PRICE = "total_price"
|
||||||
|
CURRENCY = "currency"
|
||||||
|
TOOL_INFO = "tool_info"
|
||||||
|
AGENT_LOG = "agent_log"
|
||||||
|
ITERATION_ID = "iteration_id"
|
||||||
|
ITERATION_INDEX = "iteration_index"
|
||||||
|
LOOP_ID = "loop_id"
|
||||||
|
LOOP_INDEX = "loop_index"
|
||||||
|
PARALLEL_ID = "parallel_id"
|
||||||
|
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||||
|
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||||
|
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||||
|
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||||
|
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||||
|
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||||
|
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||||
|
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||||
|
DATASOURCE_INFO = "datasource_info"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeExecutionStatus(StrEnum):
|
||||||
|
PENDING = "pending" # Node is scheduled but not yet executing
|
||||||
|
RUNNING = "running"
|
||||||
|
SUCCEEDED = "succeeded"
|
||||||
|
FAILED = "failed"
|
||||||
|
EXCEPTION = "exception"
|
||||||
|
STOPPED = "stopped"
|
||||||
|
PAUSED = "paused"
|
||||||
|
|
||||||
|
# Legacy statuses - kept for backward compatibility
|
||||||
|
RETRY = "retry" # Legacy: replaced by retry mechanism in error handling
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base.node import Node
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeRunFailedError(Exception):
|
class WorkflowNodeRunFailedError(Exception):
|
||||||
def __init__(self, node: BaseNode, err_msg: str):
|
def __init__(self, node: Node, err_msg: str):
|
||||||
self._node = node
|
self._node = node
|
||||||
self._error = err_msg
|
self._error = err_msg
|
||||||
super().__init__(f"Node {node.title} run failed: {err_msg}")
|
super().__init__(f"Node {node.title} run failed: {err_msg}")
|
||||||
|
|||||||
5
api/core/workflow/graph/__init__.py
Normal file
5
api/core/workflow/graph/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .edge import Edge
|
||||||
|
from .graph import Graph, NodeFactory
|
||||||
|
from .graph_template import GraphTemplate
|
||||||
|
|
||||||
|
__all__ = ["Edge", "Graph", "GraphTemplate", "NodeFactory"]
|
||||||
15
api/core/workflow/graph/edge.py
Normal file
15
api/core/workflow/graph/edge.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from core.workflow.enums import NodeState
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Edge:
|
||||||
|
"""Edge connecting two nodes in a workflow graph."""
|
||||||
|
|
||||||
|
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
tail: str = "" # tail node id (source)
|
||||||
|
head: str = "" # head node id (target)
|
||||||
|
source_handle: str = "source" # source handle for conditional branching
|
||||||
|
state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state
|
||||||
266
api/core/workflow/graph/graph.py
Normal file
266
api/core/workflow/graph/graph.py
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional, Protocol, cast
|
||||||
|
|
||||||
|
from core.workflow.enums import NodeType
|
||||||
|
from core.workflow.nodes.base.node import Node
|
||||||
|
|
||||||
|
from .edge import Edge
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NodeFactory(Protocol):
|
||||||
|
"""
|
||||||
|
Protocol for creating Node instances from node data dictionaries.
|
||||||
|
|
||||||
|
This protocol decouples the Graph class from specific node mapping implementations,
|
||||||
|
allowing for different node creation strategies while maintaining type safety.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create_node(self, node_config: dict[str, Any]) -> Node:
|
||||||
|
"""
|
||||||
|
Create a Node instance from node configuration data.
|
||||||
|
|
||||||
|
:param node_config: node configuration dictionary containing type and other data
|
||||||
|
:return: initialized Node instance
|
||||||
|
:raises ValueError: if node type is unknown or configuration is invalid
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Graph:
|
||||||
|
"""Graph representation with nodes and edges for workflow execution."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
nodes: Optional[dict[str, Node]] = None,
|
||||||
|
edges: Optional[dict[str, Edge]] = None,
|
||||||
|
in_edges: Optional[dict[str, list[str]]] = None,
|
||||||
|
out_edges: Optional[dict[str, list[str]]] = None,
|
||||||
|
root_node: Node,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize Graph instance.
|
||||||
|
|
||||||
|
:param nodes: graph nodes mapping (node id: node object)
|
||||||
|
:param edges: graph edges mapping (edge id: edge object)
|
||||||
|
:param in_edges: incoming edges mapping (node id: list of edge ids)
|
||||||
|
:param out_edges: outgoing edges mapping (node id: list of edge ids)
|
||||||
|
:param root_node: root node object
|
||||||
|
"""
|
||||||
|
self.nodes = nodes or {}
|
||||||
|
self.edges = edges or {}
|
||||||
|
self.in_edges = in_edges or {}
|
||||||
|
self.out_edges = out_edges or {}
|
||||||
|
self.root_node = root_node
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_node_configs(cls, node_configs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Parse node configurations and build a mapping of node IDs to configs.
|
||||||
|
|
||||||
|
:param node_configs: list of node configuration dictionaries
|
||||||
|
:return: mapping of node ID to node config
|
||||||
|
"""
|
||||||
|
node_configs_map: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
for node_config in node_configs:
|
||||||
|
node_id = node_config.get("id")
|
||||||
|
if not node_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
node_configs_map[node_id] = node_config
|
||||||
|
|
||||||
|
return node_configs_map
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _find_root_node_id(
|
||||||
|
cls,
|
||||||
|
node_configs_map: dict[str, dict[str, Any]],
|
||||||
|
edge_configs: list[dict[str, Any]],
|
||||||
|
root_node_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Find the root node ID if not specified.
|
||||||
|
|
||||||
|
:param node_configs_map: mapping of node ID to node config
|
||||||
|
:param edge_configs: list of edge configurations
|
||||||
|
:param root_node_id: explicitly specified root node ID
|
||||||
|
:return: determined root node ID
|
||||||
|
"""
|
||||||
|
if root_node_id:
|
||||||
|
if root_node_id not in node_configs_map:
|
||||||
|
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||||
|
return root_node_id
|
||||||
|
|
||||||
|
# Find nodes with no incoming edges
|
||||||
|
nodes_with_incoming = set()
|
||||||
|
for edge_config in edge_configs:
|
||||||
|
target = edge_config.get("target")
|
||||||
|
if target:
|
||||||
|
nodes_with_incoming.add(target)
|
||||||
|
|
||||||
|
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
|
||||||
|
|
||||||
|
# Prefer START node if available
|
||||||
|
start_node_id = None
|
||||||
|
for nid in root_candidates:
|
||||||
|
node_data = node_configs_map[nid].get("data", {})
|
||||||
|
if node_data.get("type") in [NodeType.START, NodeType.DATASOURCE]:
|
||||||
|
start_node_id = nid
|
||||||
|
break
|
||||||
|
|
||||||
|
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
|
||||||
|
|
||||||
|
if not root_node_id:
|
||||||
|
raise ValueError("Unable to determine root node ID")
|
||||||
|
|
||||||
|
return root_node_id
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_edges(
|
||||||
|
cls, edge_configs: list[dict[str, Any]]
|
||||||
|
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
|
||||||
|
"""
|
||||||
|
Build edge objects and mappings from edge configurations.
|
||||||
|
|
||||||
|
:param edge_configs: list of edge configurations
|
||||||
|
:return: tuple of (edges dict, in_edges dict, out_edges dict)
|
||||||
|
"""
|
||||||
|
edges: dict[str, Edge] = {}
|
||||||
|
in_edges: dict[str, list[str]] = defaultdict(list)
|
||||||
|
out_edges: dict[str, list[str]] = defaultdict(list)
|
||||||
|
|
||||||
|
edge_counter = 0
|
||||||
|
for edge_config in edge_configs:
|
||||||
|
source = edge_config.get("source")
|
||||||
|
target = edge_config.get("target")
|
||||||
|
|
||||||
|
if not source or not target:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create edge
|
||||||
|
edge_id = f"edge_{edge_counter}"
|
||||||
|
edge_counter += 1
|
||||||
|
|
||||||
|
source_handle = edge_config.get("sourceHandle", "source")
|
||||||
|
|
||||||
|
edge = Edge(
|
||||||
|
id=edge_id,
|
||||||
|
tail=source,
|
||||||
|
head=target,
|
||||||
|
source_handle=source_handle,
|
||||||
|
)
|
||||||
|
|
||||||
|
edges[edge_id] = edge
|
||||||
|
out_edges[source].append(edge_id)
|
||||||
|
in_edges[target].append(edge_id)
|
||||||
|
|
||||||
|
return edges, dict(in_edges), dict(out_edges)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_node_instances(
|
||||||
|
cls,
|
||||||
|
node_configs_map: dict[str, dict[str, Any]],
|
||||||
|
node_factory: "NodeFactory",
|
||||||
|
) -> dict[str, Node]:
|
||||||
|
"""
|
||||||
|
Create node instances from configurations using the node factory.
|
||||||
|
|
||||||
|
:param node_configs_map: mapping of node ID to node config
|
||||||
|
:param node_factory: factory for creating node instances
|
||||||
|
:return: mapping of node ID to node instance
|
||||||
|
"""
|
||||||
|
nodes: dict[str, Node] = {}
|
||||||
|
|
||||||
|
for node_id, node_config in node_configs_map.items():
|
||||||
|
try:
|
||||||
|
node_instance = node_factory.create_node(node_config)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning("Failed to create node instance: %s", str(e))
|
||||||
|
continue
|
||||||
|
nodes[node_id] = node_instance
|
||||||
|
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
graph_config: Mapping[str, Any],
|
||||||
|
node_factory: "NodeFactory",
|
||||||
|
root_node_id: Optional[str] = None,
|
||||||
|
) -> "Graph":
|
||||||
|
"""
|
||||||
|
Initialize graph
|
||||||
|
|
||||||
|
:param graph_config: graph config containing nodes and edges
|
||||||
|
:param node_factory: factory for creating node instances from config data
|
||||||
|
:param root_node_id: root node id
|
||||||
|
:return: graph instance
|
||||||
|
"""
|
||||||
|
# Parse configs
|
||||||
|
edge_configs = graph_config.get("edges", [])
|
||||||
|
node_configs = graph_config.get("nodes", [])
|
||||||
|
|
||||||
|
if not node_configs:
|
||||||
|
raise ValueError("Graph must have at least one node")
|
||||||
|
|
||||||
|
edge_configs = cast(list, edge_configs)
|
||||||
|
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||||
|
|
||||||
|
# Parse node configurations
|
||||||
|
node_configs_map = cls._parse_node_configs(node_configs)
|
||||||
|
|
||||||
|
# Find root node
|
||||||
|
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
|
||||||
|
|
||||||
|
# Build edges
|
||||||
|
edges, in_edges, out_edges = cls._build_edges(edge_configs)
|
||||||
|
|
||||||
|
# Create node instances
|
||||||
|
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
||||||
|
|
||||||
|
# Get root node instance
|
||||||
|
root_node = nodes[root_node_id]
|
||||||
|
|
||||||
|
# Create and return the graph
|
||||||
|
return cls(
|
||||||
|
nodes=nodes,
|
||||||
|
edges=edges,
|
||||||
|
in_edges=in_edges,
|
||||||
|
out_edges=out_edges,
|
||||||
|
root_node=root_node,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_ids(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get list of node IDs (compatibility property for existing code)
|
||||||
|
|
||||||
|
:return: list of node IDs
|
||||||
|
"""
|
||||||
|
return list(self.nodes.keys())
|
||||||
|
|
||||||
|
def get_outgoing_edges(self, node_id: str) -> list[Edge]:
|
||||||
|
"""
|
||||||
|
Get all outgoing edges from a node (V2 method)
|
||||||
|
|
||||||
|
:param node_id: node id
|
||||||
|
:return: list of outgoing edges
|
||||||
|
"""
|
||||||
|
edge_ids = self.out_edges.get(node_id, [])
|
||||||
|
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||||
|
|
||||||
|
def get_incoming_edges(self, node_id: str) -> list[Edge]:
|
||||||
|
"""
|
||||||
|
Get all incoming edges to a node (V2 method)
|
||||||
|
|
||||||
|
:param node_id: node id
|
||||||
|
:return: list of incoming edges
|
||||||
|
"""
|
||||||
|
edge_ids = self.in_edges.get(node_id, [])
|
||||||
|
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||||
20
api/core/workflow/graph/graph_template.py
Normal file
20
api/core/workflow/graph/graph_template.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class GraphTemplate(BaseModel):
|
||||||
|
"""
|
||||||
|
Graph Template for container nodes and subgraph expansion
|
||||||
|
|
||||||
|
According to GraphEngine V2 spec, GraphTemplate contains:
|
||||||
|
- nodes: mapping of node definitions
|
||||||
|
- edges: mapping of edge definitions
|
||||||
|
- root_ids: list of root node IDs
|
||||||
|
- output_selectors: list of output selectors for the template
|
||||||
|
"""
|
||||||
|
|
||||||
|
nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping")
|
||||||
|
edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping")
|
||||||
|
root_ids: list[str] = Field(default_factory=list, description="root node IDs")
|
||||||
|
output_selectors: list[str] = Field(default_factory=list, description="output selectors")
|
||||||
187
api/core/workflow/graph_engine/README.md
Normal file
187
api/core/workflow/graph_engine/README.md
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
# Graph Engine
|
||||||
|
|
||||||
|
Queue-based workflow execution engine for parallel graph processing.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
The engine uses a modular architecture with specialized packages:
|
||||||
|
|
||||||
|
### Core Components
|
||||||
|
|
||||||
|
- **Domain** (`domain/`) - Core models: ExecutionContext, GraphExecution, NodeExecution
|
||||||
|
- **Event Management** (`event_management/`) - Event handling, collection, and emission
|
||||||
|
- **State Management** (`state_management/`) - Thread-safe state tracking for nodes and edges
|
||||||
|
- **Error Handling** (`error_handling/`) - Strategy-based error recovery (retry, abort, fail-branch, default-value)
|
||||||
|
- **Graph Traversal** (`graph_traversal/`) - Node readiness, edge processing, branch handling
|
||||||
|
- **Command Processing** (`command_processing/`) - External command handling (abort, pause, resume)
|
||||||
|
- **Worker Management** (`worker_management/`) - Dynamic worker pool with auto-scaling
|
||||||
|
- **Orchestration** (`orchestration/`) - Main event loop and execution coordination
|
||||||
|
|
||||||
|
### Supporting Components
|
||||||
|
|
||||||
|
- **Output Registry** (`output_registry/`) - Thread-safe storage for node outputs
|
||||||
|
- **Response Coordinator** (`response_coordinator/`) - Ordered streaming of response nodes
|
||||||
|
- **Command Channels** (`command_channels/`) - Command transport (InMemory/Redis)
|
||||||
|
- **Layers** (`layers/`) - Pluggable middleware for extensions
|
||||||
|
|
||||||
|
## Architecture Diagram
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
classDiagram
|
||||||
|
class GraphEngine {
|
||||||
|
+run()
|
||||||
|
+add_layer()
|
||||||
|
}
|
||||||
|
|
||||||
|
class Domain {
|
||||||
|
ExecutionContext
|
||||||
|
GraphExecution
|
||||||
|
NodeExecution
|
||||||
|
}
|
||||||
|
|
||||||
|
class EventManagement {
|
||||||
|
EventHandlerRegistry
|
||||||
|
EventCollector
|
||||||
|
EventEmitter
|
||||||
|
}
|
||||||
|
|
||||||
|
class StateManagement {
|
||||||
|
NodeStateManager
|
||||||
|
EdgeStateManager
|
||||||
|
ExecutionTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
class WorkerManagement {
|
||||||
|
WorkerPool
|
||||||
|
WorkerFactory
|
||||||
|
DynamicScaler
|
||||||
|
ActivityTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
class GraphTraversal {
|
||||||
|
NodeReadinessChecker
|
||||||
|
EdgeProcessor
|
||||||
|
BranchHandler
|
||||||
|
SkipPropagator
|
||||||
|
}
|
||||||
|
|
||||||
|
class Orchestration {
|
||||||
|
Dispatcher
|
||||||
|
ExecutionCoordinator
|
||||||
|
}
|
||||||
|
|
||||||
|
class ErrorHandling {
|
||||||
|
ErrorHandler
|
||||||
|
RetryStrategy
|
||||||
|
AbortStrategy
|
||||||
|
FailBranchStrategy
|
||||||
|
}
|
||||||
|
|
||||||
|
class CommandProcessing {
|
||||||
|
CommandProcessor
|
||||||
|
AbortCommandHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
class CommandChannels {
|
||||||
|
InMemoryChannel
|
||||||
|
RedisChannel
|
||||||
|
}
|
||||||
|
|
||||||
|
class OutputRegistry {
|
||||||
|
<<Storage>>
|
||||||
|
Scalar Values
|
||||||
|
Streaming Data
|
||||||
|
}
|
||||||
|
|
||||||
|
class ResponseCoordinator {
|
||||||
|
Session Management
|
||||||
|
Path Analysis
|
||||||
|
}
|
||||||
|
|
||||||
|
class Layers {
|
||||||
|
<<Plugin>>
|
||||||
|
DebugLoggingLayer
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphEngine --> Orchestration : coordinates
|
||||||
|
GraphEngine --> Layers : extends
|
||||||
|
|
||||||
|
Orchestration --> EventManagement : processes events
|
||||||
|
Orchestration --> WorkerManagement : manages scaling
|
||||||
|
Orchestration --> CommandProcessing : checks commands
|
||||||
|
Orchestration --> StateManagement : monitors state
|
||||||
|
|
||||||
|
WorkerManagement --> StateManagement : consumes ready queue
|
||||||
|
WorkerManagement --> EventManagement : produces events
|
||||||
|
WorkerManagement --> Domain : executes nodes
|
||||||
|
|
||||||
|
EventManagement --> ErrorHandling : failed events
|
||||||
|
EventManagement --> GraphTraversal : success events
|
||||||
|
EventManagement --> ResponseCoordinator : stream events
|
||||||
|
EventManagement --> Layers : notifies
|
||||||
|
|
||||||
|
GraphTraversal --> StateManagement : updates states
|
||||||
|
GraphTraversal --> Domain : checks graph
|
||||||
|
|
||||||
|
CommandProcessing --> CommandChannels : fetches commands
|
||||||
|
CommandProcessing --> Domain : modifies execution
|
||||||
|
|
||||||
|
ErrorHandling --> Domain : handles failures
|
||||||
|
|
||||||
|
StateManagement --> Domain : tracks entities
|
||||||
|
|
||||||
|
ResponseCoordinator --> OutputRegistry : reads outputs
|
||||||
|
|
||||||
|
Domain --> OutputRegistry : writes outputs
|
||||||
|
```
|
||||||
|
|
||||||
|
## Package Relationships
|
||||||
|
|
||||||
|
### Core Dependencies
|
||||||
|
|
||||||
|
- **Orchestration** acts as the central coordinator, managing all subsystems
|
||||||
|
- **Domain** provides the core business entities used by all packages
|
||||||
|
- **EventManagement** serves as the communication backbone between components
|
||||||
|
- **StateManagement** maintains thread-safe state for the entire system
|
||||||
|
|
||||||
|
### Data Flow
|
||||||
|
|
||||||
|
1. **Commands** flow from CommandChannels → CommandProcessing → Domain
|
||||||
|
1. **Events** flow from Workers → EventHandlerRegistry → State updates
|
||||||
|
1. **Node outputs** flow from Workers → OutputRegistry → ResponseCoordinator
|
||||||
|
1. **Ready nodes** flow from GraphTraversal → StateManagement → WorkerManagement
|
||||||
|
|
||||||
|
### Extension Points
|
||||||
|
|
||||||
|
- **Layers** observe all events for monitoring, logging, and custom logic
|
||||||
|
- **ErrorHandling** strategies can be extended for custom failure recovery
|
||||||
|
- **CommandChannels** can be implemented for different transport mechanisms
|
||||||
|
|
||||||
|
## Execution Flow
|
||||||
|
|
||||||
|
1. **Initialization**: GraphEngine creates all subsystems with the workflow graph
|
||||||
|
1. **Node Discovery**: Traversal components identify ready nodes
|
||||||
|
1. **Worker Execution**: Workers pull from ready queue and execute nodes
|
||||||
|
1. **Event Processing**: Dispatcher routes events to appropriate handlers
|
||||||
|
1. **State Updates**: Managers track node/edge states for next steps
|
||||||
|
1. **Completion**: Coordinator detects when all nodes are done
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from core.workflow.graph_engine import GraphEngine
|
||||||
|
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||||
|
|
||||||
|
# Create and run engine
|
||||||
|
engine = GraphEngine(
|
||||||
|
tenant_id="tenant_1",
|
||||||
|
app_id="app_1",
|
||||||
|
workflow_id="workflow_1",
|
||||||
|
graph=graph,
|
||||||
|
command_channel=InMemoryChannel(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream execution events
|
||||||
|
for event in engine.run():
|
||||||
|
handle_event(event)
|
||||||
|
```
|
||||||
@ -1,4 +1,3 @@
|
|||||||
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
|
|
||||||
from .graph_engine import GraphEngine
|
from .graph_engine import GraphEngine
|
||||||
|
|
||||||
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
__all__ = ["GraphEngine"]
|
||||||
|
|||||||
33
api/core/workflow/graph_engine/command_channels/README.md
Normal file
33
api/core/workflow/graph_engine/command_channels/README.md
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# Command Channels
|
||||||
|
|
||||||
|
Channel implementations for external workflow control.
|
||||||
|
|
||||||
|
## Components
|
||||||
|
|
||||||
|
### InMemoryChannel
|
||||||
|
|
||||||
|
Thread-safe in-memory queue for single-process deployments.
|
||||||
|
|
||||||
|
- `fetch_commands()` - Get pending commands
|
||||||
|
- `send_command()` - Add command to queue
|
||||||
|
|
||||||
|
### RedisChannel
|
||||||
|
|
||||||
|
Redis-based queue for distributed deployments.
|
||||||
|
|
||||||
|
- `fetch_commands()` - Get commands with JSON deserialization
|
||||||
|
- `send_command()` - Store commands with TTL
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Local execution
|
||||||
|
channel = InMemoryChannel()
|
||||||
|
channel.send_command(AbortCommand(graph_id="workflow-123"))
|
||||||
|
|
||||||
|
# Distributed execution
|
||||||
|
redis_channel = RedisChannel(
|
||||||
|
redis_client=redis_client,
|
||||||
|
channel_key="workflow:123:commands"
|
||||||
|
)
|
||||||
|
```
|
||||||
@ -0,0 +1,6 @@
|
|||||||
|
"""Command channel implementations for GraphEngine."""
|
||||||
|
|
||||||
|
from .in_memory_channel import InMemoryChannel
|
||||||
|
from .redis_channel import RedisChannel
|
||||||
|
|
||||||
|
__all__ = ["InMemoryChannel", "RedisChannel"]
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user