Merge remote-tracking branch 'myori/main' into p363

This commit is contained in:
hjlarry 2026-03-02 13:49:21 +08:00
commit 92ab5dae98
474 changed files with 1818 additions and 1882 deletions

View File

@ -29,20 +29,26 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Prepare diagnostics extractor
run: |
git show ${{ github.event.pull_request.head.sha }}:api/libs/pyrefly_diagnostics.py > /tmp/pyrefly_diagnostics.py
- name: Run pyrefly on PR branch
run: |
uv run --directory api pyrefly check > /tmp/pyrefly_pr.txt 2>&1 || true
uv run --directory api --dev pyrefly check 2>&1 \
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_pr.txt || true
- name: Checkout base branch
run: git checkout ${{ github.base_ref }}
- name: Run pyrefly on base branch
run: |
uv run --directory api pyrefly check > /tmp/pyrefly_base.txt 2>&1 || true
uv run --directory api --dev pyrefly check 2>&1 \
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_base.txt || true
- name: Compute diff
run: |
diff /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
- name: Save PR number
run: |

View File

@ -68,10 +68,9 @@ lint:
@echo "✅ Linting complete"
type-check:
@echo "📝 Running type checks (basedpyright + mypy + ty)..."
@echo "📝 Running type checks (basedpyright + mypy)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@cd api && uv run ty check
@echo "✅ Type checks complete"
test:
@ -132,7 +131,7 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checks (basedpyright, mypy, ty)"
@echo " make type-check - Run type checks (basedpyright, mypy)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@echo ""
@echo "Docker Build Targets:"

View File

@ -29,6 +29,8 @@ ignore_imports =
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
@ -52,7 +54,6 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
# TODO(QuantumGhost): use DI to avoid depending on global DB.
@ -107,14 +108,11 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> models.model
core.workflow.nodes.llm.llm_utils -> models.provider
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
core.workflow.nodes.llm.node -> core.tools.signature
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
@ -131,12 +129,10 @@ ignore_imports =
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.start.entities -> core.app.app_config.entities
core.workflow.nodes.start.start_node -> core.app.app_config.entities
core.workflow.workflow_entry -> core.app.apps.exc
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota
core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
@ -150,7 +146,6 @@ ignore_imports =
core.workflow.nodes.llm.node -> core.model_manager
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
@ -172,7 +167,6 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
@ -180,7 +174,7 @@ ignore_imports =
core.workflow.workflow_entry -> extensions.otel.runtime
core.workflow.nodes.agent.agent_node -> models
core.workflow.nodes.base.node -> models.enums
core.workflow.nodes.llm.llm_utils -> models.provider_ids
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.llm.node -> models.model
core.workflow.workflow_entry -> models.enums
core.workflow.nodes.agent.agent_node -> services

View File

@ -8,9 +8,9 @@ from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
from core.mcp import types as mcp_types
from core.mcp.server.streamable_http import handle_mcp_request
from core.workflow.variables.input_entities import VariableEntity
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode, EndUser

View File

@ -1,7 +1,8 @@
import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.app.app_config.entities import ExternalDataVariableEntity
from core.external_data_tool.factory import ExternalDataToolFactory
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
[

View File

@ -2,12 +2,12 @@ from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.file import FileTransferMethod, FileType, FileUploadConfig
from core.workflow.file import FileUploadConfig
from core.workflow.variables.input_entities import VariableEntity as WorkflowVariableEntity
from models.model import AppMode
@ -90,61 +90,7 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
class VariableEntityType(StrEnum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
JSON_OBJECT = "json_object"
class VariableEntity(BaseModel):
"""
Variable Entity.
"""
# `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""
type: VariableEntityType
required: bool = False
hide: bool = False
default: Any = None
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
def convert_none_description(cls, v: Any) -> str:
return v or ""
@field_validator("options", mode="before")
@classmethod
def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or []
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict | None) -> dict | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema
class RagPipelineVariableEntity(VariableEntity):
class RagPipelineVariableEntity(WorkflowVariableEntity):
"""
Rag Pipeline Variable Entity.
"""
@ -314,7 +260,7 @@ class AppConfig(BaseModel):
app_id: str
app_mode: AppMode
additional_features: AppAdditionalFeatures | None = None
variables: list[VariableEntity] = []
variables: list[WorkflowVariableEntity] = []
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None

View File

@ -1,6 +1,7 @@
import re
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from core.app.app_config.entities import RagPipelineVariableEntity
from core.workflow.variables.input_entities import VariableEntity
from models.workflow import Workflow

View File

@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Union, final
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import NodeType
from core.workflow.file import File, FileUploadConfig
@ -12,13 +11,14 @@ from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
NoopDraftVariableSaver,
)
from core.workflow.variables.input_entities import VariableEntityType
from factories import file_factory
from libs.orjson import orjson_dumps
from models import Account, EndUser
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
from core.app.app_config.entities import VariableEntity
from core.workflow.variables.input_entities import VariableEntity
class BaseAppGenerator:

View File

@ -1 +1,5 @@
"""LLM-related application services."""
from .quota import deduct_llm_quota, ensure_llm_quota_available
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]

93
api/core/app/llm/quota.py Normal file
View File

@ -0,0 +1,93 @@
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
provider_model = provider_configuration.get_provider_model(
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
)
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@ -1,9 +1,11 @@
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
from .llm_quota import LLMQuotaLayer
from .observability import ObservabilityLayer
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
__all__ = [
"LLMQuotaLayer",
"ObservabilityLayer",
"PersistenceWorkflowInfo",
"WorkflowPersistenceLayer",

View File

@ -0,0 +1,128 @@
"""
LLM quota deduction layer for GraphEngine.
This layer centralizes model-quota deduction outside node implementations.
"""
import logging
from typing import TYPE_CHECKING, cast, final
from typing_extensions import override
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.workflow.enums import NodeType
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.nodes.base.node import Node
if TYPE_CHECKING:
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
logger = logging.getLogger(__name__)
@final
class LLMQuotaLayer(GraphEngineLayer):
"""Graph layer that applies LLM quota deduction after node execution."""
def __init__(self) -> None:
super().__init__()
self._abort_sent = False
@override
def on_graph_start(self) -> None:
self._abort_sent = False
@override
def on_event(self, event: GraphEngineEvent) -> None:
_ = event
@override
def on_graph_end(self, error: Exception | None) -> None:
_ = error
@override
def on_node_run_start(self, node: Node) -> None:
if self._abort_sent:
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
ensure_llm_quota_available(model_instance=model_instance)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
@override
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
deduct_llm_quota(
tenant_id=node.tenant_id,
model_instance=model_instance,
usage=result_event.node_run_result.llm_usage,
)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc)
except Exception:
logger.exception("LLM quota deduction failed, node_id=%s", node.id)
@staticmethod
def _set_stop_event(node: Node) -> None:
stop_event = getattr(node.graph_runtime_state, "stop_event", None)
if stop_event is not None:
stop_event.set()
def _send_abort_command(self, *, reason: str) -> None:
if not self.command_channel or self._abort_sent:
return
try:
self.command_channel.send_command(
AbortCommand(
command_type=CommandType.ABORT,
reason=reason,
)
)
self._abort_sent = True
except Exception:
logger.exception("Failed to send quota abort command")
@staticmethod
def _extract_model_instance(node: Node) -> ModelInstance | None:
try:
match node.node_type:
case NodeType.LLM:
return cast("LLMNode", node).model_instance
case NodeType.PARAMETER_EXTRACTOR:
return cast("ParameterExtractorNode", node).model_instance
case NodeType.QUESTION_CLASSIFIER:
return cast("QuestionClassifierNode", node).model_instance
case _:
return None
except AttributeError:
logger.warning(
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
node.id,
)
return None

View File

@ -1,6 +1,8 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast, final
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
@ -11,14 +13,16 @@ from core.helper.code_executor.code_executor import (
CodeExecutor,
)
from core.helper.ssrf_proxy import ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
from core.workflow.enums import NodeType, SystemVariableKey
from core.workflow.file.file_manager import file_manager
from core.workflow.graph.graph import NodeFactory
from core.workflow.nodes.base.node import Node
@ -29,11 +33,9 @@ from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.protocols import PromptMessageMemory
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
@ -41,12 +43,34 @@ from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
def fetch_memory(
*,
conversation_id: str | None,
app_id: str,
node_data_memory: MemoryConfig | None,
model_instance: ModelInstance,
) -> TokenBufferMemory | None:
if not node_data_memory or not conversation_id:
return None
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
class DefaultWorkflowCodeExecutor:
def execute(
self,
@ -221,6 +245,7 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return QuestionClassifierNode(
id=node_id,
config=node_config,
@ -229,10 +254,12 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return ParameterExtractorNode(
id=node_id,
config=node_config,
@ -241,6 +268,7 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
return node_class(
@ -295,8 +323,14 @@ class DifyNodeFactory(NodeFactory):
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
return llm_utils.fetch_memory(
variable_pool=self.graph_runtime_state.variable_pool,
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
conversation_id = (
conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None
)
return fetch_memory(
conversation_id=conversation_id,
app_id=self.graph_init_params.app_id,
node_data_memory=node_memory,
model_instance=model_instance,

View File

@ -4,10 +4,10 @@ from collections.abc import Mapping
from typing import Any, cast
from configs import dify_config
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types as mcp_types
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService

View File

@ -0,0 +1,3 @@
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]

View File

@ -0,0 +1,18 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Protocol
from core.model_runtime.entities import PromptMessage
DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000
class PromptMessageMemory(Protocol):
"""Port for loading memory as prompt messages."""
def get_history_prompt_messages(
self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""Return historical prompt messages constrained by token/message limits."""
...

View File

@ -2,6 +2,7 @@ import tempfile
from binascii import hexlify, unhexlify
from collections.abc import Generator
from core.app.llm import deduct_llm_quota
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import (
@ -29,7 +30,6 @@ from core.plugin.entities.request import (
)
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm import llm_utils
from models.account import Tenant
@ -63,16 +63,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
@ -126,16 +124,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(
response: LLMResultWithStructuredOutput,

View File

@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None, # ty: ignore [invalid-argument-type]
content=None, # ty: ignore [invalid-argument-type]
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'",
)
@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
@ -249,7 +249,7 @@ class AnalyticdbVectorOpenAPI:
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None, # ty: ignore [invalid-argument-type]
content=None,
top_k=kwargs.get("top_k", 4),
filter=where_clause,
)
@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None, # ty: ignore [invalid-argument-type]
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=where_clause,

View File

@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)

View File

@ -8,6 +8,7 @@ from typing import Any, cast
logger = logging.getLogger(__name__)
from core.app.llm import deduct_llm_quota
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
@ -35,7 +36,6 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
from core.workflow.nodes.llm import llm_utils
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs import helper
@ -474,7 +474,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# Deduct quota for summary generation (same as workflow nodes)
try:
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
except Exception as e:
# Log but don't fail summary generation if quota deduction fails
logger.warning("Failed to deduct quota for summary generation: %s", str(e))

View File

@ -2,6 +2,7 @@ from collections.abc import Generator, Sequence
from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.llm import deduct_llm_quota
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
@ -9,7 +10,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm import llm_utils
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@ -162,7 +162,7 @@ class ReactMultiDatasetRouter:
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
return text, usage

View File

@ -6,9 +6,9 @@ identity:
zh_Hans: 网页抓取
pt_BR: WebScraper
description:
en_US: Web Scrapper tool kit is used to scrape web
en_US: Web Scraper tool kit is used to scrape web
zh_Hans: 一个用于抓取网页的工具。
pt_BR: Web Scrapper tool kit is used to scrape web
pt_BR: Web Scraper tool kit is used to scrape web
icon: icon.svg
tags:
- productivity

View File

@ -1,11 +1,11 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import OutputVariableEntity
from core.workflow.variables.input_entities import VariableEntity
class WorkflowToolConfigurationUtils:

View File

@ -5,7 +5,6 @@ from collections.abc import Mapping
from pydantic import Field
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.db.session_factory import session_factory
from core.plugin.entities.parameters import PluginParameterOption
@ -23,6 +22,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode

View File

@ -9,7 +9,6 @@ from __future__ import annotations
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
@ -77,13 +76,10 @@ class GraphEngine:
config: GraphEngineConfig = _DEFAULT_CONFIG,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
self._stop_event = threading.Event()
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
self._config = config
@ -163,7 +159,6 @@ class GraphEngine:
layers=self._layers,
execution_context=execution_context,
config=self._config,
stop_event=self._stop_event,
)
# === Orchestration ===
@ -194,7 +189,6 @@ class GraphEngine:
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
stop_event=self._stop_event,
)
# === Validation ===
@ -314,7 +308,6 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
deferred_nodes: list[str] = []
if resume:
@ -348,7 +341,6 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self._stop_event.set()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it

View File

@ -44,7 +44,6 @@ class Dispatcher:
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
stop_event: threading.Event,
event_emitter: EventManager | None = None,
) -> None:
"""
@ -62,7 +61,7 @@ class Dispatcher:
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = stop_event
self._stop_event = threading.Event()
self._start_time: float | None = None
def start(self) -> None:
@ -70,12 +69,14 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=2.0)

View File

@ -42,7 +42,6 @@ class Worker(threading.Thread):
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
execution_context: IExecutionContext | None = None,
) -> None:
@ -63,16 +62,13 @@ class Worker(threading.Thread):
self._graph = graph
self._worker_id = worker_id
self._execution_context = execution_context
self._stop_event = stop_event
self._stop_event = threading.Event()
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
def stop(self) -> None:
"""Worker is controlled via shared stop_event from GraphEngine.
This method is a no-op retained for backward compatibility.
"""
pass
"""Signal the worker to stop processing."""
self._stop_event.set()
@property
def is_idle(self) -> bool:

View File

@ -37,7 +37,6 @@ class WorkerPool:
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
config: GraphEngineConfig,
execution_context: IExecutionContext | None = None,
) -> None:
@ -64,7 +63,6 @@ class WorkerPool:
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
self._stop_event = stop_event
# No longer tracking worker states with callbacks to avoid lock contention
@ -135,7 +133,6 @@ class WorkerPool:
layers=self._layers,
worker_id=worker_id,
execution_context=self._execution_context,
stop_event=self._stop_event,
)
worker.start()

View File

@ -302,10 +302,6 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
def _should_stop(self) -> bool:
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
@ -374,21 +370,6 @@ class Node(Generic[NodeDataT]):
yield event
else:
yield event
if self._should_stop():
error_message = "Execution cancelled"
yield NodeRunFailedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
),
error=error_message,
)
return
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(

View File

@ -588,6 +588,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@ -642,5 +643,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@ -1,26 +1,19 @@
from collections.abc import Sequence
from typing import cast
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities import PromptMessageRole
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.enums import SystemVariableKey
from core.workflow.file.models import File
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
from .exc import InvalidVariableTypeError
@ -48,88 +41,51 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
def convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
used_quota = 1
continue
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def fetch_memory_text(
*,
memory: PromptMessageMemory,
max_token_limit: int,
message_limit: int | None = None,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
) -> str:
history_messages = memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit,
)
return convert_history_messages_to_text(
history_messages=history_messages,
human_prefix=human_prefix,
ai_prefix=ai_prefix,
)

View File

