diff --git a/api/.env.example b/api/.env.example
index ae7e82c779..7878308588 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -491,3 +491,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
# Prevent Clickjacking
ALLOW_EMBED=false
+
+# Dataset queue monitor configuration
+QUEUE_MONITOR_THRESHOLD=200
+# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
+QUEUE_MONITOR_ALERT_EMAILS=
+# Monitor interval in minutes, default is 30 minutes
+QUEUE_MONITOR_INTERVAL=30
diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py
index 1b015b3267..2dcf1710b0 100644
--- a/api/configs/middleware/__init__.py
+++ b/api/configs/middleware/__init__.py
@@ -2,7 +2,7 @@ import os
from typing import Any, Literal, Optional
from urllib.parse import parse_qsl, quote_plus
-from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
+from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from .cache.redis_config import RedisConfig
@@ -256,6 +256,25 @@ class InternalTestConfig(BaseSettings):
)
+class DatasetQueueMonitorConfig(BaseSettings):
+ """
+ Configuration settings for Dataset Queue Monitor
+ """
+
+ QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field(
+ description="Threshold for dataset queue monitor",
+ default=200,
+ )
+ QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field(
+ description="Emails for dataset queue monitor alert, separated by commas",
+ default=None,
+ )
+ QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field(
+ description="Interval for dataset queue monitor in minutes",
+ default=30,
+ )
+
+
class MiddlewareConfig(
# place the configs in alphabet order
CeleryConfig,
@@ -303,5 +322,6 @@ class MiddlewareConfig(
BaiduVectorDBConfig,
OpenGaussConfig,
TableStoreConfig,
+ DatasetQueueMonitorConfig,
):
pass
diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index ea8a9f0f41..ab7ab4dcf0 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -175,8 +175,11 @@ class DocumentAddByFileApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")
- if not dataset.indexing_technique and not args.get("indexing_technique"):
+
+ indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
+ if not indexing_technique:
raise ValueError("indexing_technique is required.")
+ args["indexing_technique"] = indexing_technique
# save file info
file = request.files["file"]
diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py
index 5017835565..e1c021a44a 100644
--- a/api/core/entities/model_entities.py
+++ b/api/core/entities/model_entities.py
@@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel):
status: ModelStatus
load_balancing_enabled: bool = False
+ def raise_for_status(self) -> None:
+ """
+ Check model status and raise ValueError if not active.
+
+ :raises ValueError: When model status is not active, with a descriptive message
+ """
+ if self.status == ModelStatus.ACTIVE:
+ return
+
+ error_messages = {
+ ModelStatus.NO_CONFIGURE: "Model is not configured",
+ ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
+ ModelStatus.NO_PERMISSION: "No permission to use this model",
+ ModelStatus.DISABLED: "Model is disabled",
+ }
+
+ if self.status in error_messages:
+ raise ValueError(error_messages[self.status])
+
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
"""
diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py
index 231743bf2a..06fdb089d4 100644
--- a/api/core/extension/extensible.py
+++ b/api/core/extension/extensible.py
@@ -41,45 +41,53 @@ class Extensible:
extensions = []
position_map: dict[str, int] = {}
- # get the path of the current class
- current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
- current_dir_path = os.path.dirname(current_path)
+ # Get the package name from the module path
+ package_name = ".".join(cls.__module__.split(".")[:-1])
- # traverse subdirectories
- for subdir_name in os.listdir(current_dir_path):
- if subdir_name.startswith("__"):
- continue
+ try:
+ # Get package directory path
+ package_spec = importlib.util.find_spec(package_name)
+ if not package_spec or not package_spec.origin:
+ raise ImportError(f"Could not find package {package_name}")
- subdir_path = os.path.join(current_dir_path, subdir_name)
- extension_name = subdir_name
- if os.path.isdir(subdir_path):
+ package_dir = os.path.dirname(package_spec.origin)
+
+ # Traverse subdirectories
+ for subdir_name in os.listdir(package_dir):
+ if subdir_name.startswith("__"):
+ continue
+
+ subdir_path = os.path.join(package_dir, subdir_name)
+ if not os.path.isdir(subdir_path):
+ continue
+
+ extension_name = subdir_name
file_names = os.listdir(subdir_path)
- # is builtin extension, builtin extension
- # in the front-end page and business logic, there are special treatments.
+ # Check for extension module file
+ if (extension_name + ".py") not in file_names:
+ logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
+ continue
+
+ # Check for builtin flag and position
builtin = False
- # default position is 0 can not be None for sort_to_dict_by_position_map
position = 0
if "__builtin__" in file_names:
builtin = True
-
builtin_file_path = os.path.join(subdir_path, "__builtin__")
if os.path.exists(builtin_file_path):
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
position_map[extension_name] = position
- if (extension_name + ".py") not in file_names:
- logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
- continue
-
- # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
- py_path = os.path.join(subdir_path, extension_name + ".py")
- spec = importlib.util.spec_from_file_location(extension_name, py_path)
+ # Import the extension module
+ module_name = f"{package_name}.{extension_name}.{extension_name}"
+ spec = importlib.util.find_spec(module_name)
if not spec or not spec.loader:
- raise Exception(f"Failed to load module {extension_name} from {py_path}")
+ raise ImportError(f"Failed to load module {module_name}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
+ # Find extension class
extension_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
@@ -87,21 +95,21 @@ class Extensible:
break
if not extension_class:
- logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
+ logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
continue
+ # Load schema if not builtin
json_data: dict[str, Any] = {}
if not builtin:
- if "schema.json" not in file_names:
+ json_path = os.path.join(subdir_path, "schema.json")
+ if not os.path.exists(json_path):
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
continue
- json_path = os.path.join(subdir_path, "schema.json")
- json_data = {}
- if os.path.exists(json_path):
- with open(json_path, encoding="utf-8") as f:
- json_data = json.load(f)
+ with open(json_path, encoding="utf-8") as f:
+ json_data = json.load(f)
+ # Create extension
extensions.append(
ModuleExtension(
extension_class=extension_class,
@@ -113,6 +121,11 @@ class Extensible:
)
)
+ except Exception as e:
+ logging.exception("Error scanning extensions")
+ raise
+
+ # Sort extensions by position
sorted_extensions = sort_to_dict_by_position_map(
position_map=position_map, data=extensions, name_func=lambda x: x.name
)
diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py
index 373ef2bbe2..568149cc37 100644
--- a/api/core/model_runtime/entities/model_entities.py
+++ b/api/core/model_runtime/entities/model_entities.py
@@ -160,6 +160,10 @@ class ProviderModel(BaseModel):
deprecated: bool = False
model_config = ConfigDict(protected_namespaces=())
+ @property
+ def support_structure_output(self) -> bool:
+ return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
+
class ParameterRule(BaseModel):
"""
diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py
index 7570200175..488a394679 100644
--- a/api/core/provider_manager.py
+++ b/api/core/provider_manager.py
@@ -3,7 +3,9 @@ from collections import defaultdict
from json import JSONDecodeError
from typing import Any, Optional, cast
+from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
+from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
@@ -393,19 +395,13 @@ class ProviderManager:
@staticmethod
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
- """
- Get all provider records of the workspace.
-
- :param tenant_id: workspace id
- :return:
- """
- providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
-
provider_name_to_provider_records_dict = defaultdict(list)
- for provider in providers:
- # TODO: Use provider name with prefix after the data migration
- provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
-
+ with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
+ providers = session.scalars(stmt)
+ for provider in providers:
+ # Use provider name with prefix after the data migration
+ provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
return provider_name_to_provider_records_dict
@staticmethod
@@ -416,17 +412,12 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
- # Get all provider model records of the workspace
- provider_models = (
- db.session.query(ProviderModel)
- .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
- .all()
- )
-
provider_name_to_provider_model_records_dict = defaultdict(list)
- for provider_model in provider_models:
- provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
-
+ with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
+ provider_models = session.scalars(stmt)
+ for provider_model in provider_models:
+ provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
return provider_name_to_provider_model_records_dict
@staticmethod
@@ -437,17 +428,14 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
- preferred_provider_types = (
- db.session.query(TenantPreferredModelProvider)
- .filter(TenantPreferredModelProvider.tenant_id == tenant_id)
- .all()
- )
-
- provider_name_to_preferred_provider_type_records_dict = {
- preferred_provider_type.provider_name: preferred_provider_type
- for preferred_provider_type in preferred_provider_types
- }
-
+ provider_name_to_preferred_provider_type_records_dict = {}
+ with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
+ preferred_provider_types = session.scalars(stmt)
+ provider_name_to_preferred_provider_type_records_dict = {
+ preferred_provider_type.provider_name: preferred_provider_type
+ for preferred_provider_type in preferred_provider_types
+ }
return provider_name_to_preferred_provider_type_records_dict
@staticmethod
@@ -458,18 +446,14 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
- provider_model_settings = (
- db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
- )
-
provider_name_to_provider_model_settings_dict = defaultdict(list)
- for provider_model_setting in provider_model_settings:
- (
+ with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
+ provider_model_settings = session.scalars(stmt)
+ for provider_model_setting in provider_model_settings:
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
provider_model_setting
)
- )
-
return provider_name_to_provider_model_settings_dict
@staticmethod
@@ -492,15 +476,14 @@ class ProviderManager:
if not model_load_balancing_enabled:
return {}
- provider_load_balancing_configs = (
- db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
- )
-
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
- for provider_load_balancing_config in provider_load_balancing_configs:
- provider_name_to_provider_load_balancing_model_configs_dict[
- provider_load_balancing_config.provider_name
- ].append(provider_load_balancing_config)
+ with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
+ provider_load_balancing_configs = session.scalars(stmt)
+ for provider_load_balancing_config in provider_load_balancing_configs:
+ provider_name_to_provider_load_balancing_model_configs_dict[
+ provider_load_balancing_config.provider_name
+ ].append(provider_load_balancing_config)
return provider_name_to_provider_load_balancing_model_configs_dict
@@ -626,10 +609,9 @@ class ProviderManager:
if not cached_provider_credentials:
try:
# fix origin data
- if (
- custom_provider_record.encrypted_config
- and not custom_provider_record.encrypted_config.startswith("{")
- ):
+ if custom_provider_record.encrypted_config is None:
+ raise ValueError("No credentials found")
+ if not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
@@ -733,7 +715,7 @@ class ProviderManager:
return SystemConfiguration(enabled=False)
# Convert provider_records to dict
- quota_type_to_provider_records_dict = {}
+ quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value:
continue
@@ -758,6 +740,11 @@ class ProviderManager:
else:
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
+ if provider_record.quota_used is None:
+ raise ValueError("quota_used is None")
+ if provider_record.quota_limit is None:
+ raise ValueError("quota_limit is None")
+
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
@@ -791,10 +778,9 @@ class ProviderManager:
cached_provider_credentials = provider_credentials_cache.get()
if not cached_provider_credentials:
- try:
- provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
- except JSONDecodeError:
- provider_credentials = {}
+ provider_credentials: dict[str, Any] = {}
+ if provider_records and provider_records[0].encrypted_config:
+ provider_credentials = json.loads(provider_records[0].encrypted_config)
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py
index 486b4b01af..36d0688807 100644
--- a/api/core/workflow/nodes/llm/entities.py
+++ b/api/core/workflow/nodes/llm/entities.py
@@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData):
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None
- structured_output_enabled: bool = False
+ # We used 'structured_output_enabled' in the past, but it's not a good name.
+ structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
@field_validator("prompt_config", mode="before")
@classmethod
@@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData):
if v is None:
return PromptConfig()
return v
+
+ @property
+ def structured_output_enabled(self) -> bool:
+ return self.structured_output_switch_on and self.structured_output is not None
diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py
index df8f614db3..ee181cf3bf 100644
--- a/api/core/workflow/nodes/llm/node.py
+++ b/api/core/workflow/nodes/llm/node.py
@@ -12,9 +12,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
-from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
-from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -74,7 +72,6 @@ from core.workflow.nodes.event import (
from core.workflow.utils.structured_output.entities import (
ResponseFormat,
SpecialModelType,
- SupportStructuredOutputStatus,
)
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@@ -277,7 +274,7 @@ class LLMNode(BaseNode[LLMNodeData]):
llm_usage=usage,
)
)
- except LLMNodeError as e:
+ except ValueError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@@ -527,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
- model_name = node_data_model.name
- provider_name = node_data_model.provider
+ if not node_data_model.mode:
+ raise LLMModeRequiredError("LLM mode is required.")
- model_manager = ModelManager()
- model_instance = model_manager.get_model_instance(
- tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
+ model = ModelManager().get_model_instance(
+ tenant_id=self.tenant_id,
+ model_type=ModelType.LLM,
+ provider=node_data_model.provider,
+ model=node_data_model.name,
)
- provider_model_bundle = model_instance.provider_model_bundle
- model_type_instance = model_instance.model_type_instance
- model_type_instance = cast(LargeLanguageModel, model_type_instance)
-
- model_credentials = model_instance.credentials
+ model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
# check model
- provider_model = provider_model_bundle.configuration.get_provider_model(
- model=model_name, model_type=ModelType.LLM
+ provider_model = model.provider_model_bundle.configuration.get_provider_model(
+ model=node_data_model.name, model_type=ModelType.LLM
)
if provider_model is None:
- raise ModelNotExistError(f"Model {model_name} not exist.")
-
- if provider_model.status == ModelStatus.NO_CONFIGURE:
- raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
- elif provider_model.status == ModelStatus.NO_PERMISSION:
- raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
- elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
- raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
+ raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
+ provider_model.raise_for_status()
# model config
- completion_params = node_data_model.completion_params
- stop = []
- if "stop" in completion_params:
- stop = completion_params["stop"]
- del completion_params["stop"]
-
- # get model mode
- model_mode = node_data_model.mode
- if not model_mode:
- raise LLMModeRequiredError("LLM mode is required.")
-
- model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
+ stop: list[str] = []
+ if "stop" in node_data_model.completion_params:
+ stop = node_data_model.completion_params.pop("stop")
+ model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
- raise ModelNotExistError(f"Model {model_name} not exist.")
- support_structured_output = self._check_model_structured_output_support()
- if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
- completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
- elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
- # Set appropriate response format based on model capabilities
- self._set_response_format(completion_params, model_schema.parameter_rules)
- return model_instance, ModelConfigWithCredentialsEntity(
- provider=provider_name,
- model=model_name,
+ raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
+
+ if self.node_data.structured_output_enabled:
+ if model_schema.support_structure_output:
+ node_data_model.completion_params = self._handle_native_json_schema(
+ node_data_model.completion_params, model_schema.parameter_rules
+ )
+ else:
+ # Set appropriate response format based on model capabilities
+ self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules)
+
+ return model, ModelConfigWithCredentialsEntity(
+ provider=node_data_model.provider,
+ model=node_data_model.name,
model_schema=model_schema,
- mode=model_mode,
- provider_model_bundle=provider_model_bundle,
- credentials=model_credentials,
- parameters=completion_params,
+ mode=node_data_model.mode,
+ provider_model_bundle=model.provider_model_bundle,
+ credentials=model.credentials,
+ parameters=node_data_model.completion_params,
stop=stop,
)
@@ -786,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]):
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
- support_structured_output = self._check_model_structured_output_support()
- if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
- filtered_prompt_messages = self._handle_prompt_based_schema(
- prompt_messages=filtered_prompt_messages,
- )
- stop = model_config.stop
- return filtered_prompt_messages, stop
+
+ model = ModelManager().get_model_instance(
+ tenant_id=self.tenant_id,
+ model_type=ModelType.LLM,
+ provider=self.node_data.model.provider,
+ model=self.node_data.model.name,
+ )
+ model_schema = model.model_type_instance.get_model_schema(
+ model=self.node_data.model.name,
+ credentials=model.credentials,
+ )
+ if not model_schema:
+ raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.")
+ if self.node_data.structured_output_enabled:
+ if not model_schema.support_structure_output:
+ filtered_prompt_messages = self._handle_prompt_based_schema(
+ prompt_messages=filtered_prompt_messages,
+ )
+ return filtered_prompt_messages, model_config.stop
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
structured_output: dict[str, Any] = {}
@@ -903,7 +900,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_mapping["#context#"] = node_data.context.variable_selector
if node_data.vision.enabled:
- variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value]
+ variable_mapping["#files#"] = node_data.vision.configs.variable_selector
if node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
@@ -1185,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]):
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
- def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
- """
- Check if the current model supports structured output.
-
- Returns:
- SupportStructuredOutput: The support status of structured output
- """
- # Early return if structured output is disabled
- if (
- not isinstance(self.node_data, LLMNodeData)
- or not self.node_data.structured_output_enabled
- or not self.node_data.structured_output
- ):
- return SupportStructuredOutputStatus.DISABLED
- # Get model schema and check if it exists
- model_schema = self._fetch_model_schema(self.node_data.model.provider)
- if not model_schema:
- return SupportStructuredOutputStatus.DISABLED
-
- # Check if model supports structured output feature
- return (
- SupportStructuredOutputStatus.SUPPORTED
- if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
- else SupportStructuredOutputStatus.UNSUPPORTED
- )
-
def _save_multimodal_output_and_convert_result_to_markdown(
self,
contents: str | list[PromptMessageContentUnionTypes] | None,
diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py
index 7954acbaee..6491042bfe 100644
--- a/api/core/workflow/utils/structured_output/entities.py
+++ b/api/core/workflow/utils/structured_output/entities.py
@@ -14,11 +14,3 @@ class SpecialModelType(StrEnum):
GEMINI = "gemini"
OLLAMA = "ollama"
-
-
-class SupportStructuredOutputStatus(StrEnum):
- """Constants for structured output support status"""
-
- SUPPORTED = "supported"
- UNSUPPORTED = "unsupported"
- DISABLED = "disabled"
diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py
index 26bd6b3577..a837552007 100644
--- a/api/extensions/ext_celery.py
+++ b/api/extensions/ext_celery.py
@@ -70,6 +70,7 @@ def init_app(app: DifyApp) -> Celery:
"schedule.update_tidb_serverless_status_task",
"schedule.clean_messages",
"schedule.mail_clean_document_notify_task",
+ "schedule.queue_monitor_task",
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = {
@@ -98,6 +99,12 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
"schedule": crontab(minute="0", hour="10", day_of_week="1"),
},
+ "datasets-queue-monitor": {
+ "task": "schedule.queue_monitor_task.queue_monitor_task",
+ "schedule": timedelta(
+ minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
+ ),
+ },
}
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
diff --git a/api/models/provider.py b/api/models/provider.py
index 497cbefc61..1e25f0c90f 100644
--- a/api/models/provider.py
+++ b/api/models/provider.py
@@ -1,6 +1,9 @@
+from datetime import datetime
from enum import Enum
+from typing import Optional
-from sqlalchemy import func
+from sqlalchemy import func, text
+from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .engine import db
@@ -51,20 +54,24 @@ class Provider(Base):
),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
- encrypted_config = db.Column(db.Text, nullable=True)
- is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- last_used = db.Column(db.DateTime, nullable=True)
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ provider_type: Mapped[str] = mapped_column(
+ db.String(40), nullable=False, server_default=text("'custom'::character varying")
+ )
+ encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
+ is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
+ last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
- quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying"))
- quota_limit = db.Column(db.BigInteger, nullable=True)
- quota_used = db.Column(db.BigInteger, default=0)
+ quota_type: Mapped[Optional[str]] = mapped_column(
+ db.String(40), nullable=True, server_default=text("''::character varying")
+ )
+ quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True)
+ quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def __repr__(self):
return (
@@ -104,15 +111,15 @@ class ProviderModel(Base):
),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- model_name = db.Column(db.String(255), nullable=False)
- model_type = db.Column(db.String(40), nullable=False)
- encrypted_config = db.Column(db.Text, nullable=True)
- is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
+ encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
+ is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TenantDefaultModel(Base):
@@ -122,13 +129,13 @@ class TenantDefaultModel(Base):
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- model_name = db.Column(db.String(255), nullable=False)
- model_type = db.Column(db.String(40), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TenantPreferredModelProvider(Base):
@@ -138,12 +145,12 @@ class TenantPreferredModelProvider(Base):
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- preferred_provider_type = db.Column(db.String(40), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderOrder(Base):
@@ -153,22 +160,24 @@ class ProviderOrder(Base):
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- account_id = db.Column(StringUUID, nullable=False)
- payment_product_id = db.Column(db.String(191), nullable=False)
- payment_id = db.Column(db.String(191))
- transaction_id = db.Column(db.String(191))
- quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1"))
- currency = db.Column(db.String(40))
- total_amount = db.Column(db.Integer)
- payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying"))
- paid_at = db.Column(db.DateTime)
- pay_failed_at = db.Column(db.DateTime)
- refunded_at = db.Column(db.DateTime)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False)
+ payment_id: Mapped[Optional[str]] = mapped_column(db.String(191))
+ transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191))
+ quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1"))
+ currency: Mapped[Optional[str]] = mapped_column(db.String(40))
+ total_amount: Mapped[Optional[int]] = mapped_column(db.Integer)
+ payment_status: Mapped[str] = mapped_column(
+ db.String(40), nullable=False, server_default=text("'wait_pay'::character varying")
+ )
+ paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
+ pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
+ refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderModelSetting(Base):
@@ -182,15 +191,15 @@ class ProviderModelSetting(Base):
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- model_name = db.Column(db.String(255), nullable=False)
- model_type = db.Column(db.String(40), nullable=False)
- enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
- load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
+ enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
+ load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class LoadBalancingModelConfig(Base):
@@ -204,13 +213,13 @@ class LoadBalancingModelConfig(Base):
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- model_name = db.Column(db.String(255), nullable=False)
- model_type = db.Column(db.String(40), nullable=False)
- name = db.Column(db.String(255), nullable=False)
- encrypted_config = db.Column(db.Text, nullable=True)
- enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
+ name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
+ enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py
new file mode 100644
index 0000000000..e3a7021b9d
--- /dev/null
+++ b/api/schedule/queue_monitor_task.py
@@ -0,0 +1,62 @@
+import logging
+from datetime import datetime
+from urllib.parse import urlparse
+
+import click
+from flask import render_template
+from redis import Redis
+
+import app
+from configs import dify_config
+from extensions.ext_database import db
+from extensions.ext_mail import mail
+
+# Create a dedicated Redis connection (using the same configuration as Celery)
+celery_broker_url = dify_config.CELERY_BROKER_URL
+
+parsed = urlparse(celery_broker_url)
+host = parsed.hostname or "localhost"
+port = parsed.port or 6379
+password = parsed.password or None
+redis_db = parsed.path.strip("/") or "1" # type: ignore
+
+celery_redis = Redis(host=host, port=port, password=password, db=redis_db)
+
+
+@app.celery.task(queue="monitor")
+def queue_monitor_task():
+ queue_name = "dataset"
+ threshold = dify_config.QUEUE_MONITOR_THRESHOLD
+
+ try:
+ queue_length = celery_redis.llen(f"{queue_name}")
+ logging.info(click.style(f"Start monitor {queue_name}", fg="green"))
+ logging.info(click.style(f"Queue length: {queue_length}", fg="green"))
+
+ if queue_length >= threshold:
+ warning_msg = f"Queue {queue_name} task count exceeded the limit.: {queue_length}/{threshold}"
+ logging.warning(click.style(warning_msg, fg="red"))
+ alter_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS
+ if alter_emails:
+ to_list = alter_emails.split(",")
+ for to in to_list:
+ try:
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ html_content = render_template(
+ "queue_monitor_alert_email_template_en-US.html",
+ queue_name=queue_name,
+ queue_length=queue_length,
+ threshold=threshold,
+ alert_time=current_time,
+ )
+ mail.send(
+ to=to, subject="Alert: Dataset Queue pending tasks exceeded the limit", html=html_content
+ )
+ except Exception as e:
+ logging.exception(click.style("Exception occurred during sending email", fg="red"))
+
+ except Exception as e:
+ logging.exception(click.style("Exception occurred during queue monitoring", fg="red"))
+ finally:
+ if db.session.is_active:
+ db.session.close()
diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py
index f32bc4f187..51b6343fdc 100644
--- a/api/tasks/batch_create_segment_to_index_task.py
+++ b/api/tasks/batch_create_segment_to_index_task.py
@@ -5,7 +5,7 @@ import uuid
import click
from celery import shared_task # type: ignore
-from sqlalchemy import func, select
+from sqlalchemy import func
from sqlalchemy.orm import Session
from core.model_manager import ModelManager
@@ -68,11 +68,6 @@ def batch_create_segment_to_index_task(
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
- word_count_change = 0
- segments_to_insert: list[str] = []
- max_position_stmt = select(func.max(DocumentSegment.position)).where(
- DocumentSegment.document_id == dataset_document.id
- )
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
diff --git a/api/templates/queue_monitor_alert_email_template_en-US.html b/api/templates/queue_monitor_alert_email_template_en-US.html
new file mode 100644
index 0000000000..2885210864
--- /dev/null
+++ b/api/templates/queue_monitor_alert_email_template_en-US.html
@@ -0,0 +1,129 @@
+
+
+
+
+
+
+
+
+
+
+
Queue Monitoring Alert
+
Our system has detected an abnormal queue status that requires your attention:
+
+
+
Queue Task Alert
+
+ Queue "{{queue_name}}" has {{queue_length}} pending tasks (Threshold: {{threshold}})
+
+
+
+
+
Recommended actions:
+
1. Check the queue processing status in the system dashboard
+
2. Verify if there are any processing bottlenecks
+
3. Consider scaling up workers if needed
+
+
+
Additional Information:
+
+ - Alert triggered at: {{alert_time}}
+
+
+
+
+
diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py
index 5fbee266bd..6aa48b1cbb 100644
--- a/api/tests/integration_tests/workflow/nodes/test_llm.py
+++ b/api/tests/integration_tests/workflow/nodes/test_llm.py
@@ -3,11 +3,16 @@ import os
import time
import uuid
from collections.abc import Generator
-from unittest.mock import MagicMock
+from decimal import Decimal
+from unittest.mock import MagicMock, patch
import pytest
+from app_factory import create_app
+from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
+from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
@@ -19,13 +24,27 @@ from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
-from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
+@pytest.fixture(scope="session")
+def app():
+ # Set up storage configuration
+ os.environ["STORAGE_TYPE"] = "opendal"
+ os.environ["OPENDAL_SCHEME"] = "fs"
+ os.environ["OPENDAL_FS_ROOT"] = "storage"
+
+ # Ensure storage directory exists
+ os.makedirs("storage", exist_ok=True)
+
+ app = create_app()
+ dify_config.LOGIN_DISABLED = True
+ return app
+
+
def init_llm_node(config: dict) -> LLMNode:
graph_config = {
"edges": [
@@ -40,13 +59,19 @@ def init_llm_node(config: dict) -> LLMNode:
graph = Graph.init(graph_config=graph_config)
+ # Use proper UUIDs for database compatibility
+ tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
+ app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
+ workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
+ user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
+
init_params = GraphInitParams(
- tenant_id="1",
- app_id="1",
+ tenant_id=tenant_id,
+ app_id=app_id,
workflow_type=WorkflowType.WORKFLOW,
- workflow_id="1",
+ workflow_id=workflow_id,
graph_config=graph_config,
- user_id="1",
+ user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
@@ -77,115 +102,197 @@ def init_llm_node(config: dict) -> LLMNode:
return node
-def test_execute_llm(setup_model_mock):
- node = init_llm_node(
- config={
- "id": "llm",
- "data": {
- "title": "123",
- "type": "llm",
- "model": {
- "provider": "langgenius/openai/openai",
- "name": "gpt-3.5-turbo",
- "mode": "chat",
- "completion_params": {},
+def test_execute_llm(app):
+ with app.app_context():
+ node = init_llm_node(
+ config={
+ "id": "llm",
+ "data": {
+ "title": "123",
+ "type": "llm",
+ "model": {
+ "provider": "langgenius/openai/openai",
+ "name": "gpt-3.5-turbo",
+ "mode": "chat",
+ "completion_params": {},
+ },
+ "prompt_template": [
+ {
+ "role": "system",
+ "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
+ },
+ {"role": "user", "text": "{{#sys.query#}}"},
+ ],
+ "memory": None,
+ "context": {"enabled": False},
+ "vision": {"enabled": False},
},
- "prompt_template": [
- {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
- {"role": "user", "text": "{{#sys.query#}}"},
- ],
- "memory": None,
- "context": {"enabled": False},
- "vision": {"enabled": False},
},
- },
- )
+ )
- credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
+ credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
- # Mock db.session.close()
- db.session.close = MagicMock()
+ # Create a proper LLM result with real entities
+ mock_usage = LLMUsage(
+ prompt_tokens=30,
+ prompt_unit_price=Decimal("0.001"),
+ prompt_price_unit=Decimal("1000"),
+ prompt_price=Decimal("0.00003"),
+ completion_tokens=20,
+ completion_unit_price=Decimal("0.002"),
+ completion_price_unit=Decimal("1000"),
+ completion_price=Decimal("0.00004"),
+ total_tokens=50,
+ total_price=Decimal("0.00007"),
+ currency="USD",
+ latency=0.5,
+ )
- node._fetch_model_config = get_mocked_fetch_model_config(
- provider="langgenius/openai/openai",
- model="gpt-3.5-turbo",
- mode="chat",
- credentials=credentials,
- )
+ mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
- # execute node
- result = node._run()
- assert isinstance(result, Generator)
+ mock_llm_result = LLMResult(
+ model="gpt-3.5-turbo",
+ prompt_messages=[],
+ message=mock_message,
+ usage=mock_usage,
+ )
- for item in result:
- if isinstance(item, RunCompletedEvent):
- assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
- assert item.run_result.process_data is not None
- assert item.run_result.outputs is not None
- assert item.run_result.outputs.get("text") is not None
- assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
+ # Create a simple mock model instance that doesn't call real providers
+ mock_model_instance = MagicMock()
+ mock_model_instance.invoke_llm.return_value = mock_llm_result
+
+ # Create a simple mock model config with required attributes
+ mock_model_config = MagicMock()
+ mock_model_config.mode = "chat"
+ mock_model_config.provider = "langgenius/openai/openai"
+ mock_model_config.model = "gpt-3.5-turbo"
+ mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
+
+ # Mock the _fetch_model_config method
+ def mock_fetch_model_config_func(_node_data_model):
+ return mock_model_instance, mock_model_config
+
+ # Also mock ModelManager.get_model_instance to avoid database calls
+ def mock_get_model_instance(_self, **kwargs):
+ return mock_model_instance
+
+ with (
+ patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
+ patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
+ ):
+ # execute node
+ result = node._run()
+ assert isinstance(result, Generator)
+
+ for item in result:
+ if isinstance(item, RunCompletedEvent):
+ assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert item.run_result.process_data is not None
+ assert item.run_result.outputs is not None
+ assert item.run_result.outputs.get("text") is not None
+ assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
-def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock):
+def test_execute_llm_with_jinja2(app, setup_code_executor_mock):
"""
Test execute LLM node with jinja2
"""
- node = init_llm_node(
- config={
- "id": "llm",
- "data": {
- "title": "123",
- "type": "llm",
- "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
- "prompt_config": {
- "jinja2_variables": [
- {"variable": "sys_query", "value_selector": ["sys", "query"]},
- {"variable": "output", "value_selector": ["abc", "output"]},
- ]
+ with app.app_context():
+ node = init_llm_node(
+ config={
+ "id": "llm",
+ "data": {
+ "title": "123",
+ "type": "llm",
+ "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
+ "prompt_config": {
+ "jinja2_variables": [
+ {"variable": "sys_query", "value_selector": ["sys", "query"]},
+ {"variable": "output", "value_selector": ["abc", "output"]},
+ ]
+ },
+ "prompt_template": [
+ {
+ "role": "system",
+ "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
+ "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
+ "edition_type": "jinja2",
+ },
+ {
+ "role": "user",
+ "text": "{{#sys.query#}}",
+ "jinja2_text": "{{sys_query}}",
+ "edition_type": "basic",
+ },
+ ],
+ "memory": None,
+ "context": {"enabled": False},
+ "vision": {"enabled": False},
},
- "prompt_template": [
- {
- "role": "system",
- "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
- "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
- "edition_type": "jinja2",
- },
- {
- "role": "user",
- "text": "{{#sys.query#}}",
- "jinja2_text": "{{sys_query}}",
- "edition_type": "basic",
- },
- ],
- "memory": None,
- "context": {"enabled": False},
- "vision": {"enabled": False},
},
- },
- )
+ )
- credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
+ # Mock db.session.close()
+ db.session.close = MagicMock()
- # Mock db.session.close()
- db.session.close = MagicMock()
+ # Create a proper LLM result with real entities
+ mock_usage = LLMUsage(
+ prompt_tokens=30,
+ prompt_unit_price=Decimal("0.001"),
+ prompt_price_unit=Decimal("1000"),
+ prompt_price=Decimal("0.00003"),
+ completion_tokens=20,
+ completion_unit_price=Decimal("0.002"),
+ completion_price_unit=Decimal("1000"),
+ completion_price=Decimal("0.00004"),
+ total_tokens=50,
+ total_price=Decimal("0.00007"),
+ currency="USD",
+ latency=0.5,
+ )
- node._fetch_model_config = get_mocked_fetch_model_config(
- provider="langgenius/openai/openai",
- model="gpt-3.5-turbo",
- mode="chat",
- credentials=credentials,
- )
+ mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
- # execute node
- result = node._run()
+ mock_llm_result = LLMResult(
+ model="gpt-3.5-turbo",
+ prompt_messages=[],
+ message=mock_message,
+ usage=mock_usage,
+ )
- for item in result:
- if isinstance(item, RunCompletedEvent):
- assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
- assert item.run_result.process_data is not None
- assert "sunny" in json.dumps(item.run_result.process_data)
- assert "what's the weather today?" in json.dumps(item.run_result.process_data)
+ # Create a simple mock model instance that doesn't call real providers
+ mock_model_instance = MagicMock()
+ mock_model_instance.invoke_llm.return_value = mock_llm_result
+
+ # Create a simple mock model config with required attributes
+ mock_model_config = MagicMock()
+ mock_model_config.mode = "chat"
+ mock_model_config.provider = "openai"
+ mock_model_config.model = "gpt-3.5-turbo"
+ mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
+
+ # Mock the _fetch_model_config method
+ def mock_fetch_model_config_func(_node_data_model):
+ return mock_model_instance, mock_model_config
+
+ # Also mock ModelManager.get_model_instance to avoid database calls
+ def mock_get_model_instance(_self, **kwargs):
+ return mock_model_instance
+
+ with (
+ patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
+ patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
+ ):
+ # execute node
+ result = node._run()
+
+ for item in result:
+ if isinstance(item, RunCompletedEvent):
+ assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert item.run_result.process_data is not None
+ assert "sunny" in json.dumps(item.run_result.process_data)
+ assert "what's the weather today?" in json.dumps(item.run_result.process_data)
def test_extract_json():
diff --git a/docker/.env.example b/docker/.env.example
index ac9536be03..4cf5e202d0 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -1057,7 +1057,7 @@ PLUGIN_MAX_EXECUTION_TIMEOUT=600
PIP_MIRROR_URL=
# https://github.com/langgenius/dify-plugin-daemon/blob/main/.env.example
-# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss
+# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss volcengine_tos
PLUGIN_STORAGE_TYPE=local
PLUGIN_STORAGE_LOCAL_ROOT=/app/storage
PLUGIN_WORKING_PATH=/app/storage/cwd
@@ -1087,6 +1087,11 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID=
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET=
PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4
PLUGIN_ALIYUN_OSS_PATH=
+# Plugin oss volcengine tos
+PLUGIN_VOLCENGINE_TOS_ENDPOINT=
+PLUGIN_VOLCENGINE_TOS_ACCESS_KEY=
+PLUGIN_VOLCENGINE_TOS_SECRET_KEY=
+PLUGIN_VOLCENGINE_TOS_REGION=
# ------------------------------
# OTLP Collector Configuration
@@ -1106,3 +1111,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
# Prevent Clickjacking
ALLOW_EMBED=false
+
+# Dataset queue monitor configuration
+QUEUE_MONITOR_THRESHOLD=200
+# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
+QUEUE_MONITOR_ALERT_EMAILS=
+# Monitor interval in minutes, default is 30 minutes
+QUEUE_MONITOR_INTERVAL=30
diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml
index 74a7b87bf9..75bdab1a06 100644
--- a/docker/docker-compose-template.yaml
+++ b/docker/docker-compose-template.yaml
@@ -184,6 +184,10 @@ services:
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
+ VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
+ VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
+ VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
+ VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ports:
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
volumes:
diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml
index d4a0b94619..8276e2977f 100644
--- a/docker/docker-compose.middleware.yaml
+++ b/docker/docker-compose.middleware.yaml
@@ -121,6 +121,10 @@ services:
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
+ VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
+ VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
+ VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
+ VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ports:
- "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}"
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index 41e86d015f..e559021684 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -484,6 +484,10 @@ x-shared-env: &shared-api-worker-env
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
PLUGIN_ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
PLUGIN_ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
+ PLUGIN_VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
+ PLUGIN_VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
+ PLUGIN_VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
+ PLUGIN_VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ENABLE_OTEL: ${ENABLE_OTEL:-false}
OTLP_BASE_ENDPOINT: ${OTLP_BASE_ENDPOINT:-http://localhost:4318}
OTLP_API_KEY: ${OTLP_API_KEY:-}
@@ -497,6 +501,9 @@ x-shared-env: &shared-api-worker-env
OTEL_BATCH_EXPORT_TIMEOUT: ${OTEL_BATCH_EXPORT_TIMEOUT:-10000}
OTEL_METRIC_EXPORT_TIMEOUT: ${OTEL_METRIC_EXPORT_TIMEOUT:-30000}
ALLOW_EMBED: ${ALLOW_EMBED:-false}
+ QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200}
+ QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-}
+ QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30}
services:
# API service
@@ -683,6 +690,10 @@ services:
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
+ VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
+ VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
+ VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
+ VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ports:
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
volumes:
diff --git a/docker/middleware.env.example b/docker/middleware.env.example
index ba6859885b..66037f281c 100644
--- a/docker/middleware.env.example
+++ b/docker/middleware.env.example
@@ -152,3 +152,8 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID=
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET=
PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4
PLUGIN_ALIYUN_OSS_PATH=
+# Plugin oss volcengine tos
+PLUGIN_VOLCENGINE_TOS_ENDPOINT=
+PLUGIN_VOLCENGINE_TOS_ACCESS_KEY=
+PLUGIN_VOLCENGINE_TOS_SECRET_KEY=
+PLUGIN_VOLCENGINE_TOS_REGION=
diff --git a/web/app/components/base/app-icon-picker/index.tsx b/web/app/components/base/app-icon-picker/index.tsx
index 975f8aeb6c..048ae2a576 100644
--- a/web/app/components/base/app-icon-picker/index.tsx
+++ b/web/app/components/base/app-icon-picker/index.tsx
@@ -5,7 +5,6 @@ import type { Area } from 'react-easy-crop'
import Modal from '../modal'
import Divider from '../divider'
import Button from '../button'
-import { ImagePlus } from '../icons/src/vender/line/images'
import { useLocalFileUploader } from '../image-uploader/hooks'
import EmojiPickerInner from '../emoji-picker/Inner'
import type { OnImageInput } from './ImageInput'
@@ -16,6 +15,7 @@ import type { AppIconType, ImageFile } from '@/types/app'
import cn from '@/utils/classnames'
import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config'
import { noop } from 'lodash-es'
+import { RiImageCircleAiLine } from '@remixicon/react'
export type AppIconEmojiSelection = {
type: 'emoji'
@@ -46,7 +46,7 @@ const AppIconPicker: FC = ({
const tabs = [
{ key: 'emoji', label: t('app.iconPicker.emoji'), icon: 🤖 },
- { key: 'image', label: t('app.iconPicker.image'), icon: },
+ { key: 'image', label: t('app.iconPicker.image'), icon: },
]
const [activeTab, setActiveTab] = useState('emoji')
@@ -119,10 +119,10 @@ const AppIconPicker: FC = ({
{tabs.map(tab => (