mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 14:14:17 +08:00
Merge remote-tracking branch 'myori/main' into p363
This commit is contained in:
commit
92ab5dae98
12
.github/workflows/pyrefly-diff.yml
vendored
12
.github/workflows/pyrefly-diff.yml
vendored
@ -29,20 +29,26 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Prepare diagnostics extractor
|
||||
run: |
|
||||
git show ${{ github.event.pull_request.head.sha }}:api/libs/pyrefly_diagnostics.py > /tmp/pyrefly_diagnostics.py
|
||||
|
||||
- name: Run pyrefly on PR branch
|
||||
run: |
|
||||
uv run --directory api pyrefly check > /tmp/pyrefly_pr.txt 2>&1 || true
|
||||
uv run --directory api --dev pyrefly check 2>&1 \
|
||||
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_pr.txt || true
|
||||
|
||||
- name: Checkout base branch
|
||||
run: git checkout ${{ github.base_ref }}
|
||||
|
||||
- name: Run pyrefly on base branch
|
||||
run: |
|
||||
uv run --directory api pyrefly check > /tmp/pyrefly_base.txt 2>&1 || true
|
||||
uv run --directory api --dev pyrefly check 2>&1 \
|
||||
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_base.txt || true
|
||||
|
||||
- name: Compute diff
|
||||
run: |
|
||||
diff /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
|
||||
5
Makefile
5
Makefile
@ -68,10 +68,9 @@ lint:
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (basedpyright + mypy + ty)..."
|
||||
@echo "📝 Running type checks (basedpyright + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@cd api && uv run ty check
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
test:
|
||||
@ -132,7 +131,7 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checks (basedpyright, mypy, ty)"
|
||||
@echo " make type-check - Run type checks (basedpyright, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
|
||||
@ -29,6 +29,8 @@ ignore_imports =
|
||||
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
@ -52,7 +54,6 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
# TODO(QuantumGhost): use DI to avoid depending on global DB.
|
||||
@ -107,14 +108,11 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||
core.workflow.nodes.llm.protocols -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> models.model
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
|
||||
core.workflow.nodes.llm.node -> core.tools.signature
|
||||
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
@ -131,12 +129,10 @@ ignore_imports =
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.start.entities -> core.app.app_config.entities
|
||||
core.workflow.nodes.start.start_node -> core.app.app_config.entities
|
||||
core.workflow.workflow_entry -> core.app.apps.exc
|
||||
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
|
||||
core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
@ -150,7 +146,6 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.node -> core.model_manager
|
||||
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
@ -172,7 +167,6 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
@ -180,7 +174,7 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> extensions.otel.runtime
|
||||
core.workflow.nodes.agent.agent_node -> models
|
||||
core.workflow.nodes.base.node -> models.enums
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider_ids
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.llm.node -> models.model
|
||||
core.workflow.workflow_entry -> models.enums
|
||||
core.workflow.nodes.agent.agent_node -> services
|
||||
|
||||
@ -8,9 +8,9 @@ from sqlalchemy.orm import Session
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import mcp_ns
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types as mcp_types
|
||||
from core.mcp.server.streamable_http import handle_mcp_request
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
[
|
||||
|
||||
@ -2,12 +2,12 @@ from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
||||
from jsonschema import Draft7Validator, SchemaError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.file import FileTransferMethod, FileType, FileUploadConfig
|
||||
from core.workflow.file import FileUploadConfig
|
||||
from core.workflow.variables.input_entities import VariableEntity as WorkflowVariableEntity
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
@ -90,61 +90,7 @@ class PromptTemplateEntity(BaseModel):
|
||||
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
|
||||
|
||||
|
||||
class VariableEntityType(StrEnum):
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
FILE = "file"
|
||||
FILE_LIST = "file-list"
|
||||
CHECKBOX = "checkbox"
|
||||
JSON_OBJECT = "json_object"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
"""
|
||||
Variable Entity.
|
||||
"""
|
||||
|
||||
# `variable` records the name of the variable in user inputs.
|
||||
variable: str
|
||||
label: str
|
||||
description: str = ""
|
||||
type: VariableEntityType
|
||||
required: bool = False
|
||||
hide: bool = False
|
||||
default: Any = None
|
||||
max_length: int | None = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: dict | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
def convert_none_description(cls, v: Any) -> str:
|
||||
return v or ""
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def convert_none_options(cls, v: Any) -> Sequence[str]:
|
||||
return v or []
|
||||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: dict | None) -> dict | None:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
Draft7Validator.check_schema(schema)
|
||||
except SchemaError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||
return schema
|
||||
|
||||
|
||||
class RagPipelineVariableEntity(VariableEntity):
|
||||
class RagPipelineVariableEntity(WorkflowVariableEntity):
|
||||
"""
|
||||
Rag Pipeline Variable Entity.
|
||||
"""
|
||||
@ -314,7 +260,7 @@ class AppConfig(BaseModel):
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
additional_features: AppAdditionalFeatures | None = None
|
||||
variables: list[VariableEntity] = []
|
||||
variables: list[WorkflowVariableEntity] = []
|
||||
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.file import File, FileUploadConfig
|
||||
@ -12,13 +11,14 @@ from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
NoopDraftVariableSaver,
|
||||
)
|
||||
from core.workflow.variables.input_entities import VariableEntityType
|
||||
from factories import file_factory
|
||||
from libs.orjson import orjson_dumps
|
||||
from models import Account, EndUser
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
|
||||
@ -1 +1,5 @@
|
||||
"""LLM-related application services."""
|
||||
|
||||
from .quota import deduct_llm_quota, ensure_llm_quota_available
|
||||
|
||||
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]
|
||||
|
||||
93
api/core/app/llm/quota.py
Normal file
93
api/core/app/llm/quota.py
Normal file
@ -0,0 +1,93 @@
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
provider_model = provider_configuration.get_provider_model(
|
||||
model_type=model_instance.model_type_instance.model_type,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
|
||||
|
||||
|
||||
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
@ -1,9 +1,11 @@
|
||||
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
|
||||
|
||||
from .llm_quota import LLMQuotaLayer
|
||||
from .observability import ObservabilityLayer
|
||||
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
|
||||
__all__ = [
|
||||
"LLMQuotaLayer",
|
||||
"ObservabilityLayer",
|
||||
"PersistenceWorkflowInfo",
|
||||
"WorkflowPersistenceLayer",
|
||||
|
||||
128
api/core/app/workflow/layers/llm_quota.py
Normal file
128
api/core/app/workflow/layers/llm_quota.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""
|
||||
LLM quota deduction layer for GraphEngine.
|
||||
|
||||
This layer centralizes model-quota deduction outside node implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class LLMQuotaLayer(GraphEngineLayer):
|
||||
"""Graph layer that applies LLM quota deduction after node execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._abort_sent = False
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
self._abort_sent = False
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
_ = event
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
_ = error
|
||||
|
||||
@override
|
||||
def on_node_run_start(self, node: Node) -> None:
|
||||
if self._abort_sent:
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
return
|
||||
|
||||
try:
|
||||
ensure_llm_quota_available(model_instance=model_instance)
|
||||
except QuotaExceededError as exc:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=str(exc))
|
||||
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
|
||||
|
||||
@override
|
||||
def on_node_run_end(
|
||||
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
return
|
||||
|
||||
try:
|
||||
deduct_llm_quota(
|
||||
tenant_id=node.tenant_id,
|
||||
model_instance=model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
except QuotaExceededError as exc:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=str(exc))
|
||||
logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc)
|
||||
except Exception:
|
||||
logger.exception("LLM quota deduction failed, node_id=%s", node.id)
|
||||
|
||||
@staticmethod
|
||||
def _set_stop_event(node: Node) -> None:
|
||||
stop_event = getattr(node.graph_runtime_state, "stop_event", None)
|
||||
if stop_event is not None:
|
||||
stop_event.set()
|
||||
|
||||
def _send_abort_command(self, *, reason: str) -> None:
|
||||
if not self.command_channel or self._abort_sent:
|
||||
return
|
||||
|
||||
try:
|
||||
self.command_channel.send_command(
|
||||
AbortCommand(
|
||||
command_type=CommandType.ABORT,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
self._abort_sent = True
|
||||
except Exception:
|
||||
logger.exception("Failed to send quota abort command")
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_instance(node: Node) -> ModelInstance | None:
|
||||
try:
|
||||
match node.node_type:
|
||||
case NodeType.LLM:
|
||||
return cast("LLMNode", node).model_instance
|
||||
case NodeType.PARAMETER_EXTRACTOR:
|
||||
return cast("ParameterExtractorNode", node).model_instance
|
||||
case NodeType.QUESTION_CLASSIFIER:
|
||||
return cast("QuestionClassifierNode", node).model_instance
|
||||
case _:
|
||||
return None
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
|
||||
node.id,
|
||||
)
|
||||
return None
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, cast, final
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
@ -11,14 +13,16 @@ from core.helper.code_executor.code_executor import (
|
||||
CodeExecutor,
|
||||
)
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.memory import PromptMessageMemory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.enums import NodeType, SystemVariableKey
|
||||
from core.workflow.file.file_manager import file_manager
|
||||
from core.workflow.graph.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@ -29,11 +33,9 @@ from core.workflow.nodes.datasource import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.llm.protocols import PromptMessageMemory
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
@ -41,12 +43,34 @@ from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
*,
|
||||
conversation_id: str | None,
|
||||
app_id: str,
|
||||
node_data_memory: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory or not conversation_id:
|
||||
return None
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
|
||||
class DefaultWorkflowCodeExecutor:
|
||||
def execute(
|
||||
self,
|
||||
@ -221,6 +245,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return QuestionClassifierNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@ -229,10 +254,12 @@ class DifyNodeFactory(NodeFactory):
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return ParameterExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@ -241,6 +268,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
@ -295,8 +323,14 @@ class DifyNodeFactory(NodeFactory):
|
||||
return None
|
||||
|
||||
node_memory = MemoryConfig.model_validate(raw_memory_config)
|
||||
return llm_utils.fetch_memory(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
conversation_id = (
|
||||
conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None
|
||||
)
|
||||
return fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id=self.graph_init_params.app_id,
|
||||
node_data_memory=node_memory,
|
||||
model_instance=model_instance,
|
||||
|
||||
@ -4,10 +4,10 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types as mcp_types
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
||||
3
api/core/model_runtime/memory/__init__.py
Normal file
3
api/core/model_runtime/memory/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
|
||||
|
||||
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]
|
||||
18
api/core/model_runtime/memory/prompt_message_memory.py
Normal file
18
api/core/model_runtime/memory/prompt_message_memory.py
Normal file
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000
|
||||
|
||||
|
||||
class PromptMessageMemory(Protocol):
|
||||
"""Port for loading memory as prompt messages."""
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Return historical prompt messages constrained by token/message limits."""
|
||||
...
|
||||
@ -2,6 +2,7 @@ import tempfile
|
||||
from binascii import hexlify, unhexlify
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
@ -29,7 +30,6 @@ from core.plugin.entities.request import (
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
@ -63,16 +63,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
def handle() -> Generator[LLMResultChunk, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
@ -126,16 +124,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(
|
||||
response: LLMResultWithStructuredOutput,
|
||||
|
||||
@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None, # ty: ignore [invalid-argument-type]
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None, # ty: ignore [invalid-argument-type]
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
@ -249,7 +249,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
)
|
||||
@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
|
||||
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
search_iter = self._scope.search(
|
||||
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
||||
)
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
@ -35,7 +36,6 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs import helper
|
||||
@ -474,7 +474,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
# Deduct quota for summary generation (same as workflow nodes)
|
||||
try:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
except Exception as e:
|
||||
# Log but don't fail summary generation if quota deduction fails
|
||||
logger.warning("Failed to deduct quota for summary generation: %s", str(e))
|
||||
|
||||
@ -2,6 +2,7 @@ from collections.abc import Generator, Sequence
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
@ -9,7 +10,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
|
||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||
|
||||
@ -162,7 +162,7 @@ class ReactMultiDatasetRouter:
|
||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
|
||||
@ -6,9 +6,9 @@ identity:
|
||||
zh_Hans: 网页抓取
|
||||
pt_BR: WebScraper
|
||||
description:
|
||||
en_US: Web Scrapper tool kit is used to scrape web
|
||||
en_US: Web Scraper tool kit is used to scrape web
|
||||
zh_Hans: 一个用于抓取网页的工具。
|
||||
pt_BR: Web Scrapper tool kit is used to scrape web
|
||||
pt_BR: Web Scraper tool kit is used to scrape web
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- productivity
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
|
||||
@ -5,7 +5,6 @@ from collections.abc import Mapping
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
@ -23,6 +22,7 @@ from core.tools.entities.tool_entities import (
|
||||
)
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
|
||||
@ -9,7 +9,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
@ -77,13 +76,10 @@ class GraphEngine:
|
||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
# stop event
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.stop_event = self._stop_event
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
self._config = config
|
||||
@ -163,7 +159,6 @@ class GraphEngine:
|
||||
layers=self._layers,
|
||||
execution_context=execution_context,
|
||||
config=self._config,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
@ -194,7 +189,6 @@ class GraphEngine:
|
||||
event_handler=self._event_handler_registry,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
event_emitter=self._event_manager,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Validation ===
|
||||
@ -314,7 +308,6 @@ class GraphEngine:
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
self._stop_event.clear()
|
||||
paused_nodes: list[str] = []
|
||||
deferred_nodes: list[str] = []
|
||||
if resume:
|
||||
@ -348,7 +341,6 @@ class GraphEngine:
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self._stop_event.set()
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
@ -44,7 +44,6 @@ class Dispatcher:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
stop_event: threading.Event,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -62,7 +61,7 @@ class Dispatcher:
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = stop_event
|
||||
self._stop_event = threading.Event()
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
@ -70,12 +69,14 @@ class Dispatcher:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=2.0)
|
||||
|
||||
|
||||
@ -42,7 +42,6 @@ class Worker(threading.Thread):
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
) -> None:
|
||||
@ -63,16 +62,13 @@ class Worker(threading.Thread):
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._execution_context = execution_context
|
||||
self._stop_event = stop_event
|
||||
self._stop_event = threading.Event()
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
|
||||
This method is a no-op retained for backward compatibility.
|
||||
"""
|
||||
pass
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
|
||||
@ -37,7 +37,6 @@ class WorkerPool:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
config: GraphEngineConfig,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
) -> None:
|
||||
@ -64,7 +63,6 @@ class WorkerPool:
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
self._stop_event = stop_event
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
@ -135,7 +133,6 @@ class WorkerPool:
|
||||
layers=self._layers,
|
||||
worker_id=worker_id,
|
||||
execution_context=self._execution_context,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
|
||||
@ -302,10 +302,6 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _should_stop(self) -> bool:
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@ -374,21 +370,6 @@ class Node(Generic[NodeDataT]):
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
|
||||
if self._should_stop():
|
||||
error_message = "Execution cancelled"
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
),
|
||||
error=error_message,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
|
||||
@ -588,6 +588,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
@ -642,5 +643,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
return graph_engine
|
||||
|
||||
@ -1,26 +1,19 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities import PromptMessageRole
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.memory import PromptMessageMemory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
|
||||
|
||||
from .exc import InvalidVariableTypeError
|
||||
|
||||
@ -48,88 +41,51 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
|
||||
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
return memory
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||
def convert_history_messages_to_text(
|
||||
*,
|
||||
history_messages: Sequence[PromptMessage],
|
||||
human_prefix: str,
|
||||
ai_prefix: str,
|
||||
) -> str:
|
||||
string_messages: list[str] = []
|
||||
for message in history_messages:
|
||||
if message.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif message.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
used_quota = 1
|
||||
continue
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
if isinstance(message.content, list):
|
||||
content_parts = []
|
||||
for content in message.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
content_parts.append(content.data)
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
content_parts.append("[image]")
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
inner_msg = "\n".join(content_parts)
|
||||
string_messages.append(f"{role}: {inner_msg}")
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
string_messages.append(f"{role}: {message.content}")
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
def fetch_memory_text(
|
||||
*,
|
||||
memory: PromptMessageMemory,
|
||||
max_token_limit: int,
|
||||
message_limit: int | None = None,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
) -> str:
|
||||
history_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
return convert_history_messages_to_text(
|
||||
history_messages=history_messages,
|
||||
human_prefix=human_prefix,
|
||||
ai_prefix=ai_prefix,
|
||||
)
|
||||
|
||||
@ -37,6 +37,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.memory import PromptMessageMemory
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
@ -62,7 +63,7 @@ from core.workflow.node_events import (
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables import (
|
||||
ArrayFileSegment,
|
||||
@ -278,8 +279,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
else None
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
elif isinstance(event, LLMStructuredOutput):
|
||||
structured_output = event
|
||||
@ -1234,6 +1233,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def retry(self) -> bool:
|
||||
return self.node_data.retry_config.retry_enabled
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
return self._model_instance
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
@ -1336,48 +1339,16 @@ def _handle_memory_completion_mode(
|
||||
)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
memory_text = llm_utils.fetch_memory_text(
|
||||
memory=memory,
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
memory_text = _convert_history_messages_to_text(
|
||||
history_messages=memory_messages,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
return memory_text
|
||||
|
||||
|
||||
def _convert_history_messages_to_text(
|
||||
*,
|
||||
history_messages: Sequence[PromptMessage],
|
||||
human_prefix: str,
|
||||
ai_prefix: str,
|
||||
) -> str:
|
||||
string_messages: list[str] = []
|
||||
for message in history_messages:
|
||||
if message.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif message.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(message.content, list):
|
||||
content_parts = []
|
||||
for content in message.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
content_parts.append(content.data)
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
content_parts.append("[image]")
|
||||
|
||||
inner_msg = "\n".join(content_parts)
|
||||
string_messages.append(f"{role}: {inner_msg}")
|
||||
else:
|
||||
string_messages.append(f"{role}: {message.content}")
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
|
||||
class CredentialsProvider(Protocol):
|
||||
@ -21,13 +19,3 @@ class ModelFactory(Protocol):
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
"""Create a model instance that is ready for schema lookup and invocation."""
|
||||
...
|
||||
|
||||
|
||||
class PromptMessageMemory(Protocol):
|
||||
"""Port for loading memory as prompt messages for LLM nodes."""
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Return historical prompt messages constrained by token/message limits."""
|
||||
...
|
||||
|
||||
@ -413,6 +413,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
# Import dependencies
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
@ -454,5 +455,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
return graph_engine
|
||||
|
||||
@ -5,7 +5,6 @@ import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -18,13 +17,18 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.memory import PromptMessageMemory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.file import File
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
@ -97,6 +101,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
_model_instance: ModelInstance
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
_memory: PromptMessageMemory | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -108,6 +113,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -118,6 +124,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@ -163,13 +170,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
except ValueError as exc:
|
||||
raise ModelSchemaNotFoundError("Model schema not found") from exc
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
memory = self._memory
|
||||
|
||||
if (
|
||||
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
|
||||
@ -308,9 +309,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
usage = invoke_result.usage
|
||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage, tool_call
|
||||
|
||||
def _generate_function_call_prompt(
|
||||
@ -319,7 +317,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: PromptMessageMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
@ -407,7 +405,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: PromptMessageMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -445,7 +443,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: PromptMessageMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -470,7 +468,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
files=files,
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=memory,
|
||||
# AdvancedPromptTransform is still typed against TokenBufferMemory.
|
||||
memory=cast(Any, memory),
|
||||
model_instance=model_instance,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
@ -483,7 +482,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_instance: ModelInstance,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: PromptMessageMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -715,7 +714,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: PromptMessageMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@ -724,8 +723,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
|
||||
if memory and node_data.memory and node_data.memory.window:
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
memory_str = llm_utils.fetch_memory_text(
|
||||
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
@ -742,7 +741,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: PromptMessageMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@ -751,8 +750,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
|
||||
if memory and node_data.memory and node_data.memory.window:
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
memory_str = llm_utils.fetch_memory_text(
|
||||
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
@ -828,6 +827,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
return rest_tokens
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
return self._model_instance
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
@ -3,9 +3,9 @@ import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from core.model_runtime.memory import PromptMessageMemory
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
@ -56,6 +56,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
_model_instance: ModelInstance
|
||||
_memory: PromptMessageMemory | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -67,6 +68,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@ -81,6 +83,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@ -103,13 +106,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
variables = {"query": query}
|
||||
# fetch model instance
|
||||
model_instance = self._model_instance
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
memory = self._memory
|
||||
# fetch instruction
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
|
||||
@ -240,6 +237,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
return self._model_instance
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
@ -323,7 +324,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: PromptMessageMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@ -336,7 +337,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
if memory:
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
memory_str = llm_utils.fetch_memory_text(
|
||||
memory=memory,
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
|
||||
)
|
||||
|
||||
@ -2,8 +2,8 @@ from collections.abc import Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
|
||||
|
||||
class StartNodeData(BaseNodeData):
|
||||
|
||||
@ -2,12 +2,12 @@ from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.variables.input_entities import VariableEntityType
|
||||
|
||||
|
||||
class StartNode(Node[StartNodeData]):
|
||||
|
||||
@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
@ -219,8 +218,6 @@ class GraphRuntimeState:
|
||||
self._pending_graph_node_states: dict[str, NodeState] | None = None
|
||||
self._pending_graph_edge_states: dict[str, NodeState] | None = None
|
||||
|
||||
self.stop_event: threading.Event = threading.Event()
|
||||
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from .input_entities import VariableEntity, VariableEntityType
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import (
|
||||
ArrayAnySegment,
|
||||
@ -64,4 +65,6 @@ __all__ = [
|
||||
"StringVariable",
|
||||
"Variable",
|
||||
"VariableBase",
|
||||
"VariableEntity",
|
||||
"VariableEntityType",
|
||||
]
|
||||
|
||||
62
api/core/workflow/variables/input_entities.py
Normal file
62
api/core/workflow/variables/input_entities.py
Normal file
@ -0,0 +1,62 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, SchemaError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.workflow.file import FileTransferMethod, FileType
|
||||
|
||||
|
||||
class VariableEntityType(StrEnum):
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
FILE = "file"
|
||||
FILE_LIST = "file-list"
|
||||
CHECKBOX = "checkbox"
|
||||
JSON_OBJECT = "json_object"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
"""
|
||||
Shared variable entity used by workflow runtime and app configuration.
|
||||
"""
|
||||
|
||||
# `variable` records the name of the variable in user inputs.
|
||||
variable: str
|
||||
label: str
|
||||
description: str = ""
|
||||
type: VariableEntityType
|
||||
required: bool = False
|
||||
hide: bool = False
|
||||
default: Any = None
|
||||
max_length: int | None = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
def convert_none_description(cls, value: Any) -> str:
|
||||
return value or ""
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def convert_none_options(cls, value: Any) -> Sequence[str]:
|
||||
return value or []
|
||||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
Draft7Validator.check_schema(schema)
|
||||
except SchemaError as error:
|
||||
raise ValueError(f"Invalid JSON schema: {error.message}")
|
||||
return schema
|
||||
@ -6,6 +6,7 @@ from typing import Any, cast
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
@ -106,6 +107,7 @@ class WorkflowEntry:
|
||||
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
self.graph_engine.layer(limits_layer)
|
||||
self.graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
# Add observability layer when OTel is enabled
|
||||
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
|
||||
|
||||
48
api/libs/pyrefly_diagnostics.py
Normal file
48
api/libs/pyrefly_diagnostics.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""Helpers for producing concise pyrefly diagnostics for CI diff output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
_DIAGNOSTIC_PREFIXES = ("ERROR ", "WARNING ")
|
||||
_LOCATION_PREFIX = "-->"
|
||||
|
||||
|
||||
def extract_diagnostics(raw_output: str) -> str:
|
||||
"""Extract stable diagnostic lines from pyrefly output.
|
||||
|
||||
The full pyrefly output includes code excerpts and carets, which create noisy
|
||||
diffs. This helper keeps only:
|
||||
- diagnostic headline lines (``ERROR ...`` / ``WARNING ...``)
|
||||
- the following location line (``--> path:line:column``), when present
|
||||
"""
|
||||
|
||||
lines = raw_output.splitlines()
|
||||
diagnostics: list[str] = []
|
||||
|
||||
for index, line in enumerate(lines):
|
||||
if line.startswith(_DIAGNOSTIC_PREFIXES):
|
||||
diagnostics.append(line.rstrip())
|
||||
|
||||
next_index = index + 1
|
||||
if next_index < len(lines):
|
||||
next_line = lines[next_index]
|
||||
if next_line.lstrip().startswith(_LOCATION_PREFIX):
|
||||
diagnostics.append(next_line.rstrip())
|
||||
|
||||
if not diagnostics:
|
||||
return ""
|
||||
|
||||
return "\n".join(diagnostics) + "\n"
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Read pyrefly output from stdin and print normalized diagnostics."""
|
||||
|
||||
raw_output = sys.stdin.read()
|
||||
sys.stdout.write(extract_diagnostics(raw_output))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@ -787,7 +787,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
|
||||
__tablename__ = "workflow_node_executions"
|
||||
|
||||
@declared_attr
|
||||
@declared_attr.directive
|
||||
@classmethod
|
||||
def __table_args__(cls) -> Any:
|
||||
return (
|
||||
|
||||
@ -68,7 +68,7 @@ dependencies = [
|
||||
"pydantic~=2.12.5",
|
||||
"pydantic-extra-types~=2.10.3",
|
||||
"pydantic-settings~=2.12.0",
|
||||
"pyjwt~=2.10.1",
|
||||
"pyjwt~=2.11.0",
|
||||
"pypdfium2==5.2.0",
|
||||
"python-docx~=1.2.0",
|
||||
"python-dotenv==1.0.1",
|
||||
@ -116,7 +116,6 @@ dev = [
|
||||
"dotenv-linter~=0.5.0",
|
||||
"faker~=38.2.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"ty>=0.0.14",
|
||||
"basedpyright~=1.31.0",
|
||||
"ruff~=0.14.0",
|
||||
"pytest~=8.3.2",
|
||||
@ -125,7 +124,7 @@ dev = [
|
||||
"pytest-env~=1.1.3",
|
||||
"pytest-mock~=3.14.0",
|
||||
"testcontainers~=4.13.2",
|
||||
"types-aiofiles~=24.1.0",
|
||||
"types-aiofiles~=25.1.0",
|
||||
"types-beautifulsoup4~=4.12.0",
|
||||
"types-cachetools~=5.5.0",
|
||||
"types-colorama~=0.4.15",
|
||||
|
||||
@ -29,7 +29,7 @@ from typing import Any, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import and_, delete, func, null, or_, select
|
||||
from sqlalchemy import and_, delete, func, null, or_, select, tuple_
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
@ -423,9 +423,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
|
||||
if last_seen:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
WorkflowRun.created_at > last_seen[0],
|
||||
and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
|
||||
tuple_(WorkflowRun.created_at, WorkflowRun.id)
|
||||
> tuple_(
|
||||
sa.literal(last_seen[0], type_=sa.DateTime()),
|
||||
sa.literal(last_seen[1], type_=WorkflowRun.id.type),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -8,7 +8,6 @@ from core.app.app_config.entities import (
|
||||
ExternalDataVariableEntity,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
)
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
@ -20,6 +19,7 @@ from core.prompt.simple_prompt_transform import SimplePromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.file.models import FileUploadConfig
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
|
||||
@ -9,7 +9,6 @@ from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -40,6 +39,7 @@ from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import load_into_variable_pool
|
||||
from core.workflow.variables import VariableBase
|
||||
from core.workflow.variables.input_entities import VariableEntityType
|
||||
from core.workflow.variables.variables import Variable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from enums.cloud_plan import CloudPlan
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
@ -125,7 +126,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
|
||||
data_source_info = document.data_source_info_dict
|
||||
data_source_info["last_edited_time"] = last_edited_time
|
||||
document.data_source_info = data_source_info
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
|
||||
@ -5,7 +5,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
@ -22,19 +22,17 @@ from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_mod
|
||||
|
||||
def get_mocked_fetch_memory(memory_text: str):
|
||||
class MemoryMock:
|
||||
def get_history_prompt_text(
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
):
|
||||
return memory_text
|
||||
return [UserPromptMessage(content=memory_text), AssistantPromptMessage(content="mocked answer")]
|
||||
|
||||
return MagicMock(return_value=MemoryMock())
|
||||
|
||||
|
||||
def init_parameter_extractor_node(config: dict):
|
||||
def init_parameter_extractor_node(config: dict, memory=None):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
@ -79,6 +77,7 @@ def init_parameter_extractor_node(config: dict):
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
memory=memory,
|
||||
)
|
||||
return node
|
||||
|
||||
@ -350,7 +349,7 @@ def test_extract_json_from_tool_call():
|
||||
assert result["location"] == "kawaii"
|
||||
|
||||
|
||||
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
|
||||
def test_chat_parameter_extractor_with_memory(setup_model_mock):
|
||||
"""
|
||||
Test chat parameter extractor with memory.
|
||||
"""
|
||||
@ -373,6 +372,7 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
|
||||
"memory": {"window": {"enabled": True, "size": 50}},
|
||||
},
|
||||
},
|
||||
memory=get_mocked_fetch_memory("customized memory")(),
|
||||
)
|
||||
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
@ -381,8 +381,6 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)()
|
||||
# Test the mock before running the actual test
|
||||
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
|
||||
@ -10,11 +10,10 @@ from core.app.app_config.entities import (
|
||||
ExternalDataVariableEntity,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
VariableEntityType,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models import Account, Tenant
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
|
||||
@ -147,8 +147,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
document.cleaning_completed_at = fake.date_time_this_year()
|
||||
document.splitting_completed_at = fake.date_time_this_year()
|
||||
document.tokens = fake.random_int(min=50, max=500)
|
||||
document.indexing_started_at = fake.date_time_this_year()
|
||||
document.indexing_completed_at = fake.date_time_this_year()
|
||||
document.completed_at = fake.date_time_this_year()
|
||||
document.indexing_status = "completed"
|
||||
document.enabled = True
|
||||
document.archived = False
|
||||
|
||||
@ -12,8 +12,6 @@ from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from psycopg2.extensions import register_adapter
|
||||
from psycopg2.extras import Json
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@ -21,12 +19,6 @@ from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _register_dict_adapter_for_psycopg2():
|
||||
"""Align test DB adapter behavior with dict payloads used in task update flow."""
|
||||
register_adapter(dict, Json)
|
||||
|
||||
|
||||
class DocumentIndexingSyncTaskTestDataFactory:
|
||||
"""Create real DB entities for document indexing sync integration tests."""
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
|
||||
def test_validate_inputs_with_zero():
|
||||
|
||||
@ -4,7 +4,6 @@ from unittest.mock import Mock, patch
|
||||
import jsonschema
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types
|
||||
from core.mcp.server.streamable_http import (
|
||||
@ -19,6 +18,7 @@ from core.mcp.server.streamable_http import (
|
||||
prepare_tool_arguments,
|
||||
process_mapping_response,
|
||||
)
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,174 @@
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.commands import CommandType
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
def _build_succeeded_event() -> NodeRunSucceededEvent:
|
||||
return NodeRunSucceededEvent(
|
||||
id="execution-id",
|
||||
node_id="llm-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.now(),
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"question": "hello"},
|
||||
llm_usage=LLMUsage.empty_usage(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_successful_llm_node() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = NodeType.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.model_instance = object()
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=node.model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_question_classifier_node() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "question-classifier-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = NodeType.QUESTION_CLASSIFIER
|
||||
node.tenant_id = "tenant-id"
|
||||
node.model_instance = object()
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=node.model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
|
||||
def test_non_llm_node_is_ignored() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "start-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = NodeType.START
|
||||
node.tenant_id = "tenant-id"
|
||||
node._model_instance = object()
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_error_is_handled_in_layer() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = NodeType.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.model_instance = object()
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||
autospec=True,
|
||||
side_effect=ValueError("quota exceeded"),
|
||||
):
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
|
||||
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = NodeType.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.model_instance = object()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("No credits remaining"),
|
||||
):
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
assert stop_event.is_set()
|
||||
layer.command_channel.send_command.assert_called_once()
|
||||
abort_command = layer.command_channel.send_command.call_args.args[0]
|
||||
assert abort_command.command_type == CommandType.ABORT
|
||||
assert abort_command.reason == "No credits remaining"
|
||||
|
||||
|
||||
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.node_type = NodeType.LLM
|
||||
node.model_instance = object()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
|
||||
):
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
assert stop_event.is_set()
|
||||
layer.command_channel.send_command.assert_called_once()
|
||||
abort_command = layer.command_channel.send_command.call_args.args[0]
|
||||
assert abort_command.command_type == CommandType.ABORT
|
||||
assert abort_command.reason == "Model provider openai quota exceeded."
|
||||
|
||||
|
||||
def test_quota_precheck_passes_without_abort() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.node_type = NodeType.LLM
|
||||
node.model_instance = object()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
assert not stop_event.is_set()
|
||||
mock_check.assert_called_once_with(model_instance=node.model_instance)
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from unittest import mock
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
@ -37,7 +36,6 @@ def test_dispatcher_should_consume_remains_events_after_pause():
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
execution_coordinator=execution_coordinator,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
dispatcher._dispatcher_loop()
|
||||
assert event_queue.empty()
|
||||
@ -98,7 +96,6 @@ def _run_dispatcher_for_event(event) -> int:
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
execution_coordinator=coordinator,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
@ -184,7 +181,6 @@ def test_dispatcher_drain_event_queue():
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
execution_coordinator=coordinator,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import queue
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@ -65,7 +64,6 @@ def test_dispatcher_drains_events_when_paused() -> None:
|
||||
event_handler=handler,
|
||||
execution_coordinator=coordinator,
|
||||
event_emitter=None,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
|
||||
@ -1,550 +0,0 @@
|
||||
"""
|
||||
Unit tests for stop_event functionality in GraphEngine.
|
||||
|
||||
Tests the unified stop_event management by GraphEngine and its propagation
|
||||
to WorkerPool, Worker, Dispatcher, and Nodes.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
class TestStopEventPropagation:
|
||||
"""Test suite for stop_event propagation through GraphEngine components."""
|
||||
|
||||
def test_graph_engine_creates_stop_event(self):
|
||||
"""Test that GraphEngine creates a stop_event on initialization."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Verify stop_event was created
|
||||
assert engine._stop_event is not None
|
||||
assert isinstance(engine._stop_event, threading.Event)
|
||||
|
||||
# Verify it was set in graph_runtime_state
|
||||
assert runtime_state.stop_event is not None
|
||||
assert runtime_state.stop_event is engine._stop_event
|
||||
|
||||
def test_stop_event_cleared_on_start(self):
|
||||
"""Test that stop_event is cleared when execution starts."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start" # Set proper id
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Set the stop_event before running
|
||||
engine._stop_event.set()
|
||||
assert engine._stop_event.is_set()
|
||||
|
||||
# Run the engine (should clear the stop_event)
|
||||
events = list(engine.run())
|
||||
|
||||
# After running, stop_event should be set again (by _stop_execution)
|
||||
# But during start it was cleared
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
|
||||
|
||||
def test_stop_event_set_on_stop(self):
|
||||
"""Test that stop_event is set when execution stops."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start" # Set proper id
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Initially not set
|
||||
assert not engine._stop_event.is_set()
|
||||
|
||||
# Run the engine
|
||||
list(engine.run())
|
||||
|
||||
# After execution completes, stop_event should be set
|
||||
assert engine._stop_event.is_set()
|
||||
|
||||
def test_stop_event_passed_to_worker_pool(self):
|
||||
"""Test that stop_event is passed to WorkerPool."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Verify WorkerPool has the stop_event
|
||||
assert engine._worker_pool._stop_event is not None
|
||||
assert engine._worker_pool._stop_event is engine._stop_event
|
||||
|
||||
def test_stop_event_passed_to_dispatcher(self):
|
||||
"""Test that stop_event is passed to Dispatcher."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Verify Dispatcher has the stop_event
|
||||
assert engine._dispatcher._stop_event is not None
|
||||
assert engine._dispatcher._stop_event is engine._stop_event
|
||||
|
||||
|
||||
class TestNodeStopCheck:
|
||||
"""Test suite for Node._should_stop() functionality."""
|
||||
|
||||
def test_node_should_stop_checks_runtime_state(self):
|
||||
"""Test that Node._should_stop() checks GraphRuntimeState.stop_event."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
# Initially stop_event is not set
|
||||
assert not answer_node._should_stop()
|
||||
|
||||
# Set the stop_event
|
||||
runtime_state.stop_event.set()
|
||||
|
||||
# Now _should_stop should return True
|
||||
assert answer_node._should_stop()
|
||||
|
||||
def test_node_run_checks_stop_event_between_yields(self):
|
||||
"""Test that Node.run() checks stop_event between yielding events."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
# Create a simple node
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
# Set stop_event BEFORE running the node
|
||||
runtime_state.stop_event.set()
|
||||
|
||||
# Run the node - should yield start event then detect stop
|
||||
# The node should check stop_event before processing
|
||||
assert answer_node._should_stop(), "stop_event should be set"
|
||||
|
||||
# Run and collect events
|
||||
events = list(answer_node.run())
|
||||
|
||||
# Since stop_event is set at the start, we should get:
|
||||
# 1. NodeRunStartedEvent (always yielded first)
|
||||
# 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast)
|
||||
assert len(events) >= 2
|
||||
assert isinstance(events[0], NodeRunStartedEvent)
|
||||
|
||||
# Note: AnswerNode is very simple and might complete before stop check
|
||||
# The important thing is that _should_stop() returns True when stop_event is set
|
||||
assert answer_node._should_stop()
|
||||
|
||||
|
||||
class TestStopEventIntegration:
|
||||
"""Integration tests for stop_event in workflow execution."""
|
||||
|
||||
def test_simple_workflow_respects_stop_event(self):
|
||||
"""Test that a simple workflow respects stop_event."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start"
|
||||
|
||||
# Create start and answer nodes
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.nodes["answer"] = answer_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Set stop_event before running
|
||||
runtime_state.stop_event.set()
|
||||
|
||||
# Run the engine
|
||||
events = list(engine.run())
|
||||
|
||||
# Should get started event but not succeeded (due to stop)
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
# The workflow should still complete (start node runs quickly)
|
||||
# but answer node might be cancelled depending on timing
|
||||
|
||||
def test_stop_event_with_concurrent_nodes(self):
|
||||
"""Test stop_event behavior with multiple concurrent nodes."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
# Create multiple nodes
|
||||
for i in range(3):
|
||||
answer_node = AnswerNode(
|
||||
id=f"answer_{i}",
|
||||
config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes[f"answer_{i}"] = answer_node
|
||||
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# All nodes should share the same stop_event
|
||||
for node in mock_graph.nodes.values():
|
||||
assert node.graph_runtime_state.stop_event is runtime_state.stop_event
|
||||
assert node.graph_runtime_state.stop_event is engine._stop_event
|
||||
|
||||
|
||||
class TestStopEventTimeoutBehavior:
|
||||
"""Test stop_event behavior with join timeouts."""
|
||||
|
||||
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread", autospec=True)
|
||||
def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock):
|
||||
"""Test that Dispatcher uses 2s timeout instead of 10s."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
dispatcher = engine._dispatcher
|
||||
dispatcher.start() # This will create and start the mocked thread
|
||||
|
||||
mock_thread_instance = mock_thread_cls.return_value
|
||||
mock_thread_instance.is_alive.return_value = True
|
||||
|
||||
dispatcher.stop()
|
||||
|
||||
mock_thread_instance.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker", autospec=True)
|
||||
def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock):
|
||||
"""Test that WorkerPool uses 2s timeout instead of 10s."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
worker_pool = engine._worker_pool
|
||||
worker_pool.start(initial_count=1) # Start with one worker
|
||||
|
||||
mock_worker_instance = mock_worker_cls.return_value
|
||||
mock_worker_instance.is_alive.return_value = True
|
||||
|
||||
worker_pool.stop()
|
||||
|
||||
mock_worker_instance.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
|
||||
class TestStopEventResumeBehavior:
|
||||
"""Test stop_event behavior during workflow resume."""
|
||||
|
||||
def test_stop_event_cleared_on_resume(self):
|
||||
"""Test that stop_event is cleared when resuming a paused workflow."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start" # Set proper id
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Simulate a previous execution that set stop_event
|
||||
engine._stop_event.set()
|
||||
assert engine._stop_event.is_set()
|
||||
|
||||
# Run the engine (should clear stop_event in _start_execution)
|
||||
events = list(engine.run())
|
||||
|
||||
# Execution should complete successfully
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
|
||||
|
||||
|
||||
class TestWorkerStopBehavior:
|
||||
"""Test Worker behavior with shared stop_event."""
|
||||
|
||||
def test_worker_uses_shared_stop_event(self):
|
||||
"""Test that Worker uses shared stop_event from GraphEngine."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Get the worker pool and check workers
|
||||
worker_pool = engine._worker_pool
|
||||
|
||||
# Start the worker pool to create workers
|
||||
worker_pool.start()
|
||||
|
||||
# Check that at least one worker was created
|
||||
assert len(worker_pool._workers) > 0
|
||||
|
||||
# Verify workers use the shared stop_event
|
||||
for worker in worker_pool._workers:
|
||||
assert worker._stop_event is engine._stop_event
|
||||
|
||||
# Clean up
|
||||
worker_pool.stop()
|
||||
|
||||
def test_worker_stop_is_noop(self):
|
||||
"""Test that Worker.stop() is now a no-op."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
# Create a mock worker
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from core.workflow.graph_engine.worker import Worker
|
||||
|
||||
ready_queue = InMemoryReadyQueue()
|
||||
event_queue = MagicMock()
|
||||
|
||||
# Create a proper mock graph with real dict
|
||||
mock_graph = Mock(spec=Graph)
|
||||
mock_graph.nodes = {} # Use real dict
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=ready_queue,
|
||||
event_queue=event_queue,
|
||||
graph=mock_graph,
|
||||
layers=[],
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
# Calling stop() should do nothing (no-op)
|
||||
# and should NOT set the stop_event
|
||||
worker.stop()
|
||||
assert not stop_event.is_set()
|
||||
@ -4,12 +4,12 @@ import time
|
||||
import pytest
|
||||
from pydantic import ValidationError as PydanticValidationError
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
|
||||
def make_start_node(user_inputs, variables):
|
||||
|
||||
51
api/tests/unit_tests/libs/test_pyrefly_diagnostics.py
Normal file
51
api/tests/unit_tests/libs/test_pyrefly_diagnostics.py
Normal file
@ -0,0 +1,51 @@
|
||||
from libs.pyrefly_diagnostics import extract_diagnostics
|
||||
|
||||
|
||||
def test_extract_diagnostics_keeps_only_summary_and_location_lines() -> None:
|
||||
# Arrange
|
||||
raw_output = """INFO Checking project configured at `/tmp/project/pyrefly.toml`
|
||||
ERROR `result` may be uninitialized [unbound-name]
|
||||
--> controllers/console/app/annotation.py:126:16
|
||||
|
|
||||
126 | return result, 200
|
||||
| ^^^^^^
|
||||
|
|
||||
ERROR Object of class `App` has no attribute `access_mode` [missing-attribute]
|
||||
--> controllers/console/app/app.py:574:13
|
||||
|
|
||||
574 | app_model.access_mode = app_setting.access_mode
|
||||
| ^^^^^^^^^^^^^^^^^^^^^
|
||||
"""
|
||||
|
||||
# Act
|
||||
diagnostics = extract_diagnostics(raw_output)
|
||||
|
||||
# Assert
|
||||
assert diagnostics == (
|
||||
"ERROR `result` may be uninitialized [unbound-name]\n"
|
||||
" --> controllers/console/app/annotation.py:126:16\n"
|
||||
"ERROR Object of class `App` has no attribute `access_mode` [missing-attribute]\n"
|
||||
" --> controllers/console/app/app.py:574:13\n"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_diagnostics_handles_error_without_location_line() -> None:
|
||||
# Arrange
|
||||
raw_output = "ERROR unexpected pyrefly output format [bad-format]\n"
|
||||
|
||||
# Act
|
||||
diagnostics = extract_diagnostics(raw_output)
|
||||
|
||||
# Assert
|
||||
assert diagnostics == "ERROR unexpected pyrefly output format [bad-format]\n"
|
||||
|
||||
|
||||
def test_extract_diagnostics_returns_empty_for_non_error_output() -> None:
|
||||
# Arrange
|
||||
raw_output = "INFO Checking project configured at `/tmp/project/pyrefly.toml`\n"
|
||||
|
||||
# Act
|
||||
diagnostics = extract_diagnostics(raw_output)
|
||||
|
||||
# Assert
|
||||
assert diagnostics == ""
|
||||
@ -13,12 +13,11 @@ from core.app.app_config.entities import (
|
||||
ExternalDataVariableEntity,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
VariableEntityType,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from models.model import AppMode
|
||||
from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
@ -5,6 +5,7 @@ These tests intentionally stay in unit scope because they validate call argument
|
||||
for external collaborators rather than SQL-backed state transitions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
@ -196,3 +197,78 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
|
||||
class TestDataSourceInfoSerialization:
|
||||
"""Regression test: data_source_info must be written as a JSON string, not a raw dict.
|
||||
|
||||
See https://github.com/langgenius/dify/issues/32705
|
||||
psycopg2 raises ``ProgrammingError: can't adapt type 'dict'`` when a Python
|
||||
dict is passed directly to a text/LongText column.
|
||||
"""
|
||||
|
||||
def test_data_source_info_serialized_as_json_string(
|
||||
self,
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""data_source_info must be serialized with json.dumps before DB write."""
|
||||
with (
|
||||
patch("tasks.document_indexing_sync_task.session_factory") as mock_session_factory,
|
||||
patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class,
|
||||
patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class,
|
||||
patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_ipf,
|
||||
patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class,
|
||||
):
|
||||
# External collaborators
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_datasource_credentials.return_value = {"integration_secret": "token"}
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
mock_extractor = MagicMock()
|
||||
# Return a *different* timestamp so the task enters the sync/update branch
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-02-01T00:00:00Z"
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
|
||||
mock_ip = MagicMock()
|
||||
mock_ipf.return_value.init_index_processor.return_value = mock_ip
|
||||
|
||||
mock_runner = MagicMock()
|
||||
mock_runner_class.return_value = mock_runner
|
||||
|
||||
# DB session mock — shared across all ``session_factory.create_session()`` calls
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
# .where() path: session 1 reads document + dataset, session 2 reads dataset
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
mock_dataset,
|
||||
]
|
||||
# .filter_by() path: session 3 (update), session 4 (indexing)
|
||||
session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
begin_cm.__exit__.return_value = False
|
||||
session.begin.return_value = begin_cm
|
||||
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
session_cm.__exit__.return_value = False
|
||||
mock_session_factory.create_session.return_value = session_cm
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert: data_source_info must be a JSON *string*, not a dict
|
||||
assert isinstance(mock_document.data_source_info, str), (
|
||||
f"data_source_info should be a JSON string, got {type(mock_document.data_source_info).__name__}"
|
||||
)
|
||||
parsed = json.loads(mock_document.data_source_info)
|
||||
assert parsed["last_edited_time"] == "2024-02-01T00:00:00Z"
|
||||
|
||||
50
api/ty.toml
50
api/ty.toml
@ -1,50 +0,0 @@
|
||||
[src]
|
||||
exclude = [
|
||||
# deps groups (A1/A2/B/C/D/E)
|
||||
# B: app runner + prompt
|
||||
"core/prompt",
|
||||
"core/app/apps/base_app_runner.py",
|
||||
"core/app/apps/workflow_app_runner.py",
|
||||
"core/agent",
|
||||
"core/plugin",
|
||||
# C: services/controllers/fields/libs
|
||||
"services",
|
||||
"controllers/inner_api",
|
||||
"controllers/console/app",
|
||||
"controllers/console/explore",
|
||||
"controllers/console/datasets",
|
||||
"controllers/console/workspace",
|
||||
"controllers/service_api/wraps.py",
|
||||
"fields/conversation_fields.py",
|
||||
"libs/external_api.py",
|
||||
# D: observability + integrations
|
||||
"core/ops",
|
||||
"extensions",
|
||||
# E: vector DB integrations
|
||||
"core/rag/datasource/vdb",
|
||||
# non-producition or generated code
|
||||
"migrations",
|
||||
"tests",
|
||||
# targeted ignores for current type-check errors
|
||||
# TODO(QuantumGhost): suppress type errors in HITL related code.
|
||||
# fix the type error later
|
||||
"configs/middleware/cache/redis_pubsub_config.py",
|
||||
"extensions/ext_redis.py",
|
||||
"models/execution_extra_content.py",
|
||||
"tasks/workflow_execution_tasks.py",
|
||||
"core/workflow/nodes/base/node.py",
|
||||
"services/human_input_delivery_test_service.py",
|
||||
"core/app/apps/advanced_chat/app_generator.py",
|
||||
"controllers/console/human_input_form.py",
|
||||
"controllers/console/app/workflow_run.py",
|
||||
"repositories/sqlalchemy_api_workflow_node_execution_repository.py",
|
||||
"extensions/logstore/repositories/logstore_api_workflow_run_repository.py",
|
||||
"controllers/web/workflow_events.py",
|
||||
"tasks/app_generate/workflow_execute_task.py",
|
||||
]
|
||||
|
||||
|
||||
[rules]
|
||||
deprecated = "ignore"
|
||||
unused-ignore-comment = "ignore"
|
||||
# possibly-missing-attribute = "ignore"
|
||||
42
api/uv.lock
generated
42
api/uv.lock
generated
@ -1483,7 +1483,6 @@ dev = [
|
||||
{ name = "scipy-stubs" },
|
||||
{ name = "sseclient-py" },
|
||||
{ name = "testcontainers" },
|
||||
{ name = "ty" },
|
||||
{ name = "types-aiofiles" },
|
||||
{ name = "types-beautifulsoup4" },
|
||||
{ name = "types-cachetools" },
|
||||
@ -1637,7 +1636,7 @@ requires-dist = [
|
||||
{ name = "pydantic", specifier = "~=2.12.5" },
|
||||
{ name = "pydantic-extra-types", specifier = "~=2.10.3" },
|
||||
{ name = "pydantic-settings", specifier = "~=2.12.0" },
|
||||
{ name = "pyjwt", specifier = "~=2.10.1" },
|
||||
{ name = "pyjwt", specifier = "~=2.11.0" },
|
||||
{ name = "pypdfium2", specifier = "==5.2.0" },
|
||||
{ name = "python-docx", specifier = "~=1.2.0" },
|
||||
{ name = "python-dotenv", specifier = "==1.0.1" },
|
||||
@ -1684,8 +1683,7 @@ dev = [
|
||||
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
|
||||
{ name = "sseclient-py", specifier = ">=1.8.0" },
|
||||
{ name = "testcontainers", specifier = "~=4.13.2" },
|
||||
{ name = "ty", specifier = ">=0.0.14" },
|
||||
{ name = "types-aiofiles", specifier = "~=24.1.0" },
|
||||
{ name = "types-aiofiles", specifier = "~=25.1.0" },
|
||||
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
|
||||
{ name = "types-cachetools", specifier = "~=5.5.0" },
|
||||
{ name = "types-cffi", specifier = ">=1.17.0" },
|
||||
@ -4959,11 +4957,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pyjwt"
|
||||
version = "2.10.1"
|
||||
version = "2.11.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
@ -6278,30 +6276,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/70/26/2591b48412bde75e33bfd292034103ffe41743cacd03120e3242516cd143/transformers-4.56.2-py3-none-any.whl", hash = "sha256:79c03d0e85b26cb573c109ff9eafa96f3c8d4febfd8a0774e8bba32702dd6dde", size = 11608055, upload-time = "2025-09-19T15:16:23.736Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ty"
|
||||
version = "0.0.14"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/af/57/22c3d6bf95c2229120c49ffc2f0da8d9e8823755a1c3194da56e51f1cc31/ty-0.0.14.tar.gz", hash = "sha256:a691010565f59dd7f15cf324cdcd1d9065e010c77a04f887e1ea070ba34a7de2", size = 5036573, upload-time = "2026-01-27T00:57:31.427Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/99/cb/cc6d1d8de59beb17a41f9a614585f884ec2d95450306c173b3b7cc090d2e/ty-0.0.14-py3-none-linux_armv6l.whl", hash = "sha256:32cf2a7596e693094621d3ae568d7ee16707dce28c34d1762947874060fdddaa", size = 10034228, upload-time = "2026-01-27T00:57:53.133Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/96/dd42816a2075a8f31542296ae687483a8d047f86a6538dfba573223eaf9a/ty-0.0.14-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f971bf9805f49ce8c0968ad53e29624d80b970b9eb597b7cbaba25d8a18ce9a2", size = 9939162, upload-time = "2026-01-27T00:57:43.857Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/b4/73c4859004e0f0a9eead9ecb67021438b2e8e5fdd8d03e7f5aca77623992/ty-0.0.14-py3-none-macosx_11_0_arm64.whl", hash = "sha256:45448b9e4806423523268bc15e9208c4f3f2ead7c344f615549d2e2354d6e924", size = 9418661, upload-time = "2026-01-27T00:58:03.411Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/35/839c4551b94613db4afa20ee555dd4f33bfa7352d5da74c5fa416ffa0fd2/ty-0.0.14-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee94a9b747ff40114085206bdb3205a631ef19a4d3fb89e302a88754cbbae54c", size = 9837872, upload-time = "2026-01-27T00:57:23.718Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/41/2b/bbecf7e2faa20c04bebd35fc478668953ca50ee5847ce23e08acf20ea119/ty-0.0.14-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6756715a3c33182e9ab8ffca2bb314d3c99b9c410b171736e145773ee0ae41c3", size = 9848819, upload-time = "2026-01-27T00:57:58.501Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/60/3c0ba0f19c0f647ad9d2b5b5ac68c0f0b4dc899001bd53b3a7537fb247a2/ty-0.0.14-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89d0038a2f698ba8b6fec5cf216a4e44e2f95e4a5095a8c0f57fe549f87087c2", size = 10324371, upload-time = "2026-01-27T00:57:29.291Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/24/32/99d0a0b37d0397b0a989ffc2682493286aa3bc252b24004a6714368c2c3d/ty-0.0.14-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c64a83a2d669b77f50a4957039ca1450626fb474619f18f6f8a3eb885bf7544", size = 10865898, upload-time = "2026-01-27T00:57:33.542Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1a/88/30b583a9e0311bb474269cfa91db53350557ebec09002bfc3fb3fc364e8c/ty-0.0.14-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:242488bfb547ef080199f6fd81369ab9cb638a778bb161511d091ffd49c12129", size = 10555777, upload-time = "2026-01-27T00:58:05.853Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/a2/cb53fb6325dcf3d40f2b1d0457a25d55bfbae633c8e337bde8ec01a190eb/ty-0.0.14-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4790c3866f6c83a4f424fc7d09ebdb225c1f1131647ba8bdc6fcdc28f09ed0ff", size = 10412913, upload-time = "2026-01-27T00:57:38.834Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/8f/f2f5202d725ed1e6a4e5ffaa32b190a1fe70c0b1a2503d38515da4130b4c/ty-0.0.14-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:950f320437f96d4ea9a2332bbfb5b68f1c1acd269ebfa4c09b6970cc1565bd9d", size = 9837608, upload-time = "2026-01-27T00:57:55.898Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/ba/59a2a0521640c489dafa2c546ae1f8465f92956fede18660653cce73b4c5/ty-0.0.14-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4a0ec3ee70d83887f86925bbc1c56f4628bd58a0f47f6f32ddfe04e1f05466df", size = 9884324, upload-time = "2026-01-27T00:57:46.786Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/03/95/8d2a49880f47b638743212f011088552ecc454dd7a665ddcbdabea25772a/ty-0.0.14-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a1a4e6b6da0c58b34415955279eff754d6206b35af56a18bb70eb519d8d139ef", size = 10033537, upload-time = "2026-01-27T00:58:01.149Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/40/4523b36f2ce69f92ccf783855a9e0ebbbd0f0bb5cdce6211ee1737159ed3/ty-0.0.14-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:dc04384e874c5de4c5d743369c277c8aa73d1edea3c7fc646b2064b637db4db3", size = 10495910, upload-time = "2026-01-27T00:57:26.691Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/08/d5/655beb51224d1bfd4f9ddc0bb209659bfe71ff141bcf05c418ab670698f0/ty-0.0.14-py3-none-win32.whl", hash = "sha256:b20e22cf54c66b3e37e87377635da412d9a552c9bf4ad9fc449fed8b2e19dad2", size = 9507626, upload-time = "2026-01-27T00:57:41.43Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/d9/c569c9961760e20e0a4bc008eeb1415754564304fd53997a371b7cf3f864/ty-0.0.14-py3-none-win_amd64.whl", hash = "sha256:e312ff9475522d1a33186657fe74d1ec98e4a13e016d66f5758a452c90ff6409", size = 10437980, upload-time = "2026-01-27T00:57:36.422Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/0c/186829654f5bfd9a028f6648e9caeb11271960a61de97484627d24443f91/ty-0.0.14-py3-none-win_arm64.whl", hash = "sha256:b6facdbe9b740cb2c15293a1d178e22ffc600653646452632541d01c36d5e378", size = 9885831, upload-time = "2026-01-27T00:57:49.747Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typer"
|
||||
version = "0.20.0"
|
||||
@ -6319,11 +6293,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "types-aiofiles"
|
||||
version = "24.1.0.20250822"
|
||||
version = "25.1.0.20251011"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/19/48/c64471adac9206cc844afb33ed311ac5a65d2f59df3d861e0f2d0cad7414/types_aiofiles-24.1.0.20250822.tar.gz", hash = "sha256:9ab90d8e0c307fe97a7cf09338301e3f01a163e39f3b529ace82466355c84a7b", size = 14484, upload-time = "2025-08-22T03:02:23.039Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/84/6c/6d23908a8217e36704aa9c79d99a620f2fdd388b66a4b7f72fbc6b6ff6c6/types_aiofiles-25.1.0.20251011.tar.gz", hash = "sha256:1c2b8ab260cb3cd40c15f9d10efdc05a6e1e6b02899304d80dfa0410e028d3ff", size = 14535, upload-time = "2025-10-11T02:44:51.237Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/8e/5e6d2215e1d8f7c2a94c6e9d0059ae8109ce0f5681956d11bb0a228cef04/types_aiofiles-24.1.0.20250822-py3-none-any.whl", hash = "sha256:0ec8f8909e1a85a5a79aed0573af7901f53120dd2a29771dd0b3ef48e12328b0", size = 14322, upload-time = "2025-08-22T03:02:21.918Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/71/0f/76917bab27e270bb6c32addd5968d69e558e5b6f7fb4ac4cbfa282996a96/types_aiofiles-25.1.0.20251011-py3-none-any.whl", hash = "sha256:8ff8de7f9d42739d8f0dadcceeb781ce27cd8d8c4152d4a7c52f6b20edb8149c", size = 14338, upload-time = "2025-10-11T02:44:50.054Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@ -149,7 +149,6 @@ services:
|
||||
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai}
|
||||
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-}
|
||||
PM2_INSTANCES: ${PM2_INSTANCES:-2}
|
||||
LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100}
|
||||
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
|
||||
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
|
||||
|
||||
@ -844,7 +844,6 @@ services:
|
||||
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai}
|
||||
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-}
|
||||
PM2_INSTANCES: ${PM2_INSTANCES:-2}
|
||||
LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100}
|
||||
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
|
||||
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
|
||||
|
||||
@ -50,24 +50,18 @@ ENV MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
ENV MARKETPLACE_URL=https://marketplace.dify.ai
|
||||
ENV PORT=3000
|
||||
ENV NEXT_TELEMETRY_DISABLED=1
|
||||
ENV PM2_INSTANCES=2
|
||||
|
||||
# set timezone
|
||||
ENV TZ=UTC
|
||||
RUN ln -s /usr/share/zoneinfo/${TZ} /etc/localtime \
|
||||
&& echo ${TZ} > /etc/timezone
|
||||
|
||||
# global runtime packages
|
||||
RUN pnpm add -g pm2
|
||||
|
||||
|
||||
# Create non-root user
|
||||
ARG dify_uid=1001
|
||||
RUN addgroup -S -g ${dify_uid} dify && \
|
||||
adduser -S -u ${dify_uid} -G dify -s /bin/ash -h /home/dify dify && \
|
||||
mkdir /app && \
|
||||
mkdir /.pm2 && \
|
||||
chown -R dify:dify /app /.pm2
|
||||
chown -R dify:dify /app
|
||||
|
||||
|
||||
WORKDIR /app/web
|
||||
|
||||
@ -89,8 +89,6 @@ If you want to customize the host and port:
|
||||
pnpm run start --port=3001 --host=0.0.0.0
|
||||
```
|
||||
|
||||
If you want to customize the number of instances launched by PM2, you can configure `PM2_INSTANCES` in `docker-compose.yaml` or `Dockerfile`.
|
||||
|
||||
## Storybook
|
||||
|
||||
This project uses [Storybook](https://storybook.js.org/) for UI component development.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import Alert from './alert'
|
||||
import Alert from '../alert'
|
||||
|
||||
describe('Alert', () => {
|
||||
const defaultProps = {
|
||||
@ -1,5 +1,5 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import AppUnavailable from './app-unavailable'
|
||||
import AppUnavailable from '../app-unavailable'
|
||||
|
||||
describe('AppUnavailable', () => {
|
||||
beforeEach(() => {
|
||||
@ -1,5 +1,5 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import Badge from './badge'
|
||||
import Badge from '../badge'
|
||||
|
||||
describe('Badge', () => {
|
||||
describe('Rendering', () => {
|
||||
@ -1,5 +1,5 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import ThemeSelector from './theme-selector'
|
||||
import ThemeSelector from '../theme-selector'
|
||||
|
||||
// Mock next-themes with controllable state
|
||||
let mockTheme = 'system'
|
||||
@ -1,5 +1,5 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import ThemeSwitcher from './theme-switcher'
|
||||
import ThemeSwitcher from '../theme-switcher'
|
||||
|
||||
let mockTheme = 'system'
|
||||
const mockSetTheme = vi.fn()
|
||||
@ -1,5 +1,5 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { ActionButton, ActionButtonState } from './index'
|
||||
import { ActionButton, ActionButtonState } from '../index'
|
||||
|
||||
describe('ActionButton', () => {
|
||||
it('renders button with default props', () => {
|
||||
@ -4,7 +4,7 @@ import type { AgentLogDetailResponse } from '@/models/log'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { fetchAgentLogDetail } from '@/service/log'
|
||||
import AgentLogDetail from './detail'
|
||||
import AgentLogDetail from '../detail'
|
||||
|
||||
vi.mock('@/service/log', () => ({
|
||||
fetchAgentLogDetail: vi.fn(),
|
||||
@ -3,7 +3,7 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { useClickAway } from 'ahooks'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { fetchAgentLogDetail } from '@/service/log'
|
||||
import AgentLogModal from './index'
|
||||
import AgentLogModal from '../index'
|
||||
|
||||
vi.mock('@/service/log', () => ({
|
||||
fetchAgentLogDetail: vi.fn(),
|
||||
@ -1,6 +1,6 @@
|
||||
import type { AgentIteration } from '@/models/log'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import Iteration from './iteration'
|
||||
import Iteration from '../iteration'
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({
|
||||
default: ({ title, value }: { title: React.ReactNode, value: string | object }) => (
|
||||
@ -1,6 +1,6 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import ResultPanel from './result'
|
||||
import ResultPanel from '../result'
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({
|
||||
default: ({ title, value }: { title: React.ReactNode, value: string | object }) => (
|
||||
@ -2,7 +2,7 @@ import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import ToolCallItem from './tool-call'
|
||||
import ToolCallItem from '../tool-call'
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({
|
||||
default: ({ title, value }: { title: React.ReactNode, value: string | object }) => (
|
||||
@ -1,7 +1,7 @@
|
||||
import type { AgentIteration } from '@/models/log'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import TracingPanel from './tracing'
|
||||
import TracingPanel from '../tracing'
|
||||
|
||||
vi.mock('@/app/components/workflow/block-icon', () => ({
|
||||
default: () => <div data-testid="block-icon" />,
|
||||
@ -1,5 +1,5 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import AnswerIcon from '.'
|
||||
import AnswerIcon from '..'
|
||||
|
||||
describe('AnswerIcon', () => {
|
||||
it('renders default emoji when no icon or image is provided', () => {
|
||||
@ -1,5 +1,5 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import ImageInput from './ImageInput'
|
||||
import ImageInput from '../ImageInput'
|
||||
|
||||
const createObjectURLMock = vi.fn(() => 'blob:mock-url')
|
||||
const revokeObjectURLMock = vi.fn()
|
||||
@ -1,5 +1,5 @@
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { useDraggableUploader } from './hooks'
|
||||
import { useDraggableUploader } from '../hooks'
|
||||
|
||||
type MockDragEventOverrides = {
|
||||
dataTransfer?: { files: File[] }
|
||||
@ -3,7 +3,7 @@ import type { ImageFile } from '@/types/app'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import AppIconPicker from './index'
|
||||
import AppIconPicker from '../index'
|
||||
import 'vitest-canvas-mock'
|
||||
|
||||
type LocalFileUploaderOptions = {
|
||||
@ -93,7 +93,7 @@ vi.mock('react-easy-crop', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('../image-uploader/hooks', () => ({
|
||||
vi.mock('../../image-uploader/hooks', () => ({
|
||||
useLocalFileUploader: (options: LocalFileUploaderOptions) => {
|
||||
mocks.onUpload = options.onUpload
|
||||
return { handleLocalFileUpload: mocks.handleLocalFileUpload }
|
||||
@ -1,4 +1,4 @@
|
||||
import getCroppedImg, { checkIsAnimatedImage, createImage, getMimeType, getRadianAngle, rotateSize } from './utils'
|
||||
import getCroppedImg, { checkIsAnimatedImage, createImage, getMimeType, getRadianAngle, rotateSize } from '../utils'
|
||||
|
||||
type ImageLoadEventType = 'load' | 'error'
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import AppIcon from './index'
|
||||
import AppIcon from '../index'
|
||||
|
||||
// Mock emoji-mart initialization
|
||||
vi.mock('emoji-mart', () => ({
|
||||
@ -2,7 +2,7 @@ import { act, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import i18next from 'i18next'
|
||||
import { useParams, usePathname } from 'next/navigation'
|
||||
import AudioBtn from './index'
|
||||
import AudioBtn from '../index'
|
||||
|
||||
const mockPlayAudio = vi.fn()
|
||||
const mockPauseAudio = vi.fn()
|
||||
@ -4,7 +4,7 @@ import { vi } from 'vitest'
|
||||
import useThemeMock from '@/hooks/use-theme'
|
||||
|
||||
import { Theme } from '@/types/app'
|
||||
import AudioPlayer from './AudioPlayer'
|
||||
import AudioPlayer from '../AudioPlayer'
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: vi.fn(() => ({ theme: 'light' })),
|
||||
@ -3,12 +3,12 @@ import * as React from 'react'
|
||||
// AudioGallery.spec.tsx
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import AudioGallery from './index'
|
||||
import AudioGallery from '../index'
|
||||
|
||||
// Mock AudioPlayer so we only assert prop forwarding
|
||||
const audioPlayerMock = vi.fn()
|
||||
|
||||
vi.mock('./AudioPlayer', () => ({
|
||||
vi.mock('../AudioPlayer', () => ({
|
||||
default: (props: { srcs: string[] }) => {
|
||||
audioPlayerMock(props)
|
||||
return <div data-testid="audio-player" />
|
||||
@ -1,6 +1,6 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { sleep } from '@/utils'
|
||||
import AutoHeightTextarea from './index'
|
||||
import AutoHeightTextarea from '../index'
|
||||
|
||||
vi.mock('@/utils', async () => {
|
||||
const actual = await vi.importActual('@/utils')
|
||||
@ -1,5 +1,5 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import Avatar from './index'
|
||||
import Avatar from '../index'
|
||||
|
||||
describe('Avatar', () => {
|
||||
beforeEach(() => {
|
||||
@ -1,5 +1,5 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import Badge, { BadgeState, BadgeVariants } from './index'
|
||||
import Badge, { BadgeState, BadgeVariants } from '../index'
|
||||
|
||||
describe('Badge', () => {
|
||||
describe('Rendering', () => {
|
||||
@ -1,7 +1,7 @@
|
||||
import { cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import BlockInput, { getInputKeys } from './index'
|
||||
import BlockInput, { getInputKeys } from '../index'
|
||||
|
||||
vi.mock('@/utils/var', () => ({
|
||||
checkKeys: vi.fn((_keys: string[]) => ({
|
||||
@ -1,5 +1,5 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import AddButton from './add-button'
|
||||
import AddButton from '../add-button'
|
||||
|
||||
describe('AddButton', () => {
|
||||
describe('Rendering', () => {
|
||||
@ -1,6 +1,6 @@
|
||||
import { cleanup, fireEvent, render } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import Button from './index'
|
||||
import Button from '../index'
|
||||
|
||||
afterEach(cleanup)
|
||||
// https://testing-library.com/docs/queries/about
|
||||
@ -1,5 +1,5 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import SyncButton from './sync-button'
|
||||
import SyncButton from '../sync-button'
|
||||
|
||||
describe('SyncButton', () => {
|
||||
describe('Rendering', () => {
|
||||
@ -1,7 +1,7 @@
|
||||
import type { Mock } from 'vitest'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import useEmblaCarousel from 'embla-carousel-react'
|
||||
import { Carousel, useCarousel } from './index'
|
||||
import { Carousel, useCarousel } from '../index'
|
||||
|
||||
vi.mock('embla-carousel-react', () => ({
|
||||
default: vi.fn(),
|
||||
@ -1,5 +1,5 @@
|
||||
import type { ChatConfig, ChatItemInTree } from '../types'
|
||||
import type { ChatWithHistoryContextValue } from './context'
|
||||
import type { ChatConfig, ChatItemInTree } from '../../types'
|
||||
import type { ChatWithHistoryContextValue } from '../context'
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import type { AppData, AppMeta, ConversationItem } from '@/models/share'
|
||||
import type { HumanInputFormData } from '@/types/workflow'
|
||||
@ -12,17 +12,17 @@ import {
|
||||
stopChatMessageResponding,
|
||||
} from '@/service/share'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { useChat } from '../chat/hooks'
|
||||
import { useChat } from '../../chat/hooks'
|
||||
|
||||
import { isValidGeneratedAnswer } from '../utils'
|
||||
import ChatWrapper from './chat-wrapper'
|
||||
import { useChatWithHistoryContext } from './context'
|
||||
import { isValidGeneratedAnswer } from '../../utils'
|
||||
import ChatWrapper from '../chat-wrapper'
|
||||
import { useChatWithHistoryContext } from '../context'
|
||||
|
||||
vi.mock('../chat/hooks', () => ({
|
||||
vi.mock('../../chat/hooks', () => ({
|
||||
useChat: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('./context', () => ({
|
||||
vi.mock('../context', () => ({
|
||||
useChatWithHistoryContext: vi.fn(),
|
||||
}))
|
||||
|
||||
@ -37,7 +37,7 @@ vi.mock('next/navigation', () => ({
|
||||
useParams: vi.fn(() => ({ token: 'test-token' })),
|
||||
}))
|
||||
|
||||
vi.mock('../utils', () => ({
|
||||
vi.mock('../../utils', () => ({
|
||||
isValidGeneratedAnswer: vi.fn(),
|
||||
getLastAnswer: vi.fn(),
|
||||
}))
|
||||
@ -1,12 +1,12 @@
|
||||
import type { ChatConfig } from '../types'
|
||||
import type { ChatWithHistoryContextValue } from './context'
|
||||
import type { ChatConfig } from '../../types'
|
||||
import type { ChatWithHistoryContextValue } from '../context'
|
||||
import type { AppData, AppMeta, ConversationItem } from '@/models/share'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import { useChatWithHistoryContext } from './context'
|
||||
import HeaderInMobile from './header-in-mobile'
|
||||
import { useChatWithHistoryContext } from '../context'
|
||||
import HeaderInMobile from '../header-in-mobile'
|
||||
|
||||
vi.mock('@/hooks/use-breakpoints', () => ({
|
||||
default: vi.fn(),
|
||||
@ -17,7 +17,7 @@ vi.mock('@/hooks/use-breakpoints', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('./context', () => ({
|
||||
vi.mock('../context', () => ({
|
||||
useChatWithHistoryContext: vi.fn(),
|
||||
ChatWithHistoryContext: { Provider: ({ children }: { children: React.ReactNode }) => <div>{children}</div> },
|
||||
}))
|
||||
@ -33,7 +33,7 @@ vi.mock('next/navigation', () => ({
|
||||
useParams: vi.fn(() => ({})),
|
||||
}))
|
||||
|
||||
vi.mock('../embedded-chatbot/theme/theme-context', () => ({
|
||||
vi.mock('../../embedded-chatbot/theme/theme-context', () => ({
|
||||
useThemeContext: vi.fn(() => ({
|
||||
buildTheme: vi.fn(),
|
||||
})),
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user