@ -37,6 +37,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -62,7 +63,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import VariablePool
from core.workflow.variables import (
ArrayFileSegment,
@ -278,8 +279,6 @@ class LLMNode(Node[LLMNodeData]):
else None
)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
elif isinstance(event, LLMStructuredOutput):
structured_output = event
@ -1234,6 +1233,10 @@ class LLMNode(Node[LLMNodeData]):
def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
@ -1336,48 +1339,16 @@ def _handle_memory_completion_mode(
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_messages = memory.get_history_prompt_messages(
memory_text = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
memory_text = _convert_history_messages_to_text(
history_messages=memory_messages,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,

View File

@ -1,10 +1,8 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Protocol
from core.model_manager import ModelInstance
from core.model_runtime.entities import PromptMessage
class CredentialsProvider(Protocol):
@ -21,13 +19,3 @@ class ModelFactory(Protocol):
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for schema lookup and invocation."""
...
class PromptMessageMemory(Protocol):
"""Port for loading memory as prompt messages for LLM nodes."""
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""Return historical prompt messages constrained by token/message limits."""
...

View File

@ -413,6 +413,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@ -454,5 +455,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@ -5,7 +5,6 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMUsage
@ -18,13 +17,18 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.file import File
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
@ -97,6 +101,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
_model_instance: ModelInstance
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_memory: PromptMessageMemory | None
def __init__(
self,
@ -108,6 +113,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
) -> None:
super().__init__(
id=id,
@ -118,6 +124,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -163,13 +170,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
memory = self._memory
if (
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
@ -308,9 +309,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage, tool_call
def _generate_function_call_prompt(
@ -319,7 +317,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
@ -407,7 +405,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -445,7 +443,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -470,7 +468,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
files=files,
context="",
memory_config=node_data.memory,
memory=memory,
# AdvancedPromptTransform is still typed against TokenBufferMemory.
memory=cast(Any, memory),
model_instance=model_instance,
image_detail_config=vision_detail,
)
@ -483,7 +482,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -715,7 +714,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
@ -724,8 +723,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
@ -742,7 +741,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@ -751,8 +750,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
@ -828,6 +827,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return rest_tokens
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -3,9 +3,9 @@ import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -56,6 +56,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
def __init__(
self,
@ -67,6 +68,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@ -81,6 +83,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@ -103,13 +106,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
variables = {"query": query}
# fetch model instance
model_instance = self._model_instance
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
memory = self._memory
# fetch instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
@ -240,6 +237,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
llm_usage=usage,
)
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@ -323,7 +324,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@ -336,7 +337,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
input_text = query
memory_str = ""
if memory:
memory_str = memory.get_history_prompt_text(
memory_str = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)

View File

@ -2,8 +2,8 @@ from collections.abc import Sequence
from pydantic import Field
from core.app.app_config.entities import VariableEntity
from core.workflow.nodes.base import BaseNodeData
from core.workflow.variables.input_entities import VariableEntity
class StartNodeData(BaseNodeData):

View File

@ -2,12 +2,12 @@ from typing import Any
from jsonschema import Draft7Validator, ValidationError
from core.app.app_config.entities import VariableEntityType
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.variables.input_entities import VariableEntityType
class StartNode(Node[StartNodeData]):

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import importlib
import json
import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
@ -219,8 +218,6 @@ class GraphRuntimeState:
self._pending_graph_node_states: dict[str, NodeState] | None = None
self._pending_graph_edge_states: dict[str, NodeState] | None = None
self.stop_event: threading.Event = threading.Event()
if graph is not None:
self.attach_graph(graph)

View File

@ -1,3 +1,4 @@
from .input_entities import VariableEntity, VariableEntityType
from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
@ -64,4 +65,6 @@ __all__ = [
"StringVariable",
"Variable",
"VariableBase",
"VariableEntity",
"VariableEntityType",
]

View File

@ -0,0 +1,62 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Any
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator
from core.workflow.file import FileTransferMethod, FileType
class VariableEntityType(StrEnum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
JSON_OBJECT = "json_object"
class VariableEntity(BaseModel):
"""
Shared variable entity used by workflow runtime and app configuration.
"""
# `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""
type: VariableEntityType
required: bool = False
hide: bool = False
default: Any = None
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict[str, Any] | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
def convert_none_description(cls, value: Any) -> str:
return value or ""
@field_validator("options", mode="before")
@classmethod
def convert_none_options(cls, value: Any) -> Sequence[str]:
return value or []
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
except SchemaError as error:
raise ValueError(f"Invalid JSON schema: {error.message}")
return schema

View File

@ -6,6 +6,7 @@ from typing import Any, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
@ -106,6 +107,7 @@ class WorkflowEntry:
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
self.graph_engine.layer(limits_layer)
self.graph_engine.layer(LLMQuotaLayer())
# Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():

View File

@ -0,0 +1,48 @@
"""Helpers for producing concise pyrefly diagnostics for CI diff output."""
from __future__ import annotations
import sys
_DIAGNOSTIC_PREFIXES = ("ERROR ", "WARNING ")
_LOCATION_PREFIX = "-->"
def extract_diagnostics(raw_output: str) -> str:
"""Extract stable diagnostic lines from pyrefly output.
The full pyrefly output includes code excerpts and carets, which create noisy
diffs. This helper keeps only:
- diagnostic headline lines (``ERROR ...`` / ``WARNING ...``)
- the following location line (``--> path:line:column``), when present
"""
lines = raw_output.splitlines()
diagnostics: list[str] = []
for index, line in enumerate(lines):
if line.startswith(_DIAGNOSTIC_PREFIXES):
diagnostics.append(line.rstrip())
next_index = index + 1
if next_index < len(lines):
next_line = lines[next_index]
if next_line.lstrip().startswith(_LOCATION_PREFIX):
diagnostics.append(next_line.rstrip())
if not diagnostics:
return ""
return "\n".join(diagnostics) + "\n"
def main() -> int:
"""Read pyrefly output from stdin and print normalized diagnostics."""
raw_output = sys.stdin.read()
sys.stdout.write(extract_diagnostics(raw_output))
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -787,7 +787,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
__tablename__ = "workflow_node_executions"
@declared_attr
@declared_attr.directive
@classmethod
def __table_args__(cls) -> Any:
return (

View File

@ -68,7 +68,7 @@ dependencies = [
"pydantic~=2.12.5",
"pydantic-extra-types~=2.10.3",
"pydantic-settings~=2.12.0",
"pyjwt~=2.10.1",
"pyjwt~=2.11.0",
"pypdfium2==5.2.0",
"python-docx~=1.2.0",
"python-dotenv==1.0.1",
@ -116,7 +116,6 @@ dev = [
"dotenv-linter~=0.5.0",
"faker~=38.2.0",
"lxml-stubs~=0.5.1",
"ty>=0.0.14",
"basedpyright~=1.31.0",
"ruff~=0.14.0",
"pytest~=8.3.2",
@ -125,7 +124,7 @@ dev = [
"pytest-env~=1.1.3",
"pytest-mock~=3.14.0",
"testcontainers~=4.13.2",
"types-aiofiles~=24.1.0",
"types-aiofiles~=25.1.0",
"types-beautifulsoup4~=4.12.0",
"types-cachetools~=5.5.0",
"types-colorama~=0.4.15",

View File

@ -29,7 +29,7 @@ from typing import Any, cast
import sqlalchemy as sa
from pydantic import ValidationError
from sqlalchemy import and_, delete, func, null, or_, select
from sqlalchemy import and_, delete, func, null, or_, select, tuple_
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, selectinload, sessionmaker
@ -423,9 +423,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
if last_seen:
stmt = stmt.where(
or_(
WorkflowRun.created_at > last_seen[0],
and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
tuple_(WorkflowRun.created_at, WorkflowRun.id)
> tuple_(
sa.literal(last_seen[0], type_=sa.DateTime()),
sa.literal(last_seen[1], type_=WorkflowRun.id.type),
)
)

View File

@ -8,7 +8,6 @@ from core.app.app_config.entities import (
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
)
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
@ -20,6 +19,7 @@ from core.prompt.simple_prompt_transform import SimplePromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file.models import FileUploadConfig
from core.workflow.nodes import NodeType
from core.workflow.variables.input_entities import VariableEntity
from events.app_event import app_was_created
from extensions.ext_database import db
from models import Account

View File

@ -9,7 +9,6 @@ from sqlalchemy import exists, select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -40,6 +39,7 @@ from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import load_into_variable_pool
from core.workflow.variables import VariableBase
from core.workflow.variables.input_entities import VariableEntityType
from core.workflow.variables.variables import Variable
from core.workflow.workflow_entry import WorkflowEntry
from enums.cloud_plan import CloudPlan

View File

@ -1,3 +1,4 @@
import json
import logging
import time
@ -125,7 +126,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
data_source_info = document.data_source_info_dict
data_source_info["last_edited_time"] = last_edited_time
document.data_source_info = data_source_info
document.data_source_info = json.dumps(data_source_info)
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_manager import ModelInstance
from core.model_runtime.entities import AssistantPromptMessage
from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
@ -22,19 +22,17 @@ from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_mod
def get_mocked_fetch_memory(memory_text: str):
class MemoryMock:
def get_history_prompt_text(
def get_history_prompt_messages(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
):
return memory_text
return [UserPromptMessage(content=memory_text), AssistantPromptMessage(content="mocked answer")]
return MagicMock(return_value=MemoryMock())
def init_parameter_extractor_node(config: dict):
def init_parameter_extractor_node(config: dict, memory=None):
graph_config = {
"edges": [
{
@ -79,6 +77,7 @@ def init_parameter_extractor_node(config: dict):
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
memory=memory,
)
return node
@ -350,7 +349,7 @@ def test_extract_json_from_tool_call():
assert result["location"] == "kawaii"
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
def test_chat_parameter_extractor_with_memory(setup_model_mock):
"""
Test chat parameter extractor with memory.
"""
@ -373,6 +372,7 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
"memory": {"window": {"enabled": True, "size": 50}},
},
},
memory=get_mocked_fetch_memory("customized memory")(),
)
node._model_instance = get_mocked_fetch_model_instance(
@ -381,8 +381,6 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)()
# Test the mock before running the actual test
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
db.session.close = MagicMock()
result = node._run()

View File

@ -10,11 +10,10 @@ from core.app.app_config.entities import (
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.model_runtime.entities.llm_entities import LLMMode
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models import Account, Tenant
from models.api_based_extension import APIBasedExtension
from models.model import App, AppMode, AppModelConfig

View File

@ -147,8 +147,7 @@ class TestDisableSegmentsFromIndexTask:
document.cleaning_completed_at = fake.date_time_this_year()
document.splitting_completed_at = fake.date_time_this_year()
document.tokens = fake.random_int(min=50, max=500)
document.indexing_started_at = fake.date_time_this_year()
document.indexing_completed_at = fake.date_time_this_year()
document.completed_at = fake.date_time_this_year()
document.indexing_status = "completed"
document.enabled = True
document.archived = False

View File

@ -12,8 +12,6 @@ from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from psycopg2.extensions import register_adapter
from psycopg2.extras import Json
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -21,12 +19,6 @@ from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_sync_task import document_indexing_sync_task
@pytest.fixture(autouse=True)
def _register_dict_adapter_for_psycopg2():
"""Align test DB adapter behavior with dict payloads used in task update flow."""
register_adapter(dict, Json)
class DocumentIndexingSyncTaskTestDataFactory:
"""Create real DB entities for document indexing sync integration tests."""

View File

@ -1,7 +1,7 @@
import pytest
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.base_app_generator import BaseAppGenerator
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
def test_validate_inputs_with_zero():

View File

@ -4,7 +4,6 @@ from unittest.mock import Mock, patch
import jsonschema
import pytest
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types
from core.mcp.server.streamable_http import (
@ -19,6 +18,7 @@ from core.mcp.server.streamable_http import (
prepare_tool_arguments,
process_mapping_response,
)
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models.model import App, AppMCPServer, AppMode, EndUser

View File

@ -0,0 +1,174 @@
import threading
from datetime import datetime
from unittest.mock import MagicMock, patch
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.errors.error import QuotaExceededError
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.commands import CommandType
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
def _build_succeeded_event() -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="execution-id",
node_id="llm-node-id",
node_type=NodeType.LLM,
start_at=datetime.now(),
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"question": "hello"},
llm_usage=LLMUsage.empty_usage(),
),
)
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
usage=result_event.node_run_result.llm_usage,
)
def test_deduct_quota_called_for_question_classifier_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "question-classifier-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.QUESTION_CLASSIFIER
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
usage=result_event.node_run_result.llm_usage,
)
def test_non_llm_node_is_ignored() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "start-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.START
node.tenant_id = "tenant-id"
node._model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_not_called()
def test_quota_error_is_handled_in_layer() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=ValueError("quota exceeded"),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=QuotaExceededError("No credits remaining"),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
assert stop_event.is_set()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "No credits remaining"
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = NodeType.LLM
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
):
layer.on_node_run_start(node)
assert stop_event.is_set()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "Model provider openai quota exceeded."
def test_quota_precheck_passes_without_abort() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = NodeType.LLM
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert not stop_event.is_set()
mock_check.assert_called_once_with(model_instance=node.model_instance)
layer.command_channel.send_command.assert_not_called()

View File

@ -3,7 +3,6 @@
from __future__ import annotations
import queue
import threading
from unittest import mock
from core.workflow.entities.pause_reason import SchedulingPause
@ -37,7 +36,6 @@ def test_dispatcher_should_consume_remains_events_after_pause():
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=execution_coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()
assert event_queue.empty()
@ -98,7 +96,6 @@ def _run_dispatcher_for_event(event) -> int:
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()
@ -184,7 +181,6 @@ def test_dispatcher_drain_event_queue():
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()

View File

@ -1,5 +1,4 @@
import queue
import threading
from datetime import datetime
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -65,7 +64,6 @@ def test_dispatcher_drains_events_when_paused() -> None:
event_handler=handler,
execution_coordinator=coordinator,
event_emitter=None,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()

View File

@ -1,550 +0,0 @@
"""
Unit tests for stop_event functionality in GraphEngine.
Tests the unified stop_event management by GraphEngine and its propagation
to WorkerPool, Worker, Dispatcher, and Nodes.
"""
import threading
import time
from unittest.mock import MagicMock, Mock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import (
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunStartedEvent,
)
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
class TestStopEventPropagation:
"""Test suite for stop_event propagation through GraphEngine components."""
def test_graph_engine_creates_stop_event(self):
"""Test that GraphEngine creates a stop_event on initialization."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Verify stop_event was created
assert engine._stop_event is not None
assert isinstance(engine._stop_event, threading.Event)
# Verify it was set in graph_runtime_state
assert runtime_state.stop_event is not None
assert runtime_state.stop_event is engine._stop_event
def test_stop_event_cleared_on_start(self):
"""Test that stop_event is cleared when execution starts."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Set the stop_event before running
engine._stop_event.set()
assert engine._stop_event.is_set()
# Run the engine (should clear the stop_event)
events = list(engine.run())
# After running, stop_event should be set again (by _stop_execution)
# But during start it was cleared
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
def test_stop_event_set_on_stop(self):
"""Test that stop_event is set when execution stops."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Initially not set
assert not engine._stop_event.is_set()
# Run the engine
list(engine.run())
# After execution completes, stop_event should be set
assert engine._stop_event.is_set()
def test_stop_event_passed_to_worker_pool(self):
"""Test that stop_event is passed to WorkerPool."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Verify WorkerPool has the stop_event
assert engine._worker_pool._stop_event is not None
assert engine._worker_pool._stop_event is engine._stop_event
def test_stop_event_passed_to_dispatcher(self):
"""Test that stop_event is passed to Dispatcher."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Verify Dispatcher has the stop_event
assert engine._dispatcher._stop_event is not None
assert engine._dispatcher._stop_event is engine._stop_event
class TestNodeStopCheck:
"""Test suite for Node._should_stop() functionality."""
def test_node_should_stop_checks_runtime_state(self):
"""Test that Node._should_stop() checks GraphRuntimeState.stop_event."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
# Initially stop_event is not set
assert not answer_node._should_stop()
# Set the stop_event
runtime_state.stop_event.set()
# Now _should_stop should return True
assert answer_node._should_stop()
def test_node_run_checks_stop_event_between_yields(self):
"""Test that Node.run() checks stop_event between yielding events."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
# Create a simple node
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
# Set stop_event BEFORE running the node
runtime_state.stop_event.set()
# Run the node - should yield start event then detect stop
# The node should check stop_event before processing
assert answer_node._should_stop(), "stop_event should be set"
# Run and collect events
events = list(answer_node.run())
# Since stop_event is set at the start, we should get:
# 1. NodeRunStartedEvent (always yielded first)
# 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast)
assert len(events) >= 2
assert isinstance(events[0], NodeRunStartedEvent)
# Note: AnswerNode is very simple and might complete before stop check
# The important thing is that _should_stop() returns True when stop_event is set
assert answer_node._should_stop()
class TestStopEventIntegration:
"""Integration tests for stop_event in workflow execution."""
def test_simple_workflow_respects_stop_event(self):
"""Test that a simple workflow respects stop_event."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start"
# Create start and answer nodes
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.nodes["answer"] = answer_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Set stop_event before running
runtime_state.stop_event.set()
# Run the engine
events = list(engine.run())
# Should get started event but not succeeded (due to stop)
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
# The workflow should still complete (start node runs quickly)
# but answer node might be cancelled depending on timing
def test_stop_event_with_concurrent_nodes(self):
"""Test stop_event behavior with multiple concurrent nodes."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
# Create multiple nodes
for i in range(3):
answer_node = AnswerNode(
id=f"answer_{i}",
config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes[f"answer_{i}"] = answer_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# All nodes should share the same stop_event
for node in mock_graph.nodes.values():
assert node.graph_runtime_state.stop_event is runtime_state.stop_event
assert node.graph_runtime_state.stop_event is engine._stop_event
class TestStopEventTimeoutBehavior:
"""Test stop_event behavior with join timeouts."""
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread", autospec=True)
def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock):
"""Test that Dispatcher uses 2s timeout instead of 10s."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
dispatcher = engine._dispatcher
dispatcher.start() # This will create and start the mocked thread
mock_thread_instance = mock_thread_cls.return_value
mock_thread_instance.is_alive.return_value = True
dispatcher.stop()
mock_thread_instance.join.assert_called_once_with(timeout=2.0)
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker", autospec=True)
def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock):
"""Test that WorkerPool uses 2s timeout instead of 10s."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
worker_pool = engine._worker_pool
worker_pool.start(initial_count=1) # Start with one worker
mock_worker_instance = mock_worker_cls.return_value
mock_worker_instance.is_alive.return_value = True
worker_pool.stop()
mock_worker_instance.join.assert_called_once_with(timeout=2.0)
class TestStopEventResumeBehavior:
"""Test stop_event behavior during workflow resume."""
def test_stop_event_cleared_on_resume(self):
"""Test that stop_event is cleared when resuming a paused workflow."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Simulate a previous execution that set stop_event
engine._stop_event.set()
assert engine._stop_event.is_set()
# Run the engine (should clear stop_event in _start_execution)
events = list(engine.run())
# Execution should complete successfully
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
class TestWorkerStopBehavior:
"""Test Worker behavior with shared stop_event."""
def test_worker_uses_shared_stop_event(self):
"""Test that Worker uses shared stop_event from GraphEngine."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Get the worker pool and check workers
worker_pool = engine._worker_pool
# Start the worker pool to create workers
worker_pool.start()
# Check that at least one worker was created
assert len(worker_pool._workers) > 0
# Verify workers use the shared stop_event
for worker in worker_pool._workers:
assert worker._stop_event is engine._stop_event
# Clean up
worker_pool.stop()
def test_worker_stop_is_noop(self):
"""Test that Worker.stop() is now a no-op."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
# Create a mock worker
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
from core.workflow.graph_engine.worker import Worker
ready_queue = InMemoryReadyQueue()
event_queue = MagicMock()
# Create a proper mock graph with real dict
mock_graph = Mock(spec=Graph)
mock_graph.nodes = {} # Use real dict
stop_event = threading.Event()
worker = Worker(
ready_queue=ready_queue,
event_queue=event_queue,
graph=mock_graph,
layers=[],
stop_event=stop_event,
)
# Calling stop() should do nothing (no-op)
# and should NOT set the stop_event
worker.stop()
assert not stop_event.is_set()

View File

@ -4,12 +4,12 @@ import time
import pytest
from pydantic import ValidationError as PydanticValidationError
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.workflow.entities import GraphInitParams
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
def make_start_node(user_inputs, variables):

View File

@ -0,0 +1,51 @@
from libs.pyrefly_diagnostics import extract_diagnostics
def test_extract_diagnostics_keeps_only_summary_and_location_lines() -> None:
# Arrange
raw_output = """INFO Checking project configured at `/tmp/project/pyrefly.toml`
ERROR `result` may be uninitialized [unbound-name]
--> controllers/console/app/annotation.py:126:16
|
126 | return result, 200
| ^^^^^^
|
ERROR Object of class `App` has no attribute `access_mode` [missing-attribute]
--> controllers/console/app/app.py:574:13
|
574 | app_model.access_mode = app_setting.access_mode
| ^^^^^^^^^^^^^^^^^^^^^
"""
# Act
diagnostics = extract_diagnostics(raw_output)
# Assert
assert diagnostics == (
"ERROR `result` may be uninitialized [unbound-name]\n"
" --> controllers/console/app/annotation.py:126:16\n"
"ERROR Object of class `App` has no attribute `access_mode` [missing-attribute]\n"
" --> controllers/console/app/app.py:574:13\n"
)
def test_extract_diagnostics_handles_error_without_location_line() -> None:
# Arrange
raw_output = "ERROR unexpected pyrefly output format [bad-format]\n"
# Act
diagnostics = extract_diagnostics(raw_output)
# Assert
assert diagnostics == "ERROR unexpected pyrefly output format [bad-format]\n"
def test_extract_diagnostics_returns_empty_for_non_error_output() -> None:
# Arrange
raw_output = "INFO Checking project configured at `/tmp/project/pyrefly.toml`\n"
# Act
diagnostics = extract_diagnostics(raw_output)
# Assert
assert diagnostics == ""

View File

@ -13,12 +13,11 @@ from core.app.app_config.entities import (
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from models.model import AppMode
from services.workflow.workflow_converter import WorkflowConverter

View File

@ -5,6 +5,7 @@ These tests intentionally stay in unit scope because they validate call argument
for external collaborators rather than SQL-backed state transitions.
"""
import json
import uuid
from unittest.mock import MagicMock, Mock, patch
@ -196,3 +197,78 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
class TestDataSourceInfoSerialization:
"""Regression test: data_source_info must be written as a JSON string, not a raw dict.
See https://github.com/langgenius/dify/issues/32705
psycopg2 raises ``ProgrammingError: can't adapt type 'dict'`` when a Python
dict is passed directly to a text/LongText column.
"""
def test_data_source_info_serialized_as_json_string(
self,
mock_document,
mock_dataset,
dataset_id,
document_id,
):
"""data_source_info must be serialized with json.dumps before DB write."""
with (
patch("tasks.document_indexing_sync_task.session_factory") as mock_session_factory,
patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class,
patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class,
patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_ipf,
patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class,
):
# External collaborators
mock_service = MagicMock()
mock_service.get_datasource_credentials.return_value = {"integration_secret": "token"}
mock_service_class.return_value = mock_service
mock_extractor = MagicMock()
# Return a *different* timestamp so the task enters the sync/update branch
mock_extractor.get_notion_last_edited_time.return_value = "2024-02-01T00:00:00Z"
mock_extractor_class.return_value = mock_extractor
mock_ip = MagicMock()
mock_ipf.return_value.init_index_processor.return_value = mock_ip
mock_runner = MagicMock()
mock_runner_class.return_value = mock_runner
# DB session mock — shared across all ``session_factory.create_session()`` calls
session = MagicMock()
session.scalars.return_value.all.return_value = []
# .where() path: session 1 reads document + dataset, session 2 reads dataset
session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
]
# .filter_by() path: session 3 (update), session 4 (indexing)
session.query.return_value.filter_by.return_value.first.side_effect = [
mock_document,
mock_document,
]
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
begin_cm.__exit__.return_value = False
session.begin.return_value = begin_cm
session_cm = MagicMock()
session_cm.__enter__.return_value = session
session_cm.__exit__.return_value = False
mock_session_factory.create_session.return_value = session_cm
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert: data_source_info must be a JSON *string*, not a dict
assert isinstance(mock_document.data_source_info, str), (
f"data_source_info should be a JSON string, got {type(mock_document.data_source_info).__name__}"
)
parsed = json.loads(mock_document.data_source_info)
assert parsed["last_edited_time"] == "2024-02-01T00:00:00Z"

View File

@ -1,50 +0,0 @@
[src]
exclude = [
# deps groups (A1/A2/B/C/D/E)
# B: app runner + prompt
"core/prompt",
"core/app/apps/base_app_runner.py",
"core/app/apps/workflow_app_runner.py",
"core/agent",
"core/plugin",
# C: services/controllers/fields/libs
"services",
"controllers/inner_api",
"controllers/console/app",
"controllers/console/explore",
"controllers/console/datasets",
"controllers/console/workspace",
"controllers/service_api/wraps.py",
"fields/conversation_fields.py",
"libs/external_api.py",
# D: observability + integrations
"core/ops",
"extensions",
# E: vector DB integrations
"core/rag/datasource/vdb",
# non-producition or generated code
"migrations",
"tests",
# targeted ignores for current type-check errors
# TODO(QuantumGhost): suppress type errors in HITL related code.
# fix the type error later
"configs/middleware/cache/redis_pubsub_config.py",
"extensions/ext_redis.py",
"models/execution_extra_content.py",
"tasks/workflow_execution_tasks.py",
"core/workflow/nodes/base/node.py",
"services/human_input_delivery_test_service.py",
"core/app/apps/advanced_chat/app_generator.py",
"controllers/console/human_input_form.py",
"controllers/console/app/workflow_run.py",
"repositories/sqlalchemy_api_workflow_node_execution_repository.py",
"extensions/logstore/repositories/logstore_api_workflow_run_repository.py",
"controllers/web/workflow_events.py",
"tasks/app_generate/workflow_execute_task.py",
]
[rules]
deprecated = "ignore"
unused-ignore-comment = "ignore"
# possibly-missing-attribute = "ignore"

42
api/uv.lock generated
View File

@ -1483,7 +1483,6 @@ dev = [
{ name = "scipy-stubs" },
{ name = "sseclient-py" },
{ name = "testcontainers" },
{ name = "ty" },
{ name = "types-aiofiles" },
{ name = "types-beautifulsoup4" },
{ name = "types-cachetools" },
@ -1637,7 +1636,7 @@ requires-dist = [
{ name = "pydantic", specifier = "~=2.12.5" },
{ name = "pydantic-extra-types", specifier = "~=2.10.3" },
{ name = "pydantic-settings", specifier = "~=2.12.0" },
{ name = "pyjwt", specifier = "~=2.10.1" },
{ name = "pyjwt", specifier = "~=2.11.0" },
{ name = "pypdfium2", specifier = "==5.2.0" },
{ name = "python-docx", specifier = "~=1.2.0" },
{ name = "python-dotenv", specifier = "==1.0.1" },
@ -1684,8 +1683,7 @@ dev = [
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
{ name = "sseclient-py", specifier = ">=1.8.0" },
{ name = "testcontainers", specifier = "~=4.13.2" },
{ name = "ty", specifier = ">=0.0.14" },
{ name = "types-aiofiles", specifier = "~=24.1.0" },
{ name = "types-aiofiles", specifier = "~=25.1.0" },
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
{ name = "types-cachetools", specifier = "~=5.5.0" },
{ name = "types-cffi", specifier = ">=1.17.0" },
@ -4959,11 +4957,11 @@ wheels = [
[[package]]
name = "pyjwt"
version = "2.10.1"
version = "2.11.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" }
sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
{ url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" },
]
[package.optional-dependencies]
@ -6278,30 +6276,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/70/26/2591b48412bde75e33bfd292034103ffe41743cacd03120e3242516cd143/transformers-4.56.2-py3-none-any.whl", hash = "sha256:79c03d0e85b26cb573c109ff9eafa96f3c8d4febfd8a0774e8bba32702dd6dde", size = 11608055, upload-time = "2025-09-19T15:16:23.736Z" },
]
[[package]]
name = "ty"
version = "0.0.14"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/af/57/22c3d6bf95c2229120c49ffc2f0da8d9e8823755a1c3194da56e51f1cc31/ty-0.0.14.tar.gz", hash = "sha256:a691010565f59dd7f15cf324cdcd1d9065e010c77a04f887e1ea070ba34a7de2", size = 5036573, upload-time = "2026-01-27T00:57:31.427Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/99/cb/cc6d1d8de59beb17a41f9a614585f884ec2d95450306c173b3b7cc090d2e/ty-0.0.14-py3-none-linux_armv6l.whl", hash = "sha256:32cf2a7596e693094621d3ae568d7ee16707dce28c34d1762947874060fdddaa", size = 10034228, upload-time = "2026-01-27T00:57:53.133Z" },
{ url = "https://files.pythonhosted.org/packages/f3/96/dd42816a2075a8f31542296ae687483a8d047f86a6538dfba573223eaf9a/ty-0.0.14-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f971bf9805f49ce8c0968ad53e29624d80b970b9eb597b7cbaba25d8a18ce9a2", size = 9939162, upload-time = "2026-01-27T00:57:43.857Z" },
{ url = "https://files.pythonhosted.org/packages/ff/b4/73c4859004e0f0a9eead9ecb67021438b2e8e5fdd8d03e7f5aca77623992/ty-0.0.14-py3-none-macosx_11_0_arm64.whl", hash = "sha256:45448b9e4806423523268bc15e9208c4f3f2ead7c344f615549d2e2354d6e924", size = 9418661, upload-time = "2026-01-27T00:58:03.411Z" },
{ url = "https://files.pythonhosted.org/packages/58/35/839c4551b94613db4afa20ee555dd4f33bfa7352d5da74c5fa416ffa0fd2/ty-0.0.14-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee94a9b747ff40114085206bdb3205a631ef19a4d3fb89e302a88754cbbae54c", size = 9837872, upload-time = "2026-01-27T00:57:23.718Z" },
{ url = "https://files.pythonhosted.org/packages/41/2b/bbecf7e2faa20c04bebd35fc478668953ca50ee5847ce23e08acf20ea119/ty-0.0.14-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6756715a3c33182e9ab8ffca2bb314d3c99b9c410b171736e145773ee0ae41c3", size = 9848819, upload-time = "2026-01-27T00:57:58.501Z" },
{ url = "https://files.pythonhosted.org/packages/be/60/3c0ba0f19c0f647ad9d2b5b5ac68c0f0b4dc899001bd53b3a7537fb247a2/ty-0.0.14-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89d0038a2f698ba8b6fec5cf216a4e44e2f95e4a5095a8c0f57fe549f87087c2", size = 10324371, upload-time = "2026-01-27T00:57:29.291Z" },
{ url = "https://files.pythonhosted.org/packages/24/32/99d0a0b37d0397b0a989ffc2682493286aa3bc252b24004a6714368c2c3d/ty-0.0.14-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c64a83a2d669b77f50a4957039ca1450626fb474619f18f6f8a3eb885bf7544", size = 10865898, upload-time = "2026-01-27T00:57:33.542Z" },
{ url = "https://files.pythonhosted.org/packages/1a/88/30b583a9e0311bb474269cfa91db53350557ebec09002bfc3fb3fc364e8c/ty-0.0.14-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:242488bfb547ef080199f6fd81369ab9cb638a778bb161511d091ffd49c12129", size = 10555777, upload-time = "2026-01-27T00:58:05.853Z" },
{ url = "https://files.pythonhosted.org/packages/cd/a2/cb53fb6325dcf3d40f2b1d0457a25d55bfbae633c8e337bde8ec01a190eb/ty-0.0.14-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4790c3866f6c83a4f424fc7d09ebdb225c1f1131647ba8bdc6fcdc28f09ed0ff", size = 10412913, upload-time = "2026-01-27T00:57:38.834Z" },
{ url = "https://files.pythonhosted.org/packages/42/8f/f2f5202d725ed1e6a4e5ffaa32b190a1fe70c0b1a2503d38515da4130b4c/ty-0.0.14-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:950f320437f96d4ea9a2332bbfb5b68f1c1acd269ebfa4c09b6970cc1565bd9d", size = 9837608, upload-time = "2026-01-27T00:57:55.898Z" },
{ url = "https://files.pythonhosted.org/packages/f7/ba/59a2a0521640c489dafa2c546ae1f8465f92956fede18660653cce73b4c5/ty-0.0.14-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4a0ec3ee70d83887f86925bbc1c56f4628bd58a0f47f6f32ddfe04e1f05466df", size = 9884324, upload-time = "2026-01-27T00:57:46.786Z" },
{ url = "https://files.pythonhosted.org/packages/03/95/8d2a49880f47b638743212f011088552ecc454dd7a665ddcbdabea25772a/ty-0.0.14-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a1a4e6b6da0c58b34415955279eff754d6206b35af56a18bb70eb519d8d139ef", size = 10033537, upload-time = "2026-01-27T00:58:01.149Z" },
{ url = "https://files.pythonhosted.org/packages/e9/40/4523b36f2ce69f92ccf783855a9e0ebbbd0f0bb5cdce6211ee1737159ed3/ty-0.0.14-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:dc04384e874c5de4c5d743369c277c8aa73d1edea3c7fc646b2064b637db4db3", size = 10495910, upload-time = "2026-01-27T00:57:26.691Z" },
{ url = "https://files.pythonhosted.org/packages/08/d5/655beb51224d1bfd4f9ddc0bb209659bfe71ff141bcf05c418ab670698f0/ty-0.0.14-py3-none-win32.whl", hash = "sha256:b20e22cf54c66b3e37e87377635da412d9a552c9bf4ad9fc449fed8b2e19dad2", size = 9507626, upload-time = "2026-01-27T00:57:41.43Z" },
{ url = "https://files.pythonhosted.org/packages/b6/d9/c569c9961760e20e0a4bc008eeb1415754564304fd53997a371b7cf3f864/ty-0.0.14-py3-none-win_amd64.whl", hash = "sha256:e312ff9475522d1a33186657fe74d1ec98e4a13e016d66f5758a452c90ff6409", size = 10437980, upload-time = "2026-01-27T00:57:36.422Z" },
{ url = "https://files.pythonhosted.org/packages/ad/0c/186829654f5bfd9a028f6648e9caeb11271960a61de97484627d24443f91/ty-0.0.14-py3-none-win_arm64.whl", hash = "sha256:b6facdbe9b740cb2c15293a1d178e22ffc600653646452632541d01c36d5e378", size = 9885831, upload-time = "2026-01-27T00:57:49.747Z" },
]
[[package]]
name = "typer"
version = "0.20.0"
@ -6319,11 +6293,11 @@ wheels = [
[[package]]
name = "types-aiofiles"
version = "24.1.0.20250822"
version = "25.1.0.20251011"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/19/48/c64471adac9206cc844afb33ed311ac5a65d2f59df3d861e0f2d0cad7414/types_aiofiles-24.1.0.20250822.tar.gz", hash = "sha256:9ab90d8e0c307fe97a7cf09338301e3f01a163e39f3b529ace82466355c84a7b", size = 14484, upload-time = "2025-08-22T03:02:23.039Z" }
sdist = { url = "https://files.pythonhosted.org/packages/84/6c/6d23908a8217e36704aa9c79d99a620f2fdd388b66a4b7f72fbc6b6ff6c6/types_aiofiles-25.1.0.20251011.tar.gz", hash = "sha256:1c2b8ab260cb3cd40c15f9d10efdc05a6e1e6b02899304d80dfa0410e028d3ff", size = 14535, upload-time = "2025-10-11T02:44:51.237Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/8e/5e6d2215e1d8f7c2a94c6e9d0059ae8109ce0f5681956d11bb0a228cef04/types_aiofiles-24.1.0.20250822-py3-none-any.whl", hash = "sha256:0ec8f8909e1a85a5a79aed0573af7901f53120dd2a29771dd0b3ef48e12328b0", size = 14322, upload-time = "2025-08-22T03:02:21.918Z" },
{ url = "https://files.pythonhosted.org/packages/71/0f/76917bab27e270bb6c32addd5968d69e558e5b6f7fb4ac4cbfa282996a96/types_aiofiles-25.1.0.20251011-py3-none-any.whl", hash = "sha256:8ff8de7f9d42739d8f0dadcceeb781ce27cd8d8c4152d4a7c52f6b20edb8149c", size = 14338, upload-time = "2025-10-11T02:44:50.054Z" },
]
[[package]]

View File

@ -149,7 +149,6 @@ services:
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-}
PM2_INSTANCES: ${PM2_INSTANCES:-2}
LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100}
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}

View File

@ -844,7 +844,6 @@ services:
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-}
PM2_INSTANCES: ${PM2_INSTANCES:-2}
LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100}
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}

View File

@ -50,24 +50,18 @@ ENV MARKETPLACE_API_URL=https://marketplace.dify.ai
ENV MARKETPLACE_URL=https://marketplace.dify.ai
ENV PORT=3000
ENV NEXT_TELEMETRY_DISABLED=1
ENV PM2_INSTANCES=2
# set timezone
ENV TZ=UTC
RUN ln -s /usr/share/zoneinfo/${TZ} /etc/localtime \
&& echo ${TZ} > /etc/timezone
# global runtime packages
RUN pnpm add -g pm2
# Create non-root user
ARG dify_uid=1001
RUN addgroup -S -g ${dify_uid} dify && \
adduser -S -u ${dify_uid} -G dify -s /bin/ash -h /home/dify dify && \
mkdir /app && \
mkdir /.pm2 && \
chown -R dify:dify /app /.pm2
chown -R dify:dify /app
WORKDIR /app/web

View File

@ -89,8 +89,6 @@ If you want to customize the host and port:
pnpm run start --port=3001 --host=0.0.0.0
```
If you want to customize the number of instances launched by PM2, you can configure `PM2_INSTANCES` in `docker-compose.yaml` or `Dockerfile`.
## Storybook
This project uses [Storybook](https://storybook.js.org/) for UI component development.

View File

@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import Alert from './alert'
import Alert from '../alert'
describe('Alert', () => {
const defaultProps = {

View File

@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import AppUnavailable from './app-unavailable'
import AppUnavailable from '../app-unavailable'
describe('AppUnavailable', () => {
beforeEach(() => {

View File

@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import Badge from './badge'
import Badge from '../badge'
describe('Badge', () => {
describe('Rendering', () => {

View File

@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import ThemeSelector from './theme-selector'
import ThemeSelector from '../theme-selector'
// Mock next-themes with controllable state
let mockTheme = 'system'

View File

@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import ThemeSwitcher from './theme-switcher'
import ThemeSwitcher from '../theme-switcher'
let mockTheme = 'system'
const mockSetTheme = vi.fn()

View File

@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import { ActionButton, ActionButtonState } from './index'
import { ActionButton, ActionButtonState } from '../index'
describe('ActionButton', () => {
it('renders button with default props', () => {

View File

@ -4,7 +4,7 @@ import type { AgentLogDetailResponse } from '@/models/log'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { ToastContext } from '@/app/components/base/toast'
import { fetchAgentLogDetail } from '@/service/log'
import AgentLogDetail from './detail'
import AgentLogDetail from '../detail'
vi.mock('@/service/log', () => ({
fetchAgentLogDetail: vi.fn(),

View File

@ -3,7 +3,7 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { useClickAway } from 'ahooks'
import { ToastContext } from '@/app/components/base/toast'
import { fetchAgentLogDetail } from '@/service/log'
import AgentLogModal from './index'
import AgentLogModal from '../index'
vi.mock('@/service/log', () => ({
fetchAgentLogDetail: vi.fn(),

View File

@ -1,6 +1,6 @@
import type { AgentIteration } from '@/models/log'
import { render, screen } from '@testing-library/react'
import Iteration from './iteration'
import Iteration from '../iteration'
vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({
default: ({ title, value }: { title: React.ReactNode, value: string | object }) => (

View File

@ -1,6 +1,6 @@
import { render, screen } from '@testing-library/react'
import * as React from 'react'
import ResultPanel from './result'
import ResultPanel from '../result'
vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({
default: ({ title, value }: { title: React.ReactNode, value: string | object }) => (

View File

@ -2,7 +2,7 @@ import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { describe, expect, it, vi } from 'vitest'
import { BlockEnum } from '@/app/components/workflow/types'
import ToolCallItem from './tool-call'
import ToolCallItem from '../tool-call'
vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({
default: ({ title, value }: { title: React.ReactNode, value: string | object }) => (

View File

@ -1,7 +1,7 @@
import type { AgentIteration } from '@/models/log'
import { render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import TracingPanel from './tracing'
import TracingPanel from '../tracing'
vi.mock('@/app/components/workflow/block-icon', () => ({
default: () => <div data-testid="block-icon" />,

View File

@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import AnswerIcon from '.'
import AnswerIcon from '..'
describe('AnswerIcon', () => {
it('renders default emoji when no icon or image is provided', () => {

View File

@ -1,5 +1,5 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import ImageInput from './ImageInput'
import ImageInput from '../ImageInput'
const createObjectURLMock = vi.fn(() => 'blob:mock-url')
const revokeObjectURLMock = vi.fn()

View File

@ -1,5 +1,5 @@
import { act, renderHook } from '@testing-library/react'
import { useDraggableUploader } from './hooks'
import { useDraggableUploader } from '../hooks'
type MockDragEventOverrides = {
dataTransfer?: { files: File[] }

View File

@ -3,7 +3,7 @@ import type { ImageFile } from '@/types/app'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { TransferMethod } from '@/types/app'
import AppIconPicker from './index'
import AppIconPicker from '../index'
import 'vitest-canvas-mock'
type LocalFileUploaderOptions = {
@ -93,7 +93,7 @@ vi.mock('react-easy-crop', () => ({
),
}))
vi.mock('../image-uploader/hooks', () => ({
vi.mock('../../image-uploader/hooks', () => ({
useLocalFileUploader: (options: LocalFileUploaderOptions) => {
mocks.onUpload = options.onUpload
return { handleLocalFileUpload: mocks.handleLocalFileUpload }

View File

@ -1,4 +1,4 @@
import getCroppedImg, { checkIsAnimatedImage, createImage, getMimeType, getRadianAngle, rotateSize } from './utils'
import getCroppedImg, { checkIsAnimatedImage, createImage, getMimeType, getRadianAngle, rotateSize } from '../utils'
type ImageLoadEventType = 'load' | 'error'

View File

@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import AppIcon from './index'
import AppIcon from '../index'
// Mock emoji-mart initialization
vi.mock('emoji-mart', () => ({

View File

@ -2,7 +2,7 @@ import { act, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import i18next from 'i18next'
import { useParams, usePathname } from 'next/navigation'
import AudioBtn from './index'
import AudioBtn from '../index'
const mockPlayAudio = vi.fn()
const mockPauseAudio = vi.fn()

View File

@ -4,7 +4,7 @@ import { vi } from 'vitest'
import useThemeMock from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import AudioPlayer from './AudioPlayer'
import AudioPlayer from '../AudioPlayer'
vi.mock('@/hooks/use-theme', () => ({
default: vi.fn(() => ({ theme: 'light' })),

View File

@ -3,12 +3,12 @@ import * as React from 'react'
// AudioGallery.spec.tsx
import { describe, expect, it, vi } from 'vitest'
import AudioGallery from './index'
import AudioGallery from '../index'
// Mock AudioPlayer so we only assert prop forwarding
const audioPlayerMock = vi.fn()
vi.mock('./AudioPlayer', () => ({
vi.mock('../AudioPlayer', () => ({
default: (props: { srcs: string[] }) => {
audioPlayerMock(props)
return <div data-testid="audio-player" />

View File

@ -1,6 +1,6 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { sleep } from '@/utils'
import AutoHeightTextarea from './index'
import AutoHeightTextarea from '../index'
vi.mock('@/utils', async () => {
const actual = await vi.importActual('@/utils')

View File

@ -1,5 +1,5 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import Avatar from './index'
import Avatar from '../index'
describe('Avatar', () => {
beforeEach(() => {

View File

@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import Badge, { BadgeState, BadgeVariants } from './index'
import Badge, { BadgeState, BadgeVariants } from '../index'
describe('Badge', () => {
describe('Rendering', () => {

View File

@ -1,7 +1,7 @@
import { cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import Toast from '@/app/components/base/toast'
import BlockInput, { getInputKeys } from './index'
import BlockInput, { getInputKeys } from '../index'
vi.mock('@/utils/var', () => ({
checkKeys: vi.fn((_keys: string[]) => ({

View File

@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import AddButton from './add-button'
import AddButton from '../add-button'
describe('AddButton', () => {
describe('Rendering', () => {

View File

@ -1,6 +1,6 @@
import { cleanup, fireEvent, render } from '@testing-library/react'
import * as React from 'react'
import Button from './index'
import Button from '../index'
afterEach(cleanup)
// https://testing-library.com/docs/queries/about

View File

@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import SyncButton from './sync-button'
import SyncButton from '../sync-button'
describe('SyncButton', () => {
describe('Rendering', () => {

View File

@ -1,7 +1,7 @@
import type { Mock } from 'vitest'
import { act, fireEvent, render, screen } from '@testing-library/react'
import useEmblaCarousel from 'embla-carousel-react'
import { Carousel, useCarousel } from './index'
import { Carousel, useCarousel } from '../index'
vi.mock('embla-carousel-react', () => ({
default: vi.fn(),

View File

@ -1,5 +1,5 @@
import type { ChatConfig, ChatItemInTree } from '../types'
import type { ChatWithHistoryContextValue } from './context'
import type { ChatConfig, ChatItemInTree } from '../../types'
import type { ChatWithHistoryContextValue } from '../context'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { AppData, AppMeta, ConversationItem } from '@/models/share'
import type { HumanInputFormData } from '@/types/workflow'
@ -12,17 +12,17 @@ import {
stopChatMessageResponding,
} from '@/service/share'
import { TransferMethod } from '@/types/app'
import { useChat } from '../chat/hooks'
import { useChat } from '../../chat/hooks'
import { isValidGeneratedAnswer } from '../utils'
import ChatWrapper from './chat-wrapper'
import { useChatWithHistoryContext } from './context'
import { isValidGeneratedAnswer } from '../../utils'
import ChatWrapper from '../chat-wrapper'
import { useChatWithHistoryContext } from '../context'
vi.mock('../chat/hooks', () => ({
vi.mock('../../chat/hooks', () => ({
useChat: vi.fn(),
}))
vi.mock('./context', () => ({
vi.mock('../context', () => ({
useChatWithHistoryContext: vi.fn(),
}))
@ -37,7 +37,7 @@ vi.mock('next/navigation', () => ({
useParams: vi.fn(() => ({ token: 'test-token' })),
}))
vi.mock('../utils', () => ({
vi.mock('../../utils', () => ({
isValidGeneratedAnswer: vi.fn(),
getLastAnswer: vi.fn(),
}))

View File

@ -1,12 +1,12 @@
import type { ChatConfig } from '../types'
import type { ChatWithHistoryContextValue } from './context'
import type { ChatConfig } from '../../types'
import type { ChatWithHistoryContextValue } from '../context'
import type { AppData, AppMeta, ConversationItem } from '@/models/share'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { useChatWithHistoryContext } from './context'
import HeaderInMobile from './header-in-mobile'
import { useChatWithHistoryContext } from '../context'
import HeaderInMobile from '../header-in-mobile'
vi.mock('@/hooks/use-breakpoints', () => ({
default: vi.fn(),
@ -17,7 +17,7 @@ vi.mock('@/hooks/use-breakpoints', () => ({
},
}))
vi.mock('./context', () => ({
vi.mock('../context', () => ({
useChatWithHistoryContext: vi.fn(),
ChatWithHistoryContext: { Provider: ({ children }: { children: React.ReactNode }) => <div>{children}</div> },
}))
@ -33,7 +33,7 @@ vi.mock('next/navigation', () => ({
useParams: vi.fn(() => ({})),
}))
vi.mock('../embedded-chatbot/theme/theme-context', () => ({
vi.mock('../../embedded-chatbot/theme/theme-context', () => ({
useThemeContext: vi.fn(() => ({
buildTheme: vi.fn(),
})),

Some files were not shown because too many files have changed in this diff Show More