mirror of https://github.com/langgenius/dify.git
Merge branch 'feat/rag-2' of https://github.com/langgenius/dify into feat/rag-2
This commit is contained in:
commit
0316eb6064
|
|
@ -460,6 +460,16 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
|||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
|
||||
# Seconds of idle time before scaling down workers (default: 5.0)
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
|
||||
|
||||
# Workflow storage configuration
|
||||
# Options: rdbms, hybrid
|
||||
# rdbms: Use only the relational database (default)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,122 @@
|
|||
[importlinter]
|
||||
root_packages =
|
||||
core
|
||||
configs
|
||||
controllers
|
||||
models
|
||||
tasks
|
||||
services
|
||||
|
||||
[importlinter:contract:workflow]
|
||||
name = Workflow
|
||||
type=layers
|
||||
layers =
|
||||
graph_engine
|
||||
graph_events
|
||||
graph
|
||||
nodes
|
||||
node_events
|
||||
entities
|
||||
containers =
|
||||
core.workflow
|
||||
ignore_imports =
|
||||
core.workflow.nodes.base.node -> core.workflow.graph_events
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
core.workflow.nodes.node_factory -> core.workflow.graph
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
response_coordinator
|
||||
output_registry
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:worker]
|
||||
name = Worker
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
worker
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:graph-engine-architecture]
|
||||
name = Graph Engine Architecture
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
orchestration
|
||||
command_processing
|
||||
event_management
|
||||
error_handling
|
||||
graph_traversal
|
||||
state_management
|
||||
worker_management
|
||||
domain
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:domain-isolation]
|
||||
name = Domain Model Isolation
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.domain
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
core.workflow.graph_engine.command_channels
|
||||
core.workflow.graph_engine.layers
|
||||
core.workflow.graph_engine.protocols
|
||||
|
||||
[importlinter:contract:state-management-layers]
|
||||
name = State Management Layers
|
||||
type = layers
|
||||
layers =
|
||||
execution_tracker
|
||||
node_state_manager
|
||||
edge_state_manager
|
||||
containers =
|
||||
core.workflow.graph_engine.state_management
|
||||
|
||||
[importlinter:contract:worker-management-layers]
|
||||
name = Worker Management Layers
|
||||
type = layers
|
||||
layers =
|
||||
worker_pool
|
||||
worker_factory
|
||||
dynamic_scaler
|
||||
activity_tracker
|
||||
containers =
|
||||
core.workflow.graph_engine.worker_management
|
||||
|
||||
[importlinter:contract:error-handling-strategies]
|
||||
name = Error Handling Strategies
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.error_handling.abort_strategy
|
||||
core.workflow.graph_engine.error_handling.retry_strategy
|
||||
core.workflow.graph_engine.error_handling.fail_branch_strategy
|
||||
core.workflow.graph_engine.error_handling.default_value_strategy
|
||||
|
||||
[importlinter:contract:graph-traversal-components]
|
||||
name = Graph Traversal Components
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.graph_traversal.node_readiness
|
||||
core.workflow.graph_engine.graph_traversal.skip_propagator
|
||||
|
||||
[importlinter:contract:command-channels]
|
||||
name = Command Channels Independence
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.command_channels.in_memory_channel
|
||||
core.workflow.graph_engine.command_channels.redis_channel
|
||||
|
|
@ -14,7 +14,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin import DatasourceProviderID, PluginInstallationSource, ToolProviderID
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
|
|
@ -35,6 +35,7 @@ from models.dataset import Document as DatasetDocument
|
|||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
|
|
@ -1570,7 +1571,7 @@ def transform_datasource_credentials():
|
|||
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
|
||||
)
|
||||
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
|
||||
|
||||
|
||||
|
||||
@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.")
|
||||
@click.option(
|
||||
|
|
@ -1591,4 +1592,4 @@ def install_rag_pipeline_plugins(input_file, output_file, workers):
|
|||
output_file,
|
||||
workers,
|
||||
)
|
||||
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))
|
||||
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))
|
||||
|
|
|
|||
|
|
@ -529,6 +529,28 @@ class WorkflowConfig(BaseSettings):
|
|||
default=200 * 1024,
|
||||
)
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
default=1,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||
description="Maximum number of workers per GraphEngine instance",
|
||||
default=10,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
|
||||
description="Queue depth threshold that triggers worker scale up",
|
||||
default=3,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
|
||||
description="Seconds of idle time before scaling down workers",
|
||||
default=5.0,
|
||||
ge=0.1,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -16,7 +16,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
|||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class RuleGenerateApi(Resource):
|
||||
|
|
@ -135,9 +138,6 @@ class InstructionGenerateApi(Resource):
|
|||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
from models import App, db
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if not app:
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
|
|
@ -413,7 +414,12 @@ class WorkflowTaskStopApi(Resource):
|
|||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
|
|
|
|||
|
|
@ -17,10 +17,11 @@ from core.variables.segment_group import SegmentGroup
|
|||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode, db
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from controllers.console.wraps import (
|
|||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
|
|
@ -31,6 +30,7 @@ from fields.document_fields import document_status_fields
|
|||
from libs.login import login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,10 +11,10 @@ from controllers.console.wraps import (
|
|||
setup_required,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.plugin.entities.plugin import DatasourceProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
|
|
|||
|
|
@ -17,10 +17,11 @@ from core.variables.segment_group import SegmentGroup
|
|||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import db
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
|
@ -131,7 +132,7 @@ def _api_prerequisite(f):
|
|||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args, **kwargs):
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class DraftRagPipelineApi(Resource):
|
|||
Get draft rag pipeline's workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
|
|
@ -84,7 +84,7 @@ class DraftRagPipelineApi(Resource):
|
|||
Sync draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
|
@ -161,7 +161,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -198,7 +198,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||
Run draft workflow loop node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -235,7 +235,7 @@ class DraftRagPipelineRunApi(Resource):
|
|||
Run draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -272,7 +272,7 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
Run published workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -384,8 +384,6 @@ class PublishedRagPipelineRunApi(Resource):
|
|||
#
|
||||
# return result
|
||||
#
|
||||
|
||||
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -396,7 +394,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -441,10 +439,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -487,10 +482,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||
Run draft workflow node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -519,7 +511,7 @@ class RagPipelineTaskStopApi(Resource):
|
|||
Stop workflow task
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
|
@ -538,7 +530,7 @@ class PublishedRagPipelineApi(Resource):
|
|||
Get published pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
if not pipeline.is_published:
|
||||
return None
|
||||
|
|
@ -558,10 +550,7 @@ class PublishedRagPipelineApi(Resource):
|
|||
Publish workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
|
@ -595,7 +584,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
|||
Get default block config
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Get default block configs
|
||||
|
|
@ -613,7 +602,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||
Get default block config
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
|
@ -659,7 +648,7 @@ class PublishedAllRagPipelineApi(Resource):
|
|||
"""
|
||||
Get published workflows
|
||||
"""
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -708,10 +697,7 @@ class RagPipelineByIdApi(Resource):
|
|||
Update workflow attributes
|
||||
"""
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -767,7 +753,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
|
|
@ -792,7 +778,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
|||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
|
|
@ -817,7 +803,7 @@ class DraftRagPipelineFirstStepApi(Resource):
|
|||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
|
|
@ -842,7 +828,7 @@ class DraftRagPipelineSecondStepApi(Resource):
|
|||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
||||
|
|
@ -926,8 +912,11 @@ class DatasourceListApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
if not isinstance(user, Account):
|
||||
raise Forbidden()
|
||||
tenant_id = user.current_tenant_id
|
||||
if not tenant_id:
|
||||
raise Forbidden()
|
||||
|
||||
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
|
||||
|
||||
|
|
@ -974,10 +963,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
|||
"""
|
||||
Set datasource variables
|
||||
"""
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
|||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
|
|
@ -17,6 +18,9 @@ def get_rag_pipeline(
|
|||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user is not an account")
|
||||
|
||||
pipeline_id = kwargs.get("pipeline_id")
|
||||
pipeline_id = str(pipeline_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from core.errors.error import (
|
|||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models.model import AppMode, InstalledApp
|
||||
|
|
@ -78,6 +79,11 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
|||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -32,4 +32,4 @@ class SpecSchemaDefinitionsApi(Resource):
|
|||
return [], 200
|
||||
|
||||
|
||||
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
|
||||
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
|
||||
|
|
|
|||
|
|
@ -21,11 +21,11 @@ from core.mcp.auth.auth_provider import OAuthClientProvider
|
|||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from controllers.common.errors import UnsupportedFileTypeError
|
|||
from controllers.files import files_ns
|
||||
from core.tools.signature import verify_tool_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models import db as global_db
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
|
||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ from core.errors.error import (
|
|||
)
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
|
|
@ -262,7 +263,12 @@ class WorkflowTaskStopApi(Resource):
|
|||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ from controllers.service_api.wraps import (
|
|||
validate_dataset_token,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
|
|
|||
|
|
@ -133,6 +133,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
# validate args
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from core.errors.error import (
|
|||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
|
@ -75,7 +76,12 @@ class WorkflowTaskStopApi(WebApiResource):
|
|||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,9 @@ class BaseAgentRunner(AppRunner):
|
|||
tenant_id=tenant_id,
|
||||
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
|
||||
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
||||
return_resource=app_config.additional_features.show_retrieve_source,
|
||||
return_resource=(
|
||||
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||
),
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from typing import Any
|
|||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
class ModelConfigManager:
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# init application generate entity
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
|
|
@ -23,16 +23,17 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
|
|||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Workflow
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, Message, MessageAnnotation
|
||||
from models.workflow import ConversationVariable, WorkflowType
|
||||
from models.workflow import ConversationVariable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -76,23 +77,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
@ -144,16 +151,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time())
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
# RUN WORKFLOW
|
||||
# Create Redis command channel for this workflow execution
|
||||
task_id = self.application_generate_entity.task_id
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
command_channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
||||
graph=graph,
|
||||
graph_config=self._workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
|
|
@ -164,12 +182,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks,
|
||||
)
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
|
|
|||
|
|
@ -30,14 +30,9 @@ from core.app.entities.queue_entities import (
|
|||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
|
|
@ -64,8 +59,8 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
|||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
|
|
@ -393,9 +388,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[
|
||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
||||
],
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
|
|
@ -440,32 +433,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
|
||||
def _handle_parallel_branch_started_event(
|
||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_start_resp
|
||||
|
||||
def _handle_parallel_branch_finished_events(
|
||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch finished events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_finish_resp
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
self, event: QueueIterationStartEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
|
|
@ -757,8 +724,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||
# Parallel branch events
|
||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
||||
# Iteration events
|
||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||
|
|
@ -806,8 +771,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
event,
|
||||
(
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
),
|
||||
):
|
||||
|
|
@ -820,17 +783,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
)
|
||||
return
|
||||
|
||||
# Handle parallel branch finished events with isinstance check
|
||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
||||
yield from self._handle_parallel_branch_finished_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# For unhandled events, we continue (original behavior)
|
||||
return
|
||||
|
||||
|
|
@ -854,11 +806,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
graph_runtime_state = event.graph_runtime_state
|
||||
yield from self._handle_workflow_started_event(event)
|
||||
|
||||
case QueueTextChunkEvent():
|
||||
yield from self._handle_text_chunk_event(
|
||||
event, tts_publisher=tts_publisher, queue_message=queue_message
|
||||
)
|
||||
|
||||
case QueueErrorEvent():
|
||||
yield from self._handle_error_event(event)
|
||||
break
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
|
|||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaver,
|
||||
DraftVariableSaverFactory,
|
||||
|
|
|
|||
|
|
@ -126,6 +126,21 @@ class AppQueueManager:
|
|||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
@classmethod
|
||||
def set_stop_flag_no_user_check(cls, task_id: str) -> None:
|
||||
"""
|
||||
Set task stop flag without user permission check.
|
||||
This method allows stopping workflows without user context.
|
||||
|
||||
:param task_id: The task ID to stop
|
||||
:return:
|
||||
"""
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
|
|
|
|||
|
|
@ -162,7 +162,9 @@ class ChatAppRunner(AppRunner):
|
|||
config=app_config.dataset,
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
show_retrieve_source=(
|
||||
app_config.additional_features.show_retrieve_source if app_config.additional_features else False
|
||||
),
|
||||
hit_callback=hit_callback,
|
||||
memory=memory,
|
||||
message_id=message.id,
|
||||
|
|
|
|||
|
|
@ -17,14 +17,9 @@ from core.app.entities.queue_entities import (
|
|||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
|
|
@ -37,20 +32,18 @@ from core.app.entities.task_entities import (
|
|||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import (
|
||||
|
|
@ -180,11 +173,10 @@ class WorkflowResponseConverter:
|
|||
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
node_data = cast(ToolNodeData, event.node_data)
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id,
|
||||
provider_type=ToolProviderType(event.provider_type),
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
elif event.node_type == NodeType.DATASOURCE:
|
||||
node_data = cast(DatasourceNodeData, event.node_data)
|
||||
|
|
@ -200,11 +192,7 @@ class WorkflowResponseConverter:
|
|||
def workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
event: QueueNodeSucceededEvent
|
||||
| QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
|
|
@ -238,9 +226,6 @@ class WorkflowResponseConverter:
|
|||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
),
|
||||
|
|
@ -292,50 +277,6 @@ class WorkflowResponseConverter:
|
|||
),
|
||||
)
|
||||
|
||||
def workflow_parallel_branch_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_execution_id: str,
|
||||
event: QueueParallelBranchRunStartedEvent,
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
return ParallelBranchStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
data=ParallelBranchStartStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_execution_id: str,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
return ParallelBranchFinishedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
data=ParallelBranchFinishedStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_iteration_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -350,13 +291,11 @@ class WorkflowResponseConverter:
|
|||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -374,15 +313,10 @@ class WorkflowResponseConverter:
|
|||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -401,7 +335,7 @@ class WorkflowResponseConverter:
|
|||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
outputs=json_converter.to_json_encodable(event.outputs),
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
|
|
@ -415,8 +349,6 @@ class WorkflowResponseConverter:
|
|||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -430,7 +362,7 @@ class WorkflowResponseConverter:
|
|||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
|
|
@ -454,7 +386,7 @@ class WorkflowResponseConverter:
|
|||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
|
|
@ -462,7 +394,6 @@ class WorkflowResponseConverter:
|
|||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -480,7 +411,7 @@ class WorkflowResponseConverter:
|
|||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
title=event.node_title,
|
||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
|
|
@ -11,10 +10,12 @@ from core.app.entities.app_invoke_entities import (
|
|||
RagPipelineGenerateEntity,
|
||||
)
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
|
|
@ -22,7 +23,7 @@ from extensions.ext_database import db
|
|||
from models.dataset import Document, Pipeline
|
||||
from models.enums import UserFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -84,24 +85,30 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
@ -121,6 +128,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
datasource_info=self.application_generate_entity.datasource_info,
|
||||
invoke_from=self.application_generate_entity.invoke_from.value,
|
||||
)
|
||||
|
||||
rag_pipeline_variables = []
|
||||
if workflow.rag_pipeline_variables:
|
||||
for v in workflow.rag_pipeline_variables:
|
||||
|
|
@ -143,11 +151,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
conversation_variables=[],
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_rag_pipeline_graph(
|
||||
graph_config=workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
start_node_id=self.application_generate_entity.start_node_id,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
|
|
@ -155,7 +165,6 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
|
|
@ -166,11 +175,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._update_document_status(
|
||||
|
|
@ -194,10 +202,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
|
||||
def _init_rag_pipeline_graph(
|
||||
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None
|
||||
) -> Graph:
|
||||
"""
|
||||
Init pipeline graph
|
||||
"""
|
||||
graph_config = workflow.graph_dict
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
|
|
@ -227,7 +238,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
graph_config["nodes"] = real_run_nodes
|
||||
graph_config["edges"] = real_edges
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
|
@ -53,7 +53,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -67,7 +66,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -81,7 +79,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
|
|
@ -94,7 +91,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
|
|
@ -186,7 +182,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
|
|
@ -200,7 +195,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
|
|
@ -214,7 +208,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
|
|
@ -237,7 +230,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
},
|
||||
)
|
||||
|
|
@ -434,17 +426,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = session.scalar(
|
||||
|
|
@ -474,7 +456,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
from typing import Optional, cast
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
|
|
@ -9,13 +9,14 @@ from core.app.entities.app_invoke_entities import (
|
|||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -31,7 +32,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
) -> None:
|
||||
|
|
@ -41,7 +41,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
|
||||
|
|
@ -52,24 +51,30 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool.empty(),
|
||||
start_at=time.time(),
|
||||
)
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=self._workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
|
@ -92,15 +97,26 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
conversation_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
# Create Redis command channel for this workflow execution
|
||||
task_id = self.application_generate_entity.task_id
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
command_channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType.value_of(self._workflow.type),
|
||||
graph=graph,
|
||||
graph_config=self._workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
|
|
@ -111,11 +127,11 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
|
|||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
|
|
@ -25,14 +26,9 @@ from core.app.entities.queue_entities import (
|
|||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
|
|
@ -57,8 +53,8 @@ from core.app.entities.task_entities import (
|
|||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
|
@ -349,9 +345,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[
|
||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
||||
],
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
|
|
@ -370,32 +364,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
|
||||
def _handle_parallel_branch_started_event(
|
||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_start_resp
|
||||
|
||||
def _handle_parallel_branch_finished_events(
|
||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch finished events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_finish_resp
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
self, event: QueueIterationStartEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
|
|
@ -617,8 +585,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||
# Parallel branch events
|
||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
||||
# Iteration events
|
||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||
|
|
@ -633,7 +599,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
|
||||
def _dispatch_event(
|
||||
self,
|
||||
event: Any,
|
||||
event: AppQueueEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
|
|
@ -660,8 +626,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
event,
|
||||
(
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
),
|
||||
):
|
||||
|
|
@ -674,17 +638,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
)
|
||||
return
|
||||
|
||||
# Handle parallel branch finished events with isinstance check
|
||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
||||
yield from self._handle_parallel_branch_finished_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# Handle workflow failed and stop events with isinstance check
|
||||
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
|
||||
yield from self._handle_workflow_failed_and_stop_events(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from collections.abc import Mapping
|
|||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
|
|
@ -13,14 +14,9 @@ from core.app.entities.queue_entities import (
|
|||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
|
|
@ -28,42 +24,39 @@ from core.app.entities.queue_entities import (
|
|||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
LoopRunFailedEvent,
|
||||
LoopRunNextEvent,
|
||||
LoopRunStartedEvent,
|
||||
LoopRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeInLoopFailedEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
|
|
@ -79,7 +72,14 @@ class WorkflowBasedAppRunner:
|
|||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
def _init_graph(
|
||||
self,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
workflow_id: str = "",
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
|
|
@ -91,8 +91,28 @@ class WorkflowBasedAppRunner:
|
|||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
|
@ -104,6 +124,7 @@ class WorkflowBasedAppRunner:
|
|||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
|
|
@ -145,8 +166,25 @@ class WorkflowBasedAppRunner:
|
|||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
|
@ -201,6 +239,7 @@ class WorkflowBasedAppRunner:
|
|||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
|
|
@ -242,8 +281,25 @@ class WorkflowBasedAppRunner:
|
|||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
|
@ -310,29 +366,21 @@ class WorkflowBasedAppRunner:
|
|||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
inputs: Mapping[str, Any] | None = {}
|
||||
process_data: Mapping[str, Any] | None = {}
|
||||
outputs: Mapping[str, Any] | None = {}
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
|
|
@ -343,6 +391,8 @@ class WorkflowBasedAppRunner:
|
|||
error=event.error,
|
||||
execution_metadata=execution_metadata,
|
||||
retry_index=event.retry_index,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
|
|
@ -350,44 +400,30 @@ class WorkflowBasedAppRunner:
|
|||
QueueNodeStartedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
start_at=event.start_at,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
node_run_result = event.route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
else:
|
||||
inputs = {}
|
||||
process_data = {}
|
||||
outputs = {}
|
||||
execution_metadata = {}
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
outputs = node_run_result.outputs
|
||||
execution_metadata = node_run_result.metadata
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
start_at=event.start_at,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
|
|
@ -396,34 +432,18 @@ class WorkflowBasedAppRunner:
|
|||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
start_at=event.start_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
|
|
@ -434,93 +454,21 @@ class WorkflowBasedAppRunner:
|
|||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
start_at=event.start_at,
|
||||
inputs=event.node_run_result.inputs,
|
||||
process_data=event.node_run_result.process_data,
|
||||
outputs=event.node_run_result.outputs,
|
||||
error=event.node_run_result.error or "Unknown error",
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInLoopFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInLoopFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_loop_id=event.in_loop_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
text=event.chunk,
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
|
|
@ -533,10 +481,10 @@ class WorkflowBasedAppRunner:
|
|||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, AgentLogEvent):
|
||||
elif isinstance(event, NodeRunAgentLogEvent):
|
||||
self._publish_event(
|
||||
QueueAgentLogEvent(
|
||||
id=event.id,
|
||||
id=event.message_id,
|
||||
label=event.label,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
|
|
@ -547,51 +495,13 @@ class WorkflowBasedAppRunner:
|
|||
node_id=event.node_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunSucceededEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunFailedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
elif isinstance(event, NodeRunIterationStartedEvent):
|
||||
self._publish_event(
|
||||
QueueIterationStartEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
|
|
@ -599,55 +509,41 @@ class WorkflowBasedAppRunner:
|
|||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
elif isinstance(event, NodeRunIterationNextEvent):
|
||||
self._publish_event(
|
||||
QueueIterationNextEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
elif isinstance(event, (NodeRunIterationSucceededEvent | NodeRunIterationFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueIterationCompletedEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
error=event.error if isinstance(event, NodeRunIterationFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, LoopRunStartedEvent):
|
||||
elif isinstance(event, NodeRunLoopStartedEvent):
|
||||
self._publish_event(
|
||||
QueueLoopStartEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
|
|
@ -655,42 +551,32 @@ class WorkflowBasedAppRunner:
|
|||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, LoopRunNextEvent):
|
||||
elif isinstance(event, NodeRunLoopNextEvent):
|
||||
self._publish_event(
|
||||
QueueLoopNextEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_loop_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)):
|
||||
elif isinstance(event, (NodeRunLoopSucceededEvent | NodeRunLoopFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueLoopCompletedEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, LoopRunFailedEvent) else None,
|
||||
error=event.error if isinstance(event, NodeRunLoopFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,11 +7,9 @@ from pydantic import BaseModel
|
|||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class QueueEvent(StrEnum):
|
||||
|
|
@ -43,9 +41,6 @@ class QueueEvent(StrEnum):
|
|||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
MESSAGE_FILE = "message_file"
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
AGENT_LOG = "agent_log"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
|
|
@ -80,15 +75,7 @@ class QueueIterationStartEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
node_title: str
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
|
|
@ -108,20 +95,9 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
node_title: str
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
duration: Optional[float] = None
|
||||
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
|
|
@ -134,15 +110,7 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
node_title: str
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
|
|
@ -163,7 +131,7 @@ class QueueLoopStartEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_title: str
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
|
@ -191,7 +159,7 @@ class QueueLoopNextEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_title: str
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
|
@ -204,7 +172,6 @@ class QueueLoopNextEvent(AppQueueEvent):
|
|||
"""iteratoin run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current loop
|
||||
duration: Optional[float] = None
|
||||
|
||||
|
||||
class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
|
|
@ -217,7 +184,7 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_title: str
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
|
@ -364,27 +331,24 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
|||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_run_index: int = 1
|
||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||
predecessor_node_id: Optional[str] = None
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||
provider_id: str
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
|
|
@ -396,7 +360,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
|
@ -417,10 +380,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
"""single iteration duration map"""
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
"""single loop duration map"""
|
||||
loop_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueAgentLogEvent(AppQueueEvent):
|
||||
|
|
@ -454,72 +413,6 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
|||
retry_index: int # retry index
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeInLoopFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInLoopFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeExceptionEvent entity
|
||||
|
|
@ -530,7 +423,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
|
@ -563,15 +455,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
|
|
@ -678,61 +562,3 @@ class WorkflowQueueMessage(QueueMessage):
|
|||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunStartedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunSucceededEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
error: str
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
|
|
@ -71,8 +71,6 @@ class StreamEvent(Enum):
|
|||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_RETRY = "node_retry"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
|
|
@ -440,54 +438,6 @@ class NodeRetryStreamResponse(StreamResponse):
|
|||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchFinishedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
|
|
@ -506,8 +456,6 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
|||
extras: dict = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
||||
workflow_run_id: str
|
||||
|
|
@ -530,12 +478,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
|||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
pre_iteration_output: Optional[Any] = None
|
||||
extras: dict = Field(default_factory=dict)
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
|
|
@ -567,8 +510,6 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
|||
execution_metadata: Optional[Mapping] = None
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
||||
workflow_run_id: str
|
||||
|
|
@ -622,7 +563,6 @@ class LoopNodeNextStreamResponse(StreamResponse):
|
|||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_NEXT
|
||||
workflow_run_id: str
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ from core.model_runtime.entities.provider_entities import (
|
|||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import (
|
||||
|
|
@ -41,6 +40,7 @@ from models.provider import (
|
|||
ProviderType,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -627,6 +627,7 @@ class ProviderConfiguration(BaseModel):
|
|||
Get custom model credentials.
|
||||
"""
|
||||
# get provider model
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
|
|
@ -1124,6 +1125,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
Get provider model setting.
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
|
|
@ -1207,6 +1209,7 @@ class ProviderConfiguration(BaseModel):
|
|||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
provider_names = [self.provider.provider]
|
||||
if model_provider_id.is_langgenius():
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ def obfuscated_token(token: str):
|
|||
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
from models.engine import db
|
||||
|
||||
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
|
|
|
|||
|
|
@ -28,8 +28,9 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
|||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from models import App, Message, WorkflowNodeExecutionModel, db
|
||||
from core.workflow.node_events import AgentLogEvent
|
||||
from extensions.ext_database import db
|
||||
from models import App, Message, WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from collections.abc import Sequence
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
|
|
@ -39,86 +40,89 @@ class TokenBufferMemory:
|
|||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
"""
|
||||
app_record = self.conversation.app
|
||||
with Session(db.engine) as session:
|
||||
app_record = self.conversation.app
|
||||
|
||||
# fetch limited messages, and return reversed
|
||||
stmt = (
|
||||
select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
message_limit = min(message_limit, 500)
|
||||
else:
|
||||
message_limit = 500
|
||||
|
||||
stmt = stmt.limit(message_limit)
|
||||
|
||||
messages = db.session.scalars(stmt).all()
|
||||
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
messages = list(reversed(thread_messages))
|
||||
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = None
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_run = db.session.scalar(
|
||||
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
else:
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
if file_extra_config and app_record:
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
)
|
||||
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||
detail = file_extra_config.image_config.detail
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
if not file_objs:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
else:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
# fetch limited messages, and return reversed
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
message_limit = min(message_limit, 500)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
message_limit = 500
|
||||
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
stmt = stmt.limit(message_limit)
|
||||
|
||||
if not prompt_messages:
|
||||
return []
|
||||
messages = session.scalars(stmt).all()
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
messages = list(reversed(thread_messages))
|
||||
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
files = session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = None
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_run = session.scalar(
|
||||
select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
else:
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
if file_extra_config and app_record:
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
)
|
||||
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||
detail = file_extra_config.image_config.detail
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
if not file_objs:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
else:
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if curr_message_tokens > max_token_limit:
|
||||
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
|
|
|
|||
|
|
@ -24,8 +24,7 @@ from core.model_runtime.errors.invoke import (
|
|||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
|
||||
|
||||
class AIModel(BaseModel):
|
||||
|
|
@ -53,6 +52,8 @@ class AIModel(BaseModel):
|
|||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError
|
||||
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
|
|
@ -140,6 +141,8 @@ class AIModel(BaseModel):
|
|||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
|
||||
# sort credentials
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from core.model_runtime.entities.model_entities import (
|
|||
PriceType,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -142,6 +141,8 @@ class LargeLanguageModel(AIModel):
|
|||
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
result = plugin_model_manager.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
@ -340,6 +341,8 @@ class LargeLanguageModel(AIModel):
|
|||
:return:
|
||||
"""
|
||||
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_llm_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from pydantic import ConfigDict
|
|||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class ModerationModel(AIModel):
|
||||
|
|
@ -31,6 +30,8 @@ class ModerationModel(AIModel):
|
|||
self.started_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import Optional
|
|||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class RerankModel(AIModel):
|
||||
|
|
@ -36,6 +35,8 @@ class RerankModel(AIModel):
|
|||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_rerank(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from pydantic import ConfigDict
|
|||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class Speech2TextModel(AIModel):
|
||||
|
|
@ -28,6 +27,8 @@ class Speech2TextModel(AIModel):
|
|||
:return: text for given audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_speech_to_text(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from core.entities.embedding_type import EmbeddingInputType
|
|||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
|
||||
class TextEmbeddingModel(AIModel):
|
||||
|
|
@ -37,6 +36,8 @@ class TextEmbeddingModel(AIModel):
|
|||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
try:
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_text_embedding(
|
||||
|
|
@ -61,6 +62,8 @@ class TextEmbeddingModel(AIModel):
|
|||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_text_embedding_num_tokens(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from pydantic import ConfigDict
|
|||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -42,6 +41,8 @@ class TTSModel(AIModel):
|
|||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.invoke_tts(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
@ -65,6 +66,8 @@ class TTSModel(AIModel):
|
|||
:param credentials: The credentials required to access the TTS model.
|
||||
:return: A list of voices supported by the TTS model.
|
||||
"""
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
plugin_model_manager = PluginModelClient()
|
||||
return plugin_model_manager.get_tts_model_voices(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
|
|||
|
|
@ -20,10 +20,8 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
|
|||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -37,6 +35,8 @@ class ModelProviderFactory:
|
|||
provider_position_map: dict[str, int]
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
self.provider_position_map = {}
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
|
|
@ -71,7 +71,7 @@ class ModelProviderFactory:
|
|||
|
||||
return [extension.plugin_model_provider_entity.declaration for extension in sorted_extensions.values()]
|
||||
|
||||
def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]:
|
||||
def get_plugin_model_providers(self) -> Sequence["PluginModelProviderEntity"]:
|
||||
"""
|
||||
Get all plugin model providers
|
||||
:return: list of plugin model providers
|
||||
|
|
@ -109,7 +109,7 @@ class ModelProviderFactory:
|
|||
plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider)
|
||||
return plugin_model_provider_entity.declaration
|
||||
|
||||
def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity:
|
||||
def get_plugin_model_provider(self, provider: str) -> "PluginModelProviderEntity":
|
||||
"""
|
||||
Get plugin model provider
|
||||
:param provider: provider name
|
||||
|
|
@ -366,6 +366,8 @@ class ModelProviderFactory:
|
|||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
|
||||
# get icon bytes from plugin asset manager
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
|
||||
plugin_asset_manager = PluginAssetManager()
|
||||
return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
|
|
@ -375,5 +377,6 @@ class ModelProviderFactory:
|
|||
:param provider: provider name
|
||||
:return: plugin id and provider name
|
||||
"""
|
||||
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
|
|
|
|||
|
|
@ -54,13 +54,10 @@ from core.ops.entities.trace_entity import (
|
|||
)
|
||||
from core.rag.models.document import Document
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes import NodeType
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
|||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import MessageStatus
|
||||
|
|
|
|||
|
|
@ -28,8 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
|||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
|
|
|||
|
|
@ -22,8 +22,7 @@ from core.ops.entities.trace_entity import (
|
|||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import queue
|
|||
import threading
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from cachetools import LRUCache
|
||||
|
|
@ -30,13 +30,15 @@ from core.ops.entities.trace_entity import (
|
|||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.workflow import WorkflowAppLog, WorkflowRun
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import WorkflowExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -410,7 +412,7 @@ class TraceTask:
|
|||
self,
|
||||
trace_type: Any,
|
||||
message_id: Optional[str] = None,
|
||||
workflow_execution: Optional[WorkflowExecution] = None,
|
||||
workflow_execution: Optional["WorkflowExecution"] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
|
|
|
|||
|
|
@ -23,8 +23,7 @@ from core.ops.entities.trace_entity import (
|
|||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
|
|
|||
|
|
@ -164,7 +164,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
call_depth=1,
|
||||
workflow_thread_pool_id=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ModelConfig as ParameterExtractorModelConfig,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import enum
|
||||
import json
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.parameter_entities import CommonParameterType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.workflow.nodes.base.entities import NumberType
|
||||
|
||||
|
||||
class PluginParameterOption(BaseModel):
|
||||
|
|
@ -154,7 +154,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
|||
raise ValueError("The tools selector must be a list.")
|
||||
return value
|
||||
case PluginParameterType.ANY:
|
||||
if value and not isinstance(value, str | dict | list | NumberType):
|
||||
if value and not isinstance(value, str | dict | list | int | float):
|
||||
raise ValueError("The var selector must be a string, dictionary, list or number.")
|
||||
return value
|
||||
case PluginParameterType.ARRAY:
|
||||
|
|
@ -162,8 +162,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
|||
# Try to parse JSON string for arrays
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
parsed_value = json.loads(value)
|
||||
if isinstance(parsed_value, list):
|
||||
return parsed_value
|
||||
|
|
@ -176,8 +174,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
|
|||
# Try to parse JSON string for objects
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
parsed_value = json.loads(value)
|
||||
if isinstance(parsed_value, dict):
|
||||
return parsed_value
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
import datetime
|
||||
import enum
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntity
|
||||
|
|
@ -141,60 +139,6 @@ class PluginEntity(PluginInstallation):
|
|||
return self
|
||||
|
||||
|
||||
class GenericProviderID:
|
||||
organization: str
|
||||
plugin_name: str
|
||||
provider_name: str
|
||||
is_hardcoded: bool
|
||||
|
||||
def to_string(self) -> str:
|
||||
return str(self)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
|
||||
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
if not value:
|
||||
raise NotFound("plugin not found, please add plugin")
|
||||
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
|
||||
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
|
||||
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
|
||||
if re.match(r"^[a-z0-9_-]+$", value):
|
||||
value = f"langgenius/{value}/{value}"
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin id {value}")
|
||||
|
||||
self.organization, self.plugin_name, self.provider_name = value.split("/")
|
||||
self.is_hardcoded = is_hardcoded
|
||||
|
||||
def is_langgenius(self) -> bool:
|
||||
return self.organization == "langgenius"
|
||||
|
||||
@property
|
||||
def plugin_id(self) -> str:
|
||||
return f"{self.organization}/{self.plugin_name}"
|
||||
|
||||
|
||||
class ModelProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
if self.organization == "langgenius" and self.provider_name == "google":
|
||||
self.plugin_name = "gemini"
|
||||
|
||||
|
||||
class ToolProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
if self.organization == "langgenius":
|
||||
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
|
||||
self.plugin_name = f"{self.provider_name}_tool"
|
||||
|
||||
|
||||
class DatasourceProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
|
||||
|
||||
class PluginDependency(BaseModel):
|
||||
class Type(enum.StrEnum):
|
||||
Github = PluginInstallationSource.Github.value
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ from collections.abc import Generator
|
|||
from typing import Any, Optional
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginAgentProviderEntity,
|
||||
)
|
||||
from core.plugin.entities.request import PluginInvokeContext
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class PluginAgentClient(BasePluginClient):
|
||||
|
|
|
|||
|
|
@ -10,13 +10,13 @@ from core.datasource.entities.datasource_entities import (
|
|||
OnlineDriveDownloadFileRequest,
|
||||
WebsiteCrawlMessage,
|
||||
)
|
||||
from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginDatasourceProviderEntity,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.schemas.resolver import resolve_dify_schema_refs
|
||||
from models.provider_ids import DatasourceProviderID, GenericProviderID
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class DynamicSelectClient(BasePluginClient):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from collections.abc import Sequence
|
|||
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
GenericProviderID,
|
||||
MissingPluginDependency,
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
|
|
@ -16,6 +15,7 @@ from core.plugin.entities.plugin_daemon import (
|
|||
PluginListResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from models.provider_ids import GenericProviderID
|
||||
|
||||
|
||||
class PluginInstaller(BasePluginClient):
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import Any, Optional
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginToolProviderEntity,
|
||||
|
|
@ -12,6 +11,7 @@ from core.plugin.impl.base import BasePluginClient
|
|||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from core.schemas.resolver import resolve_dify_schema_refs
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
||||
from models.provider_ids import GenericProviderID, ToolProviderID
|
||||
|
||||
|
||||
class PluginToolManager(BasePluginClient):
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ from core.model_runtime.entities.provider_entities import (
|
|||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from extensions import ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
|
@ -49,6 +48,7 @@ from models.provider import (
|
|||
TenantDefaultModel,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,10 +2,9 @@
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.splitter.fixed_text_splitter import (
|
||||
|
|
@ -16,6 +15,9 @@ from core.rag.splitter.text_splitter import TextSplitter
|
|||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
|
||||
class BaseIndexProcessor(ABC):
|
||||
"""Interface for extract files."""
|
||||
|
|
@ -61,7 +63,7 @@ class BaseIndexProcessor(ABC):
|
|||
max_tokens: int,
|
||||
chunk_overlap: int,
|
||||
separator: str,
|
||||
embedding_model_instance: Optional[ModelInstance],
|
||||
embedding_model_instance: Optional["ModelInstance"],
|
||||
) -> TextSplitter:
|
||||
"""
|
||||
Get the NodeParser object according to the processing rule.
|
||||
|
|
|
|||
|
|
@ -9,11 +9,8 @@ from typing import Optional, Union
|
|||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_execution import (
|
||||
WorkflowExecution,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.entities import WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.helper import extract_tenant_id
|
||||
|
|
@ -203,5 +200,4 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
session.commit()
|
||||
|
||||
# Update the in-memory cache for faster subsequent lookups
|
||||
logger.debug("Updating cache for execution_id: %s", db_model.id)
|
||||
self._execution_cache[db_model.id] = db_model
|
||||
|
|
|
|||
|
|
@ -12,12 +12,8 @@ from sqlalchemy.engine import Engine
|
|||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.helper import extract_tenant_id
|
||||
|
|
@ -215,7 +211,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
# Update the in-memory cache for faster subsequent lookups
|
||||
# Only cache if we have a node_execution_id to use as the cache key
|
||||
if db_model.node_execution_id:
|
||||
logger.debug("Updating cache for node_execution_id: %s", db_model.node_execution_id)
|
||||
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||
|
||||
def get_db_models_by_workflow_run(
|
||||
|
|
|
|||
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
from .resolver import resolve_dify_schema_refs
|
||||
|
||||
__all__ = ["resolve_dify_schema_refs"]
|
||||
__all__ = ["resolve_dify_schema_refs"]
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from typing import Any, ClassVar, Optional
|
|||
|
||||
class SchemaRegistry:
|
||||
"""Schema registry manages JSON schemas with version support"""
|
||||
|
||||
|
||||
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
|
||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
|
|
@ -25,41 +25,41 @@ class SchemaRegistry:
|
|||
if cls._default_instance is None:
|
||||
current_dir = Path(__file__).parent
|
||||
schema_dir = current_dir / "builtin" / "schemas"
|
||||
|
||||
|
||||
registry = cls(str(schema_dir))
|
||||
registry.load_all_versions()
|
||||
|
||||
|
||||
cls._default_instance = registry
|
||||
|
||||
|
||||
return cls._default_instance
|
||||
|
||||
def load_all_versions(self) -> None:
|
||||
"""Scans the schema directory and loads all versions"""
|
||||
if not self.base_dir.exists():
|
||||
return
|
||||
|
||||
|
||||
for entry in self.base_dir.iterdir():
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
|
||||
|
||||
version = entry.name
|
||||
if not version.startswith("v"):
|
||||
continue
|
||||
|
||||
|
||||
self._load_version_dir(version, entry)
|
||||
|
||||
def _load_version_dir(self, version: str, version_dir: Path) -> None:
|
||||
"""Loads all schemas in a version directory"""
|
||||
if not version_dir.exists():
|
||||
return
|
||||
|
||||
|
||||
if version not in self.versions:
|
||||
self.versions[version] = {}
|
||||
|
||||
|
||||
for entry in version_dir.iterdir():
|
||||
if entry.suffix != ".json":
|
||||
continue
|
||||
|
||||
|
||||
schema_name = entry.stem
|
||||
self._load_schema(version, schema_name, entry)
|
||||
|
||||
|
|
@ -68,10 +68,10 @@ class SchemaRegistry:
|
|||
try:
|
||||
with open(schema_path, encoding="utf-8") as f:
|
||||
schema = json.load(f)
|
||||
|
||||
|
||||
# Store the schema
|
||||
self.versions[version][schema_name] = schema
|
||||
|
||||
|
||||
# Extract and store metadata
|
||||
uri = f"https://dify.ai/schemas/{version}/{schema_name}.json"
|
||||
metadata = {
|
||||
|
|
@ -81,26 +81,26 @@ class SchemaRegistry:
|
|||
"deprecated": schema.get("deprecated", False),
|
||||
}
|
||||
self.metadata[uri] = metadata
|
||||
|
||||
|
||||
except (OSError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
|
||||
|
||||
|
||||
def get_schema(self, uri: str) -> Optional[Any]:
|
||||
"""Retrieves a schema by URI with version support"""
|
||||
version, schema_name = self._parse_uri(uri)
|
||||
if not version or not schema_name:
|
||||
return None
|
||||
|
||||
|
||||
version_schemas = self.versions.get(version)
|
||||
if not version_schemas:
|
||||
return None
|
||||
|
||||
|
||||
return version_schemas.get(schema_name)
|
||||
|
||||
def _parse_uri(self, uri: str) -> tuple[str, str]:
|
||||
"""Parses a schema URI to extract version and schema name"""
|
||||
from core.schemas.resolver import parse_dify_schema_uri
|
||||
|
||||
return parse_dify_schema_uri(uri)
|
||||
|
||||
def list_versions(self) -> list[str]:
|
||||
|
|
@ -112,19 +112,15 @@ class SchemaRegistry:
|
|||
version_schemas = self.versions.get(version)
|
||||
if not version_schemas:
|
||||
return []
|
||||
|
||||
|
||||
return sorted(version_schemas.keys())
|
||||
|
||||
def get_all_schemas_for_version(self, version: str = "v1") -> list[Mapping[str, Any]]:
|
||||
"""Returns all schemas for a version in the API format"""
|
||||
version_schemas = self.versions.get(version, {})
|
||||
|
||||
|
||||
result = []
|
||||
for schema_name, schema in version_schemas.items():
|
||||
result.append({
|
||||
"name": schema_name,
|
||||
"label": schema.get("title", schema_name),
|
||||
"schema": schema
|
||||
})
|
||||
|
||||
return result
|
||||
result.append({"name": schema_name, "label": schema.get("title", schema_name), "schema": schema})
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -19,11 +19,13 @@ _DIFY_SCHEMA_PATTERN = re.compile(r"^https://dify\.ai/schemas/(v\d+)/(.+)\.json$
|
|||
|
||||
class SchemaResolutionError(Exception):
|
||||
"""Base exception for schema resolution errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CircularReferenceError(SchemaResolutionError):
|
||||
"""Raised when a circular reference is detected"""
|
||||
|
||||
def __init__(self, ref_uri: str, ref_path: list[str]):
|
||||
self.ref_uri = ref_uri
|
||||
self.ref_path = ref_path
|
||||
|
|
@ -32,6 +34,7 @@ class CircularReferenceError(SchemaResolutionError):
|
|||
|
||||
class MaxDepthExceededError(SchemaResolutionError):
|
||||
"""Raised when maximum resolution depth is exceeded"""
|
||||
|
||||
def __init__(self, max_depth: int):
|
||||
self.max_depth = max_depth
|
||||
super().__init__(f"Maximum resolution depth ({max_depth}) exceeded")
|
||||
|
|
@ -39,6 +42,7 @@ class MaxDepthExceededError(SchemaResolutionError):
|
|||
|
||||
class SchemaNotFoundError(SchemaResolutionError):
|
||||
"""Raised when a referenced schema cannot be found"""
|
||||
|
||||
def __init__(self, ref_uri: str):
|
||||
self.ref_uri = ref_uri
|
||||
super().__init__(f"Schema not found: {ref_uri}")
|
||||
|
|
@ -47,6 +51,7 @@ class SchemaNotFoundError(SchemaResolutionError):
|
|||
@dataclass
|
||||
class QueueItem:
|
||||
"""Represents an item in the BFS queue"""
|
||||
|
||||
current: Any
|
||||
parent: Optional[Any]
|
||||
key: Optional[Union[str, int]]
|
||||
|
|
@ -56,39 +61,39 @@ class QueueItem:
|
|||
|
||||
class SchemaResolver:
|
||||
"""Resolver for Dify schema references with caching and optimizations"""
|
||||
|
||||
|
||||
_cache: dict[str, SchemaDict] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10):
|
||||
"""
|
||||
Initialize the schema resolver
|
||||
|
||||
|
||||
Args:
|
||||
registry: Schema registry to use (defaults to default registry)
|
||||
max_depth: Maximum depth for reference resolution
|
||||
"""
|
||||
self.registry = registry or SchemaRegistry.default_registry()
|
||||
self.max_depth = max_depth
|
||||
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear the global schema cache"""
|
||||
with cls._cache_lock:
|
||||
cls._cache.clear()
|
||||
|
||||
|
||||
def resolve(self, schema: SchemaType) -> SchemaType:
|
||||
"""
|
||||
Resolve all $ref references in the schema
|
||||
|
||||
|
||||
Performance optimization: quickly checks for $ref presence before processing.
|
||||
|
||||
|
||||
Args:
|
||||
schema: Schema to resolve
|
||||
|
||||
|
||||
Returns:
|
||||
Resolved schema with all references expanded
|
||||
|
||||
|
||||
Raises:
|
||||
CircularReferenceError: If circular reference detected
|
||||
MaxDepthExceededError: If max depth exceeded
|
||||
|
|
@ -96,44 +101,39 @@ class SchemaResolver:
|
|||
"""
|
||||
if not isinstance(schema, (dict, list)):
|
||||
return schema
|
||||
|
||||
|
||||
# Fast path: if no Dify refs found, return original schema unchanged
|
||||
# This avoids expensive deepcopy and BFS traversal for schemas without refs
|
||||
if not _has_dify_refs(schema):
|
||||
return schema
|
||||
|
||||
|
||||
# Slow path: schema contains refs, perform full resolution
|
||||
import copy
|
||||
|
||||
result = copy.deepcopy(schema)
|
||||
|
||||
|
||||
# Initialize BFS queue
|
||||
queue = deque([QueueItem(
|
||||
current=result,
|
||||
parent=None,
|
||||
key=None,
|
||||
depth=0,
|
||||
ref_path=set()
|
||||
)])
|
||||
|
||||
queue = deque([QueueItem(current=result, parent=None, key=None, depth=0, ref_path=set())])
|
||||
|
||||
while queue:
|
||||
item = queue.popleft()
|
||||
|
||||
|
||||
# Process the current item
|
||||
self._process_queue_item(queue, item)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _process_queue_item(self, queue: deque, item: QueueItem) -> None:
|
||||
"""Process a single queue item"""
|
||||
if isinstance(item.current, dict):
|
||||
self._process_dict(queue, item)
|
||||
elif isinstance(item.current, list):
|
||||
self._process_list(queue, item)
|
||||
|
||||
|
||||
def _process_dict(self, queue: deque, item: QueueItem) -> None:
|
||||
"""Process a dictionary item"""
|
||||
ref_uri = item.current.get("$ref")
|
||||
|
||||
|
||||
if ref_uri and _is_dify_schema_ref(ref_uri):
|
||||
# Handle $ref resolution
|
||||
self._resolve_ref(queue, item, ref_uri)
|
||||
|
|
@ -144,14 +144,10 @@ class SchemaResolver:
|
|||
next_depth = item.depth + 1
|
||||
if next_depth >= self.max_depth:
|
||||
raise MaxDepthExceededError(self.max_depth)
|
||||
queue.append(QueueItem(
|
||||
current=value,
|
||||
parent=item.current,
|
||||
key=key,
|
||||
depth=next_depth,
|
||||
ref_path=item.ref_path
|
||||
))
|
||||
|
||||
queue.append(
|
||||
QueueItem(current=value, parent=item.current, key=key, depth=next_depth, ref_path=item.ref_path)
|
||||
)
|
||||
|
||||
def _process_list(self, queue: deque, item: QueueItem) -> None:
|
||||
"""Process a list item"""
|
||||
for idx, value in enumerate(item.current):
|
||||
|
|
@ -159,14 +155,10 @@ class SchemaResolver:
|
|||
next_depth = item.depth + 1
|
||||
if next_depth >= self.max_depth:
|
||||
raise MaxDepthExceededError(self.max_depth)
|
||||
queue.append(QueueItem(
|
||||
current=value,
|
||||
parent=item.current,
|
||||
key=idx,
|
||||
depth=next_depth,
|
||||
ref_path=item.ref_path
|
||||
))
|
||||
|
||||
queue.append(
|
||||
QueueItem(current=value, parent=item.current, key=idx, depth=next_depth, ref_path=item.ref_path)
|
||||
)
|
||||
|
||||
def _resolve_ref(self, queue: deque, item: QueueItem, ref_uri: str) -> None:
|
||||
"""Resolve a $ref reference"""
|
||||
# Check for circular reference
|
||||
|
|
@ -175,82 +167,78 @@ class SchemaResolver:
|
|||
item.current["$circular_ref"] = True
|
||||
logger.warning("Circular reference detected: %s", ref_uri)
|
||||
return
|
||||
|
||||
|
||||
# Get resolved schema (from cache or registry)
|
||||
resolved_schema = self._get_resolved_schema(ref_uri)
|
||||
if not resolved_schema:
|
||||
logger.warning("Schema not found: %s", ref_uri)
|
||||
return
|
||||
|
||||
|
||||
# Update ref path
|
||||
new_ref_path = item.ref_path | {ref_uri}
|
||||
|
||||
|
||||
# Replace the reference with resolved schema
|
||||
next_depth = item.depth + 1
|
||||
if next_depth >= self.max_depth:
|
||||
raise MaxDepthExceededError(self.max_depth)
|
||||
|
||||
|
||||
if item.parent is None:
|
||||
# Root level replacement
|
||||
item.current.clear()
|
||||
item.current.update(resolved_schema)
|
||||
queue.append(QueueItem(
|
||||
current=item.current,
|
||||
parent=None,
|
||||
key=None,
|
||||
depth=next_depth,
|
||||
ref_path=new_ref_path
|
||||
))
|
||||
queue.append(
|
||||
QueueItem(current=item.current, parent=None, key=None, depth=next_depth, ref_path=new_ref_path)
|
||||
)
|
||||
else:
|
||||
# Update parent container
|
||||
item.parent[item.key] = resolved_schema.copy()
|
||||
queue.append(QueueItem(
|
||||
current=item.parent[item.key],
|
||||
parent=item.parent,
|
||||
key=item.key,
|
||||
depth=next_depth,
|
||||
ref_path=new_ref_path
|
||||
))
|
||||
|
||||
queue.append(
|
||||
QueueItem(
|
||||
current=item.parent[item.key],
|
||||
parent=item.parent,
|
||||
key=item.key,
|
||||
depth=next_depth,
|
||||
ref_path=new_ref_path,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]:
|
||||
"""Get resolved schema from cache or registry"""
|
||||
# Check cache first
|
||||
with self._cache_lock:
|
||||
if ref_uri in self._cache:
|
||||
return self._cache[ref_uri].copy()
|
||||
|
||||
|
||||
# Fetch from registry
|
||||
schema = self.registry.get_schema(ref_uri)
|
||||
if not schema:
|
||||
return None
|
||||
|
||||
|
||||
# Clean and cache
|
||||
cleaned = _remove_metadata_fields(schema)
|
||||
with self._cache_lock:
|
||||
self._cache[ref_uri] = cleaned
|
||||
|
||||
|
||||
return cleaned.copy()
|
||||
|
||||
|
||||
def resolve_dify_schema_refs(
|
||||
schema: SchemaType,
|
||||
registry: Optional[SchemaRegistry] = None,
|
||||
max_depth: int = 30
|
||||
schema: SchemaType, registry: Optional[SchemaRegistry] = None, max_depth: int = 30
|
||||
) -> SchemaType:
|
||||
"""
|
||||
Resolve $ref references in Dify schema to actual schema content
|
||||
|
||||
|
||||
This is a convenience function that creates a resolver and resolves the schema.
|
||||
Performance optimization: quickly checks for $ref presence before processing.
|
||||
|
||||
|
||||
Args:
|
||||
schema: Schema object that may contain $ref references
|
||||
registry: Optional schema registry, defaults to default registry
|
||||
max_depth: Maximum depth to prevent infinite loops (default: 30)
|
||||
|
||||
|
||||
Returns:
|
||||
Schema with all $ref references resolved to actual content
|
||||
|
||||
|
||||
Raises:
|
||||
CircularReferenceError: If circular reference detected
|
||||
MaxDepthExceededError: If maximum depth exceeded
|
||||
|
|
@ -260,7 +248,7 @@ def resolve_dify_schema_refs(
|
|||
# This avoids expensive deepcopy and BFS traversal for schemas without refs
|
||||
if not _has_dify_refs(schema):
|
||||
return schema
|
||||
|
||||
|
||||
# Slow path: schema contains refs, perform full resolution
|
||||
resolver = SchemaResolver(registry, max_depth)
|
||||
return resolver.resolve(schema)
|
||||
|
|
@ -269,36 +257,36 @@ def resolve_dify_schema_refs(
|
|||
def _remove_metadata_fields(schema: dict) -> dict:
|
||||
"""
|
||||
Remove metadata fields from schema that shouldn't be included in resolved output
|
||||
|
||||
|
||||
Args:
|
||||
schema: Schema dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
Cleaned schema without metadata fields
|
||||
"""
|
||||
# Create a copy and remove metadata fields
|
||||
cleaned = schema.copy()
|
||||
metadata_fields = ["$id", "$schema", "version"]
|
||||
|
||||
|
||||
for field in metadata_fields:
|
||||
cleaned.pop(field, None)
|
||||
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _is_dify_schema_ref(ref_uri: Any) -> bool:
|
||||
"""
|
||||
Check if the reference URI is a Dify schema reference
|
||||
|
||||
|
||||
Args:
|
||||
ref_uri: URI to check
|
||||
|
||||
|
||||
Returns:
|
||||
True if it's a Dify schema reference
|
||||
"""
|
||||
if not isinstance(ref_uri, str):
|
||||
return False
|
||||
|
||||
|
||||
# Use pre-compiled pattern for better performance
|
||||
return bool(_DIFY_SCHEMA_PATTERN.match(ref_uri))
|
||||
|
||||
|
|
@ -306,12 +294,12 @@ def _is_dify_schema_ref(ref_uri: Any) -> bool:
|
|||
def _has_dify_refs_recursive(schema: SchemaType) -> bool:
|
||||
"""
|
||||
Recursively check if a schema contains any Dify $ref references
|
||||
|
||||
|
||||
This is the fallback method when string-based detection is not possible.
|
||||
|
||||
|
||||
Args:
|
||||
schema: Schema to check for references
|
||||
|
||||
|
||||
Returns:
|
||||
True if any Dify $ref is found, False otherwise
|
||||
"""
|
||||
|
|
@ -320,18 +308,18 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool:
|
|||
ref_uri = schema.get("$ref")
|
||||
if ref_uri and _is_dify_schema_ref(ref_uri):
|
||||
return True
|
||||
|
||||
|
||||
# Check nested values
|
||||
for value in schema.values():
|
||||
if _has_dify_refs_recursive(value):
|
||||
return True
|
||||
|
||||
|
||||
elif isinstance(schema, list):
|
||||
# Check each item in the list
|
||||
for item in schema:
|
||||
if _has_dify_refs_recursive(item):
|
||||
return True
|
||||
|
||||
|
||||
# Primitive types don't contain refs
|
||||
return False
|
||||
|
||||
|
|
@ -339,36 +327,37 @@ def _has_dify_refs_recursive(schema: SchemaType) -> bool:
|
|||
def _has_dify_refs_hybrid(schema: SchemaType) -> bool:
|
||||
"""
|
||||
Hybrid detection: fast string scan followed by precise recursive check
|
||||
|
||||
|
||||
Performance optimization using two-phase detection:
|
||||
1. Fast string scan to quickly eliminate schemas without $ref
|
||||
2. Precise recursive validation only for potential candidates
|
||||
|
||||
|
||||
Args:
|
||||
schema: Schema to check for references
|
||||
|
||||
|
||||
Returns:
|
||||
True if any Dify $ref is found, False otherwise
|
||||
"""
|
||||
# Phase 1: Fast string-based pre-filtering
|
||||
try:
|
||||
import json
|
||||
schema_str = json.dumps(schema, separators=(',', ':'))
|
||||
|
||||
|
||||
schema_str = json.dumps(schema, separators=(",", ":"))
|
||||
|
||||
# Quick elimination: no $ref at all
|
||||
if '"$ref"' not in schema_str:
|
||||
return False
|
||||
|
||||
|
||||
# Quick elimination: no Dify schema URLs
|
||||
if 'https://dify.ai/schemas/' not in schema_str:
|
||||
if "https://dify.ai/schemas/" not in schema_str:
|
||||
return False
|
||||
|
||||
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
# JSON serialization failed (e.g., circular references, non-serializable objects)
|
||||
# Fall back to recursive detection
|
||||
logger.debug("JSON serialization failed for schema, using recursive detection")
|
||||
return _has_dify_refs_recursive(schema)
|
||||
|
||||
|
||||
# Phase 2: Precise recursive validation
|
||||
# Only executed for schemas that passed string pre-filtering
|
||||
return _has_dify_refs_recursive(schema)
|
||||
|
|
@ -377,14 +366,14 @@ def _has_dify_refs_hybrid(schema: SchemaType) -> bool:
|
|||
def _has_dify_refs(schema: SchemaType) -> bool:
|
||||
"""
|
||||
Check if a schema contains any Dify $ref references
|
||||
|
||||
|
||||
Uses hybrid detection for optimal performance:
|
||||
- Fast string scan for quick elimination
|
||||
- Fast string scan for quick elimination
|
||||
- Precise recursive check for validation
|
||||
|
||||
|
||||
Args:
|
||||
schema: Schema to check for references
|
||||
|
||||
|
||||
Returns:
|
||||
True if any Dify $ref is found, False otherwise
|
||||
"""
|
||||
|
|
@ -394,15 +383,15 @@ def _has_dify_refs(schema: SchemaType) -> bool:
|
|||
def parse_dify_schema_uri(uri: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse a Dify schema URI to extract version and schema name
|
||||
|
||||
|
||||
Args:
|
||||
uri: Schema URI to parse
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (version, schema_name) or ("", "") if invalid
|
||||
"""
|
||||
match = _DIFY_SCHEMA_PATTERN.match(uri)
|
||||
if not match:
|
||||
return "", ""
|
||||
|
||||
return match.group(1), match.group(2)
|
||||
|
||||
return match.group(1), match.group(2)
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@ class SchemaManager:
|
|||
def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]:
|
||||
"""
|
||||
Get all JSON Schema definitions for a specific version
|
||||
|
||||
|
||||
Args:
|
||||
version: Schema version, defaults to v1
|
||||
|
||||
|
||||
Returns:
|
||||
Array containing schema definitions, each element contains name and schema fields
|
||||
"""
|
||||
|
|
@ -25,31 +25,28 @@ class SchemaManager:
|
|||
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]:
|
||||
"""
|
||||
Get a specific schema by name
|
||||
|
||||
|
||||
Args:
|
||||
schema_name: Schema name
|
||||
version: Schema version, defaults to v1
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary containing name and schema, returns None if not found
|
||||
"""
|
||||
uri = f"https://dify.ai/schemas/{version}/{schema_name}.json"
|
||||
schema = self.registry.get_schema(uri)
|
||||
|
||||
|
||||
if schema:
|
||||
return {
|
||||
"name": schema_name,
|
||||
"schema": schema
|
||||
}
|
||||
return {"name": schema_name, "schema": schema}
|
||||
return None
|
||||
|
||||
def list_available_schemas(self, version: str = "v1") -> list[str]:
|
||||
"""
|
||||
List all available schema names for a specific version
|
||||
|
||||
|
||||
Args:
|
||||
version: Schema version, defaults to v1
|
||||
|
||||
|
||||
Returns:
|
||||
List of schema names
|
||||
"""
|
||||
|
|
@ -58,8 +55,8 @@ class SchemaManager:
|
|||
def list_available_versions(self) -> list[str]:
|
||||
"""
|
||||
List all available schema versions
|
||||
|
||||
|
||||
Returns:
|
||||
List of versions
|
||||
"""
|
||||
return self.registry.list_versions()
|
||||
return self.registry.list_versions()
|
||||
|
|
|
|||
|
|
@ -152,7 +152,6 @@ class ToolEngine:
|
|||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
|
|
@ -166,7 +165,6 @@ class ToolEngine:
|
|||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
tool.thread_pool_id = thread_pool_id
|
||||
|
||||
if tool.runtime and tool.runtime.runtime_parameters:
|
||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||
|
|
|
|||
|
|
@ -13,31 +13,16 @@ from sqlalchemy.orm import Session
|
|||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
|
|
@ -53,16 +38,28 @@ from core.tools.entities.tool_entities import (
|
|||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import ToolProviderID
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -117,6 +114,8 @@ class ToolManager:
|
|||
get the plugin provider
|
||||
"""
|
||||
# check if context is set
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
|
||||
try:
|
||||
contexts.plugin_tool_providers.get()
|
||||
except LookupError:
|
||||
|
|
@ -172,6 +171,7 @@ class ToolManager:
|
|||
|
||||
:return: the tool
|
||||
"""
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
# check if the builtin tool need credentials
|
||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||
|
|
@ -216,16 +216,16 @@ class ToolManager:
|
|||
# fallback to the default provider
|
||||
if builtin_provider is None:
|
||||
# use the default provider
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
with Session(db.engine) as session:
|
||||
builtin_provider = session.scalar(
|
||||
sa.select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
|
|
@ -256,6 +256,7 @@ class ToolManager:
|
|||
# check if the credentials is expired
|
||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||
# TODO: circular import
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
# refresh the credentials
|
||||
|
|
@ -263,6 +264,7 @@ class ToolManager:
|
|||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
# refresh the credentials
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
|
|
@ -358,7 +360,7 @@ class ToolManager:
|
|||
app_id: str,
|
||||
agent_tool: AgentToolEntity,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional[VariablePool] = None,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
|
|
@ -400,7 +402,7 @@ class ToolManager:
|
|||
node_id: str,
|
||||
workflow_tool: "ToolEntity",
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional[VariablePool] = None,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
|
|
@ -516,6 +518,8 @@ class ToolManager:
|
|||
"""
|
||||
list all the plugin providers
|
||||
"""
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
|
||||
manager = PluginToolManager()
|
||||
provider_entities = manager.fetch_tool_providers(tenant_id)
|
||||
return [
|
||||
|
|
@ -977,7 +981,7 @@ class ToolManager:
|
|||
def _convert_tool_parameters_type(
|
||||
cls,
|
||||
parameters: list[ToolParameter],
|
||||
variable_pool: Optional[VariablePool],
|
||||
variable_pool: Optional["VariablePool"],
|
||||
tool_configurations: dict[str, Any],
|
||||
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
||||
) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -39,14 +39,12 @@ class WorkflowTool(Tool):
|
|||
entity: ToolEntity,
|
||||
runtime: ToolRuntime,
|
||||
label: str = "Workflow",
|
||||
thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
self.workflow_app_id = workflow_app_id
|
||||
self.workflow_as_tool_id = workflow_as_tool_id
|
||||
self.version = version
|
||||
self.workflow_entities = workflow_entities
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.thread_pool_id = thread_pool_id
|
||||
self.label = label
|
||||
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
|
|
@ -90,7 +88,6 @@ class WorkflowTool(Tool):
|
|||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
workflow_thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
assert isinstance(result, dict)
|
||||
data = result.get("data", {})
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class ArraySegment(Segment):
|
|||
def markdown(self) -> str:
|
||||
items = []
|
||||
for item in self.value:
|
||||
items.append(str(item))
|
||||
items.append(f"- {item}")
|
||||
return "\n".join(items)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
from .base_workflow_callback import WorkflowCallback
|
||||
from .workflow_logging_callback import WorkflowLoggingCallback
|
||||
|
||||
__all__ = [
|
||||
"WorkflowCallback",
|
||||
"WorkflowLoggingCallback",
|
||||
]
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent
|
||||
|
||||
|
||||
class WorkflowCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Published event
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
@ -1,263 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
LoopRunFailedEvent,
|
||||
LoopRunNextEvent,
|
||||
LoopRunStartedEvent,
|
||||
LoopRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .base_workflow_callback import WorkflowCallback
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
"red": "31;1",
|
||||
}
|
||||
|
||||
|
||||
class WorkflowLoggingCallback(WorkflowCallback):
|
||||
def __init__(self) -> None:
|
||||
self.current_node_id: Optional[str] = None
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.print_text("\n[GraphRunStartedEvent]", color="pink")
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.print_text("\n[GraphRunSucceededEvent]", color="green")
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink")
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.on_workflow_node_execute_started(event=event)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.on_workflow_node_execute_succeeded(event=event)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.on_workflow_node_execute_failed(event=event)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.on_node_text_chunk(event=event)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self.on_workflow_parallel_started(event=event)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||
self.on_workflow_parallel_completed(event=event)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self.on_workflow_iteration_started(event=event)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self.on_workflow_iteration_next(event=event)
|
||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
self.on_workflow_iteration_completed(event=event)
|
||||
elif isinstance(event, LoopRunStartedEvent):
|
||||
self.on_workflow_loop_started(event=event)
|
||||
elif isinstance(event, LoopRunNextEvent):
|
||||
self.on_workflow_loop_next(event=event)
|
||||
elif isinstance(event, LoopRunSucceededEvent | LoopRunFailedEvent):
|
||||
self.on_workflow_loop_completed(event=event)
|
||||
else:
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
|
||||
|
||||
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="yellow")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="yellow")
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
self.print_text("\n[NodeRunSucceededEvent]", color="green")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="green")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="green")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="green")
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: "
|
||||
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
|
||||
color="green",
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
self.print_text("\n[NodeRunFailedEvent]", color="red")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="red")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="red")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="red")
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Error: {node_run_result.error}", color="red")
|
||||
self.print_text(
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color="red",
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: "
|
||||
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color="red",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
route_node_state = event.route_node_state
|
||||
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
||||
self.current_node_id = route_node_state.node_id
|
||||
self.print_text("\n[NodeRunStreamChunkEvent]")
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}")
|
||||
|
||||
node_run_result = route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
|
||||
)
|
||||
|
||||
self.print_text(event.chunk_content, color="pink", end="")
|
||||
|
||||
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish parallel started
|
||||
"""
|
||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
|
||||
if event.in_loop_id:
|
||||
self.print_text(f"Loop ID: {event.in_loop_id}", color="blue")
|
||||
|
||||
def on_workflow_parallel_completed(
|
||||
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel completed
|
||||
"""
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
color = "blue"
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
color = "red"
|
||||
|
||||
self.print_text(
|
||||
"\n[ParallelBranchRunSucceededEvent]"
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent)
|
||||
else "\n[ParallelBranchRunFailedEvent]",
|
||||
color=color,
|
||||
)
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
|
||||
if event.in_loop_id:
|
||||
self.print_text(f"Loop ID: {event.in_loop_id}", color=color)
|
||||
|
||||
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self.print_text(f"Error: {event.error}", color=color)
|
||||
|
||||
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self.print_text("\n[IterationRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self.print_text("\n[IterationRunNextEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
self.print_text(f"Iteration Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self.print_text(
|
||||
"\n[IterationRunSucceededEvent]"
|
||||
if isinstance(event, IterationRunSucceededEvent)
|
||||
else "\n[IterationRunFailedEvent]",
|
||||
color="blue",
|
||||
)
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_loop_started(self, event: LoopRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish loop started
|
||||
"""
|
||||
self.print_text("\n[LoopRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
|
||||
def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None:
|
||||
"""
|
||||
Publish loop next
|
||||
"""
|
||||
self.print_text("\n[LoopRunNextEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
self.print_text(f"Loop Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None:
|
||||
"""
|
||||
Publish loop completed
|
||||
"""
|
||||
self.print_text(
|
||||
"\n[LoopRunSucceededEvent]" if isinstance(event, LoopRunSucceededEvent) else "\n[LoopRunFailedEvent]",
|
||||
color="blue",
|
||||
)
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
|
||||
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(f"{text_to_print}", end=end)
|
||||
|
||||
def _get_colored_text(self, text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
# GraphEngine Worker Pool Configuration
|
||||
|
||||
## Overview
|
||||
|
||||
The GraphEngine now supports **dynamic worker pool management** to optimize performance and resource usage. Instead of a fixed 10-worker pool, the engine can:
|
||||
|
||||
1. **Start with optimal worker count** based on graph complexity
|
||||
1. **Scale up** when workload increases
|
||||
1. **Scale down** when workers are idle
|
||||
1. **Respect configurable min/max limits**
|
||||
|
||||
## Benefits
|
||||
|
||||
- **Resource Efficiency**: Uses fewer workers for simple sequential workflows
|
||||
- **Better Performance**: Scales up for parallel-heavy workflows
|
||||
- **Gevent Optimization**: Works efficiently with Gevent's greenlet model
|
||||
- **Memory Savings**: Reduces memory footprint for simple workflows
|
||||
|
||||
## Configuration
|
||||
|
||||
### Configuration Variables (via dify_config)
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `GRAPH_ENGINE_MIN_WORKERS` | 1 | Minimum number of workers per engine |
|
||||
| `GRAPH_ENGINE_MAX_WORKERS` | 10 | Maximum number of workers per engine |
|
||||
| `GRAPH_ENGINE_SCALE_UP_THRESHOLD` | 3 | Queue depth that triggers scale up |
|
||||
| `GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME` | 5.0 | Seconds of idle time before scaling down |
|
||||
|
||||
### Example Configurations
|
||||
|
||||
#### Low-Resource Environment
|
||||
|
||||
```bash
|
||||
export GRAPH_ENGINE_MIN_WORKERS=1
|
||||
export GRAPH_ENGINE_MAX_WORKERS=3
|
||||
export GRAPH_ENGINE_SCALE_UP_THRESHOLD=2
|
||||
export GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=3.0
|
||||
```
|
||||
|
||||
#### High-Performance Environment
|
||||
|
||||
```bash
|
||||
export GRAPH_ENGINE_MIN_WORKERS=2
|
||||
export GRAPH_ENGINE_MAX_WORKERS=20
|
||||
export GRAPH_ENGINE_SCALE_UP_THRESHOLD=5
|
||||
export GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=10.0
|
||||
```
|
||||
|
||||
#### Default (Balanced)
|
||||
|
||||
```bash
|
||||
# Uses defaults: min=1, max=10, threshold=3, idle_time=5.0
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Initial Worker Calculation
|
||||
|
||||
The engine analyzes the graph structure at startup:
|
||||
|
||||
- **Sequential graphs** (no branches): 1 worker
|
||||
- **Limited parallelism** (few branches): 2 workers
|
||||
- **Moderate parallelism**: 3 workers
|
||||
- **High parallelism** (many branches): 5 workers
|
||||
|
||||
### Dynamic Scaling
|
||||
|
||||
During execution:
|
||||
|
||||
1. **Scale Up** triggers when:
|
||||
|
||||
- Queue depth exceeds `SCALE_UP_THRESHOLD`
|
||||
- All workers are busy and queue has items
|
||||
- Not at `MAX_WORKERS` limit
|
||||
|
||||
1. **Scale Down** triggers when:
|
||||
|
||||
- Worker idle for more than `SCALE_DOWN_IDLE_TIME` seconds
|
||||
- Above `MIN_WORKERS` limit
|
||||
|
||||
### Gevent Compatibility
|
||||
|
||||
Since Gevent patches threading to use greenlets:
|
||||
|
||||
- Workers are lightweight coroutines, not OS threads
|
||||
- Dynamic scaling has minimal overhead
|
||||
- Can efficiently handle many concurrent workers
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### Before (Fixed 10 Workers)
|
||||
|
||||
```python
|
||||
# Every GraphEngine instance created 10 workers
|
||||
# Resource waste for simple workflows
|
||||
# No adaptation to workload
|
||||
```
|
||||
|
||||
### After (Dynamic Workers)
|
||||
|
||||
```python
|
||||
# GraphEngine creates 1-5 initial workers based on graph
|
||||
# Scales up/down based on workload
|
||||
# Configurable via environment variables
|
||||
```
|
||||
|
||||
### Backward Compatibility
|
||||
|
||||
The default configuration (`max=10`) maintains compatibility with existing deployments. To get the old behavior exactly:
|
||||
|
||||
```bash
|
||||
export GRAPH_ENGINE_MIN_WORKERS=10
|
||||
export GRAPH_ENGINE_MAX_WORKERS=10
|
||||
```
|
||||
|
||||
## Performance Impact
|
||||
|
||||
### Memory Usage
|
||||
|
||||
- **Simple workflows**: ~80% reduction (1 vs 10 workers)
|
||||
- **Complex workflows**: Similar or slightly better
|
||||
|
||||
### Execution Time
|
||||
|
||||
- **Sequential workflows**: No change
|
||||
- **Parallel workflows**: Improved with proper scaling
|
||||
- **Bursty workloads**: Better adaptation
|
||||
|
||||
### Example Metrics
|
||||
|
||||
| Workflow Type | Old (10 workers) | New (Dynamic) | Improvement |
|
||||
|--------------|------------------|---------------|-------------|
|
||||
| Sequential | 10 workers idle | 1 worker active | 90% fewer workers |
|
||||
| 3-way parallel | 7 workers idle | 3 workers active | 70% fewer workers |
|
||||
| Heavy parallel | 10 workers busy | 10+ workers (scales up) | Better throughput |
|
||||
|
||||
## Monitoring
|
||||
|
||||
Log messages indicate scaling activity:
|
||||
|
||||
```shell
|
||||
INFO: GraphEngine initialized with 2 workers (min: 1, max: 10)
|
||||
INFO: Scaled up workers: 2 -> 3 (queue_depth: 4)
|
||||
INFO: Scaled down workers: 3 -> 2 (removed 1 idle workers)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Start with defaults** - They work well for most cases
|
||||
1. **Monitor queue depth** - Adjust `SCALE_UP_THRESHOLD` if queues back up
|
||||
1. **Consider workload patterns**:
|
||||
- Bursty: Lower `SCALE_DOWN_IDLE_TIME`
|
||||
- Steady: Higher `SCALE_DOWN_IDLE_TIME`
|
||||
1. **Test with your workloads** - Measure and tune
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Workers not scaling up
|
||||
|
||||
- Check `GRAPH_ENGINE_MAX_WORKERS` limit
|
||||
- Verify queue depth exceeds threshold
|
||||
- Check logs for scaling messages
|
||||
|
||||
### Workers scaling down too quickly
|
||||
|
||||
- Increase `GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME`
|
||||
- Consider workload patterns
|
||||
|
||||
### Out of memory
|
||||
|
||||
- Reduce `GRAPH_ENGINE_MAX_WORKERS`
|
||||
- Check for memory leaks in nodes
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .run_condition import RunCondition
|
||||
from .variable_pool import VariablePool, VariableValue
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"RunCondition",
|
||||
"VariablePool",
|
||||
"VariableValue",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentNodeStrategyInit(BaseModel):
|
||||
"""Agent node strategy initialization data."""
|
||||
|
||||
name: str
|
||||
icon: Optional[str] = None
|
||||
|
|
@ -3,19 +3,18 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
user_from: str = Field(
|
||||
..., description="user from, account or end-user"
|
||||
) # Should be UserFrom enum: 'account' | 'end-user'
|
||||
invoke_from: str = Field(
|
||||
..., description="invoke from, service-api, web-app, explore or debugger"
|
||||
) # Should be InvokeFrom enum: 'service-api' | 'web-app' | 'explore' | 'debugger'
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
|
|
@ -3,8 +3,8 @@ from typing import Any
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
|
||||
from .variable_pool import VariablePool
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
|
|
@ -26,6 +26,3 @@ class GraphRuntimeState(BaseModel):
|
|||
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
|
||||
node_run_state: RuntimeRouteState = RuntimeRouteState()
|
||||
"""node run state"""
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
"""
|
||||
Node Run Result.
|
||||
"""
|
||||
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[Mapping[str, Any]] = None # process data
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata
|
||||
llm_usage: Optional[LLMUsage] = None # llm usage
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
error_type: Optional[str] = None # error type if status is failed
|
||||
|
||||
# single step node run retry
|
||||
retry_index: int = 0
|
||||
|
||||
|
||||
class AgentNodeStrategyInit(BaseModel):
|
||||
name: str
|
||||
icon: str | None = None
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
"""
|
||||
Variable Selector.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_selector: Sequence[str]
|
||||
|
|
@ -68,10 +68,10 @@ class VariablePool(BaseModel):
|
|||
# Add rag pipeline variables to the variable pool
|
||||
if self.rag_pipeline_variables:
|
||||
rag_pipeline_variables_map = defaultdict(dict)
|
||||
for var in self.rag_pipeline_variables:
|
||||
node_id = var.variable.belong_to_node_id
|
||||
key = var.variable.variable
|
||||
value = var.value
|
||||
for rag_var in self.rag_pipeline_variables:
|
||||
node_id = rag_var.variable.belong_to_node_id
|
||||
key = rag_var.variable.variable
|
||||
value = rag_var.value
|
||||
rag_pipeline_variables_map[node_id][key] = value
|
||||
for key, value in rag_pipeline_variables_map.items():
|
||||
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
|
||||
|
|
|
|||
|
|
@ -7,32 +7,14 @@ implementation details like tenant_id, app_id, etc.
|
|||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
Workflow Type Enum for domain layer
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
STOPPED = "stopped"
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
|
||||
|
||||
class WorkflowExecution(BaseModel):
|
||||
"""
|
||||
Domain model for workflow execution based on WorkflowRun but without
|
||||
|
|
|
|||
|
|
@ -8,50 +8,11 @@ and don't contain implementation details like tenant_id, app_id, etc.
|
|||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
"""
|
||||
Node Execution Status Enum.
|
||||
"""
|
||||
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
EXCEPTION = "exception"
|
||||
RETRY = "retry"
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class WorkflowNodeExecution(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,12 @@
|
|||
from enum import StrEnum
|
||||
from enum import Enum, StrEnum
|
||||
|
||||
|
||||
class NodeState(Enum):
|
||||
"""State of a node or edge during workflow execution."""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
TAKEN = "taken"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class SystemVariableKey(StrEnum):
|
||||
|
|
@ -21,3 +29,107 @@ class SystemVariableKey(StrEnum):
|
|||
DATASOURCE_TYPE = "datasource_type"
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
INVOKE_FROM = "invoke_from"
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
ANSWER = "answer"
|
||||
LLM = "llm"
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
KNOWLEDGE_INDEX = "knowledge-index"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TEMPLATE_TRANSFORM = "template-transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
DATASOURCE = "datasource"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
LOOP_START = "loop-start"
|
||||
LOOP_END = "loop-end"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
VARIABLE_ASSIGNER = "assigner"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
AGENT = "agent"
|
||||
|
||||
|
||||
class NodeExecutionType(StrEnum):
|
||||
"""Node execution type classification."""
|
||||
|
||||
EXECUTABLE = "executable" # Regular nodes that execute and produce outputs
|
||||
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
|
||||
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
|
||||
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
|
||||
|
||||
|
||||
class ErrorStrategy(StrEnum):
|
||||
FAIL_BRANCH = "fail-branch"
|
||||
DEFAULT_VALUE = "default-value"
|
||||
|
||||
|
||||
class FailBranchSourceHandle(StrEnum):
|
||||
FAILED = "fail-branch"
|
||||
SUCCESS = "success-branch"
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
Workflow Type Enum for domain layer
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
class WorkflowExecutionStatus(StrEnum):
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
STOPPED = "stopped"
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
PENDING = "pending" # Node is scheduled but not yet executing
|
||||
RUNNING = "running"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
EXCEPTION = "exception"
|
||||
STOPPED = "stopped"
|
||||
PAUSED = "paused"
|
||||
|
||||
# Legacy statuses - kept for backward compatibility
|
||||
RETRY = "retry" # Legacy: replaced by retry mechanism in error handling
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
|
||||
class WorkflowNodeRunFailedError(Exception):
|
||||
def __init__(self, node: BaseNode, err_msg: str):
|
||||
def __init__(self, node: Node, err_msg: str):
|
||||
self._node = node
|
||||
self._error = err_msg
|
||||
super().__init__(f"Node {node.title} run failed: {err_msg}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
from .edge import Edge
|
||||
from .graph import Graph, NodeFactory
|
||||
from .graph_template import GraphTemplate
|
||||
|
||||
__all__ = ["Edge", "Graph", "GraphTemplate", "NodeFactory"]
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
|
||||
|
||||
@dataclass
|
||||
class Edge:
|
||||
"""Edge connecting two nodes in a workflow graph."""
|
||||
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
tail: str = "" # tail node id (source)
|
||||
head: str = "" # head node id (target)
|
||||
source_handle: str = "source" # source handle for conditional branching
|
||||
state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Protocol, cast
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .edge import Edge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeFactory(Protocol):
|
||||
"""
|
||||
Protocol for creating Node instances from node data dictionaries.
|
||||
|
||||
This protocol decouples the Graph class from specific node mapping implementations,
|
||||
allowing for different node creation strategies while maintaining type safety.
|
||||
"""
|
||||
|
||||
def create_node(self, node_config: dict[str, Any]) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data.
|
||||
|
||||
:param node_config: node configuration dictionary containing type and other data
|
||||
:return: initialized Node instance
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Graph:
|
||||
"""Graph representation with nodes and edges for workflow execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
nodes: Optional[dict[str, Node]] = None,
|
||||
edges: Optional[dict[str, Edge]] = None,
|
||||
in_edges: Optional[dict[str, list[str]]] = None,
|
||||
out_edges: Optional[dict[str, list[str]]] = None,
|
||||
root_node: Node,
|
||||
):
|
||||
"""
|
||||
Initialize Graph instance.
|
||||
|
||||
:param nodes: graph nodes mapping (node id: node object)
|
||||
:param edges: graph edges mapping (edge id: edge object)
|
||||
:param in_edges: incoming edges mapping (node id: list of edge ids)
|
||||
:param out_edges: outgoing edges mapping (node id: list of edge ids)
|
||||
:param root_node: root node object
|
||||
"""
|
||||
self.nodes = nodes or {}
|
||||
self.edges = edges or {}
|
||||
self.in_edges = in_edges or {}
|
||||
self.out_edges = out_edges or {}
|
||||
self.root_node = root_node
|
||||
|
||||
@classmethod
|
||||
def _parse_node_configs(cls, node_configs: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Parse node configurations and build a mapping of node IDs to configs.
|
||||
|
||||
:param node_configs: list of node configuration dictionaries
|
||||
:return: mapping of node ID to node config
|
||||
"""
|
||||
node_configs_map: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
continue
|
||||
|
||||
node_configs_map[node_id] = node_config
|
||||
|
||||
return node_configs_map
|
||||
|
||||
@classmethod
|
||||
def _find_root_node_id(
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, Any]],
|
||||
edge_configs: list[dict[str, Any]],
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find the root node ID if not specified.
|
||||
|
||||
:param node_configs_map: mapping of node ID to node config
|
||||
:param edge_configs: list of edge configurations
|
||||
:param root_node_id: explicitly specified root node ID
|
||||
:return: determined root node ID
|
||||
"""
|
||||
if root_node_id:
|
||||
if root_node_id not in node_configs_map:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
return root_node_id
|
||||
|
||||
# Find nodes with no incoming edges
|
||||
nodes_with_incoming = set()
|
||||
for edge_config in edge_configs:
|
||||
target = edge_config.get("target")
|
||||
if target:
|
||||
nodes_with_incoming.add(target)
|
||||
|
||||
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
|
||||
|
||||
# Prefer START node if available
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid].get("data", {})
|
||||
if node_data.get("type") in [NodeType.START, NodeType.DATASOURCE]:
|
||||
start_node_id = nid
|
||||
break
|
||||
|
||||
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
|
||||
|
||||
if not root_node_id:
|
||||
raise ValueError("Unable to determine root node ID")
|
||||
|
||||
return root_node_id
|
||||
|
||||
@classmethod
|
||||
def _build_edges(
|
||||
cls, edge_configs: list[dict[str, Any]]
|
||||
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
|
||||
"""
|
||||
Build edge objects and mappings from edge configurations.
|
||||
|
||||
:param edge_configs: list of edge configurations
|
||||
:return: tuple of (edges dict, in_edges dict, out_edges dict)
|
||||
"""
|
||||
edges: dict[str, Edge] = {}
|
||||
in_edges: dict[str, list[str]] = defaultdict(list)
|
||||
out_edges: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
edge_counter = 0
|
||||
for edge_config in edge_configs:
|
||||
source = edge_config.get("source")
|
||||
target = edge_config.get("target")
|
||||
|
||||
if not source or not target:
|
||||
continue
|
||||
|
||||
# Create edge
|
||||
edge_id = f"edge_{edge_counter}"
|
||||
edge_counter += 1
|
||||
|
||||
source_handle = edge_config.get("sourceHandle", "source")
|
||||
|
||||
edge = Edge(
|
||||
id=edge_id,
|
||||
tail=source,
|
||||
head=target,
|
||||
source_handle=source_handle,
|
||||
)
|
||||
|
||||
edges[edge_id] = edge
|
||||
out_edges[source].append(edge_id)
|
||||
in_edges[target].append(edge_id)
|
||||
|
||||
return edges, dict(in_edges), dict(out_edges)
|
||||
|
||||
@classmethod
|
||||
def _create_node_instances(
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, Any]],
|
||||
node_factory: "NodeFactory",
|
||||
) -> dict[str, Node]:
|
||||
"""
|
||||
Create node instances from configurations using the node factory.
|
||||
|
||||
:param node_configs_map: mapping of node ID to node config
|
||||
:param node_factory: factory for creating node instances
|
||||
:return: mapping of node ID to node instance
|
||||
"""
|
||||
nodes: dict[str, Node] = {}
|
||||
|
||||
for node_id, node_config in node_configs_map.items():
|
||||
try:
|
||||
node_instance = node_factory.create_node(node_config)
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to create node instance: %s", str(e))
|
||||
continue
|
||||
nodes[node_id] = node_instance
|
||||
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_factory: "NodeFactory",
|
||||
root_node_id: Optional[str] = None,
|
||||
) -> "Graph":
|
||||
"""
|
||||
Initialize graph
|
||||
|
||||
:param graph_config: graph config containing nodes and edges
|
||||
:param node_factory: factory for creating node instances from config data
|
||||
:param root_node_id: root node id
|
||||
:return: graph instance
|
||||
"""
|
||||
# Parse configs
|
||||
edge_configs = graph_config.get("edges", [])
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
edge_configs = cast(list, edge_configs)
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
# Find root node
|
||||
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
|
||||
|
||||
# Build edges
|
||||
edges, in_edges, out_edges = cls._build_edges(edge_configs)
|
||||
|
||||
# Create node instances
|
||||
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
||||
|
||||
# Get root node instance
|
||||
root_node = nodes[root_node_id]
|
||||
|
||||
# Create and return the graph
|
||||
return cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=in_edges,
|
||||
out_edges=out_edges,
|
||||
root_node=root_node,
|
||||
)
|
||||
|
||||
@property
|
||||
def node_ids(self) -> list[str]:
|
||||
"""
|
||||
Get list of node IDs (compatibility property for existing code)
|
||||
|
||||
:return: list of node IDs
|
||||
"""
|
||||
return list(self.nodes.keys())
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> list[Edge]:
|
||||
"""
|
||||
Get all outgoing edges from a node (V2 method)
|
||||
|
||||
:param node_id: node id
|
||||
:return: list of outgoing edges
|
||||
"""
|
||||
edge_ids = self.out_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
|
||||
def get_incoming_edges(self, node_id: str) -> list[Edge]:
|
||||
"""
|
||||
Get all incoming edges to a node (V2 method)
|
||||
|
||||
:param node_id: node id
|
||||
:return: list of incoming edges
|
||||
"""
|
||||
edge_ids = self.in_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GraphTemplate(BaseModel):
|
||||
"""
|
||||
Graph Template for container nodes and subgraph expansion
|
||||
|
||||
According to GraphEngine V2 spec, GraphTemplate contains:
|
||||
- nodes: mapping of node definitions
|
||||
- edges: mapping of edge definitions
|
||||
- root_ids: list of root node IDs
|
||||
- output_selectors: list of output selectors for the template
|
||||
"""
|
||||
|
||||
nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping")
|
||||
edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping")
|
||||
root_ids: list[str] = Field(default_factory=list, description="root node IDs")
|
||||
output_selectors: list[str] = Field(default_factory=list, description="output selectors")
|
||||
|
|
@ -0,0 +1,187 @@
|
|||
# Graph Engine
|
||||
|
||||
Queue-based workflow execution engine for parallel graph processing.
|
||||
|
||||
## Architecture
|
||||
|
||||
The engine uses a modular architecture with specialized packages:
|
||||
|
||||
### Core Components
|
||||
|
||||
- **Domain** (`domain/`) - Core models: ExecutionContext, GraphExecution, NodeExecution
|
||||
- **Event Management** (`event_management/`) - Event handling, collection, and emission
|
||||
- **State Management** (`state_management/`) - Thread-safe state tracking for nodes and edges
|
||||
- **Error Handling** (`error_handling/`) - Strategy-based error recovery (retry, abort, fail-branch, default-value)
|
||||
- **Graph Traversal** (`graph_traversal/`) - Node readiness, edge processing, branch handling
|
||||
- **Command Processing** (`command_processing/`) - External command handling (abort, pause, resume)
|
||||
- **Worker Management** (`worker_management/`) - Dynamic worker pool with auto-scaling
|
||||
- **Orchestration** (`orchestration/`) - Main event loop and execution coordination
|
||||
|
||||
### Supporting Components
|
||||
|
||||
- **Output Registry** (`output_registry/`) - Thread-safe storage for node outputs
|
||||
- **Response Coordinator** (`response_coordinator/`) - Ordered streaming of response nodes
|
||||
- **Command Channels** (`command_channels/`) - Command transport (InMemory/Redis)
|
||||
- **Layers** (`layers/`) - Pluggable middleware for extensions
|
||||
|
||||
## Architecture Diagram
|
||||
|
||||
```mermaid
|
||||
classDiagram
|
||||
class GraphEngine {
|
||||
+run()
|
||||
+add_layer()
|
||||
}
|
||||
|
||||
class Domain {
|
||||
ExecutionContext
|
||||
GraphExecution
|
||||
NodeExecution
|
||||
}
|
||||
|
||||
class EventManagement {
|
||||
EventHandlerRegistry
|
||||
EventCollector
|
||||
EventEmitter
|
||||
}
|
||||
|
||||
class StateManagement {
|
||||
NodeStateManager
|
||||
EdgeStateManager
|
||||
ExecutionTracker
|
||||
}
|
||||
|
||||
class WorkerManagement {
|
||||
WorkerPool
|
||||
WorkerFactory
|
||||
DynamicScaler
|
||||
ActivityTracker
|
||||
}
|
||||
|
||||
class GraphTraversal {
|
||||
NodeReadinessChecker
|
||||
EdgeProcessor
|
||||
BranchHandler
|
||||
SkipPropagator
|
||||
}
|
||||
|
||||
class Orchestration {
|
||||
Dispatcher
|
||||
ExecutionCoordinator
|
||||
}
|
||||
|
||||
class ErrorHandling {
|
||||
ErrorHandler
|
||||
RetryStrategy
|
||||
AbortStrategy
|
||||
FailBranchStrategy
|
||||
}
|
||||
|
||||
class CommandProcessing {
|
||||
CommandProcessor
|
||||
AbortCommandHandler
|
||||
}
|
||||
|
||||
class CommandChannels {
|
||||
InMemoryChannel
|
||||
RedisChannel
|
||||
}
|
||||
|
||||
class OutputRegistry {
|
||||
<<Storage>>
|
||||
Scalar Values
|
||||
Streaming Data
|
||||
}
|
||||
|
||||
class ResponseCoordinator {
|
||||
Session Management
|
||||
Path Analysis
|
||||
}
|
||||
|
||||
class Layers {
|
||||
<<Plugin>>
|
||||
DebugLoggingLayer
|
||||
}
|
||||
|
||||
GraphEngine --> Orchestration : coordinates
|
||||
GraphEngine --> Layers : extends
|
||||
|
||||
Orchestration --> EventManagement : processes events
|
||||
Orchestration --> WorkerManagement : manages scaling
|
||||
Orchestration --> CommandProcessing : checks commands
|
||||
Orchestration --> StateManagement : monitors state
|
||||
|
||||
WorkerManagement --> StateManagement : consumes ready queue
|
||||
WorkerManagement --> EventManagement : produces events
|
||||
WorkerManagement --> Domain : executes nodes
|
||||
|
||||
EventManagement --> ErrorHandling : failed events
|
||||
EventManagement --> GraphTraversal : success events
|
||||
EventManagement --> ResponseCoordinator : stream events
|
||||
EventManagement --> Layers : notifies
|
||||
|
||||
GraphTraversal --> StateManagement : updates states
|
||||
GraphTraversal --> Domain : checks graph
|
||||
|
||||
CommandProcessing --> CommandChannels : fetches commands
|
||||
CommandProcessing --> Domain : modifies execution
|
||||
|
||||
ErrorHandling --> Domain : handles failures
|
||||
|
||||
StateManagement --> Domain : tracks entities
|
||||
|
||||
ResponseCoordinator --> OutputRegistry : reads outputs
|
||||
|
||||
Domain --> OutputRegistry : writes outputs
|
||||
```
|
||||
|
||||
## Package Relationships
|
||||
|
||||
### Core Dependencies
|
||||
|
||||
- **Orchestration** acts as the central coordinator, managing all subsystems
|
||||
- **Domain** provides the core business entities used by all packages
|
||||
- **EventManagement** serves as the communication backbone between components
|
||||
- **StateManagement** maintains thread-safe state for the entire system
|
||||
|
||||
### Data Flow
|
||||
|
||||
1. **Commands** flow from CommandChannels → CommandProcessing → Domain
|
||||
1. **Events** flow from Workers → EventHandlerRegistry → State updates
|
||||
1. **Node outputs** flow from Workers → OutputRegistry → ResponseCoordinator
|
||||
1. **Ready nodes** flow from GraphTraversal → StateManagement → WorkerManagement
|
||||
|
||||
### Extension Points
|
||||
|
||||
- **Layers** observe all events for monitoring, logging, and custom logic
|
||||
- **ErrorHandling** strategies can be extended for custom failure recovery
|
||||
- **CommandChannels** can be implemented for different transport mechanisms
|
||||
|
||||
## Execution Flow
|
||||
|
||||
1. **Initialization**: GraphEngine creates all subsystems with the workflow graph
|
||||
1. **Node Discovery**: Traversal components identify ready nodes
|
||||
1. **Worker Execution**: Workers pull from ready queue and execute nodes
|
||||
1. **Event Processing**: Dispatcher routes events to appropriate handlers
|
||||
1. **State Updates**: Managers track node/edge states for next steps
|
||||
1. **Completion**: Coordinator detects when all nodes are done
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
|
||||
# Create and run engine
|
||||
engine = GraphEngine(
|
||||
tenant_id="tenant_1",
|
||||
app_id="app_1",
|
||||
workflow_id="workflow_1",
|
||||
graph=graph,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Stream execution events
|
||||
for event in engine.run():
|
||||
handle_event(event)
|
||||
```
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
|
||||
from .graph_engine import GraphEngine
|
||||
|
||||
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
__all__ = ["GraphEngine"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
# Command Channels
|
||||
|
||||
Channel implementations for external workflow control.
|
||||
|
||||
## Components
|
||||
|
||||
### InMemoryChannel
|
||||
|
||||
Thread-safe in-memory queue for single-process deployments.
|
||||
|
||||
- `fetch_commands()` - Get pending commands
|
||||
- `send_command()` - Add command to queue
|
||||
|
||||
### RedisChannel
|
||||
|
||||
Redis-based queue for distributed deployments.
|
||||
|
||||
- `fetch_commands()` - Get commands with JSON deserialization
|
||||
- `send_command()` - Store commands with TTL
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# Local execution
|
||||
channel = InMemoryChannel()
|
||||
channel.send_command(AbortCommand(graph_id="workflow-123"))
|
||||
|
||||
# Distributed execution
|
||||
redis_channel = RedisChannel(
|
||||
redis_client=redis_client,
|
||||
channel_key="workflow:123:commands"
|
||||
)
|
||||
```
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
"""Command channel implementations for GraphEngine."""
|
||||
|
||||
from .in_memory_channel import InMemoryChannel
|
||||
from .redis_channel import RedisChannel
|
||||
|
||||
__all__ = ["InMemoryChannel", "RedisChannel"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue