diff --git a/api/.env.example b/api/.env.example index ae7e82c779..7878308588 100644 --- a/api/.env.example +++ b/api/.env.example @@ -491,3 +491,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000 # Prevent Clickjacking ALLOW_EMBED=false + +# Dataset queue monitor configuration +QUEUE_MONITOR_THRESHOLD=200 +# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai +QUEUE_MONITOR_ALERT_EMAILS= +# Monitor interval in minutes, default is 30 minutes +QUEUE_MONITOR_INTERVAL=30 diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 1b015b3267..2dcf1710b0 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -2,7 +2,7 @@ import os from typing import Any, Literal, Optional from urllib.parse import parse_qsl, quote_plus -from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field +from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field from pydantic_settings import BaseSettings from .cache.redis_config import RedisConfig @@ -256,6 +256,25 @@ class InternalTestConfig(BaseSettings): ) +class DatasetQueueMonitorConfig(BaseSettings): + """ + Configuration settings for Dataset Queue Monitor + """ + + QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field( + description="Threshold for dataset queue monitor", + default=200, + ) + QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field( + description="Emails for dataset queue monitor alert, separated by commas", + default=None, + ) + QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field( + description="Interval for dataset queue monitor in minutes", + default=30, + ) + + class MiddlewareConfig( # place the configs in alphabet order CeleryConfig, @@ -303,5 +322,6 @@ class MiddlewareConfig( BaiduVectorDBConfig, OpenGaussConfig, TableStoreConfig, + DatasetQueueMonitorConfig, ): pass diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index ea8a9f0f41..ab7ab4dcf0 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -175,8 +175,11 @@ class DocumentAddByFileApi(DatasetApiResource): if not dataset: raise ValueError("Dataset does not exist.") - if not dataset.indexing_technique and not args.get("indexing_technique"): + + indexing_technique = args.get("indexing_technique") or dataset.indexing_technique + if not indexing_technique: raise ValueError("indexing_technique is required.") + args["indexing_technique"] = indexing_technique # save file info file = request.files["file"] diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 5017835565..e1c021a44a 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel): status: ModelStatus load_balancing_enabled: bool = False + def raise_for_status(self) -> None: + """ + Check model status and raise ValueError if not active. + + :raises ValueError: When model status is not active, with a descriptive message + """ + if self.status == ModelStatus.ACTIVE: + return + + error_messages = { + ModelStatus.NO_CONFIGURE: "Model is not configured", + ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded", + ModelStatus.NO_PERMISSION: "No permission to use this model", + ModelStatus.DISABLED: "Model is disabled", + } + + if self.status in error_messages: + raise ValueError(error_messages[self.status]) + class ModelWithProviderEntity(ProviderModelWithStatusEntity): """ diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 231743bf2a..06fdb089d4 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -41,45 +41,53 @@ class Extensible: extensions = [] position_map: dict[str, int] = {} - # get the path of the current class - current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") - current_dir_path = os.path.dirname(current_path) + # Get the package name from the module path + package_name = ".".join(cls.__module__.split(".")[:-1]) - # traverse subdirectories - for subdir_name in os.listdir(current_dir_path): - if subdir_name.startswith("__"): - continue + try: + # Get package directory path + package_spec = importlib.util.find_spec(package_name) + if not package_spec or not package_spec.origin: + raise ImportError(f"Could not find package {package_name}") - subdir_path = os.path.join(current_dir_path, subdir_name) - extension_name = subdir_name - if os.path.isdir(subdir_path): + package_dir = os.path.dirname(package_spec.origin) + + # Traverse subdirectories + for subdir_name in os.listdir(package_dir): + if subdir_name.startswith("__"): + continue + + subdir_path = os.path.join(package_dir, subdir_name) + if not os.path.isdir(subdir_path): + continue + + extension_name = subdir_name file_names = os.listdir(subdir_path) - # is builtin extension, builtin extension - # in the front-end page and business logic, there are special treatments. + # Check for extension module file + if (extension_name + ".py") not in file_names: + logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") + continue + + # Check for builtin flag and position builtin = False - # default position is 0 can not be None for sort_to_dict_by_position_map position = 0 if "__builtin__" in file_names: builtin = True - builtin_file_path = os.path.join(subdir_path, "__builtin__") if os.path.exists(builtin_file_path): position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) position_map[extension_name] = position - if (extension_name + ".py") not in file_names: - logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") - continue - - # Dynamic loading {subdir_name}.py file and find the subclass of Extensible - py_path = os.path.join(subdir_path, extension_name + ".py") - spec = importlib.util.spec_from_file_location(extension_name, py_path) + # Import the extension module + module_name = f"{package_name}.{extension_name}.{extension_name}" + spec = importlib.util.find_spec(module_name) if not spec or not spec.loader: - raise Exception(f"Failed to load module {extension_name} from {py_path}") + raise ImportError(f"Failed to load module {module_name}") mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) + # Find extension class extension_class = None for name, obj in vars(mod).items(): if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: @@ -87,21 +95,21 @@ class Extensible: break if not extension_class: - logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") + logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.") continue + # Load schema if not builtin json_data: dict[str, Any] = {} if not builtin: - if "schema.json" not in file_names: + json_path = os.path.join(subdir_path, "schema.json") + if not os.path.exists(json_path): logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") continue - json_path = os.path.join(subdir_path, "schema.json") - json_data = {} - if os.path.exists(json_path): - with open(json_path, encoding="utf-8") as f: - json_data = json.load(f) + with open(json_path, encoding="utf-8") as f: + json_data = json.load(f) + # Create extension extensions.append( ModuleExtension( extension_class=extension_class, @@ -113,6 +121,11 @@ class Extensible: ) ) + except Exception as e: + logging.exception("Error scanning extensions") + raise + + # Sort extensions by position sorted_extensions = sort_to_dict_by_position_map( position_map=position_map, data=extensions, name_func=lambda x: x.name ) diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 373ef2bbe2..568149cc37 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -160,6 +160,10 @@ class ProviderModel(BaseModel): deprecated: bool = False model_config = ConfigDict(protected_namespaces=()) + @property + def support_structure_output(self) -> bool: + return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features + class ParameterRule(BaseModel): """ diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 7570200175..488a394679 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,7 +3,9 @@ from collections import defaultdict from json import JSONDecodeError from typing import Any, Optional, cast +from sqlalchemy import select from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity @@ -393,19 +395,13 @@ class ProviderManager: @staticmethod def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: - """ - Get all provider records of the workspace. - - :param tenant_id: workspace id - :return: - """ - providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() - provider_name_to_provider_records_dict = defaultdict(list) - for provider in providers: - # TODO: Use provider name with prefix after the data migration - provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) - + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True) + providers = session.scalars(stmt) + for provider in providers: + # Use provider name with prefix after the data migration + provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) return provider_name_to_provider_records_dict @staticmethod @@ -416,17 +412,12 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - # Get all provider model records of the workspace - provider_models = ( - db.session.query(ProviderModel) - .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) - .all() - ) - provider_name_to_provider_model_records_dict = defaultdict(list) - for provider_model in provider_models: - provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) - + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) + provider_models = session.scalars(stmt) + for provider_model in provider_models: + provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) return provider_name_to_provider_model_records_dict @staticmethod @@ -437,17 +428,14 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - preferred_provider_types = ( - db.session.query(TenantPreferredModelProvider) - .filter(TenantPreferredModelProvider.tenant_id == tenant_id) - .all() - ) - - provider_name_to_preferred_provider_type_records_dict = { - preferred_provider_type.provider_name: preferred_provider_type - for preferred_provider_type in preferred_provider_types - } - + provider_name_to_preferred_provider_type_records_dict = {} + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id) + preferred_provider_types = session.scalars(stmt) + provider_name_to_preferred_provider_type_records_dict = { + preferred_provider_type.provider_name: preferred_provider_type + for preferred_provider_type in preferred_provider_types + } return provider_name_to_preferred_provider_type_records_dict @staticmethod @@ -458,18 +446,14 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - provider_model_settings = ( - db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() - ) - provider_name_to_provider_model_settings_dict = defaultdict(list) - for provider_model_setting in provider_model_settings: - ( + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id) + provider_model_settings = session.scalars(stmt) + for provider_model_setting in provider_model_settings: provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( provider_model_setting ) - ) - return provider_name_to_provider_model_settings_dict @staticmethod @@ -492,15 +476,14 @@ class ProviderManager: if not model_load_balancing_enabled: return {} - provider_load_balancing_configs = ( - db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() - ) - provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) - for provider_load_balancing_config in provider_load_balancing_configs: - provider_name_to_provider_load_balancing_model_configs_dict[ - provider_load_balancing_config.provider_name - ].append(provider_load_balancing_config) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id) + provider_load_balancing_configs = session.scalars(stmt) + for provider_load_balancing_config in provider_load_balancing_configs: + provider_name_to_provider_load_balancing_model_configs_dict[ + provider_load_balancing_config.provider_name + ].append(provider_load_balancing_config) return provider_name_to_provider_load_balancing_model_configs_dict @@ -626,10 +609,9 @@ class ProviderManager: if not cached_provider_credentials: try: # fix origin data - if ( - custom_provider_record.encrypted_config - and not custom_provider_record.encrypted_config.startswith("{") - ): + if custom_provider_record.encrypted_config is None: + raise ValueError("No credentials found") + if not custom_provider_record.encrypted_config.startswith("{"): provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) @@ -733,7 +715,7 @@ class ProviderManager: return SystemConfiguration(enabled=False) # Convert provider_records to dict - quota_type_to_provider_records_dict = {} + quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {} for provider_record in provider_records: if provider_record.provider_type != ProviderType.SYSTEM.value: continue @@ -758,6 +740,11 @@ class ProviderManager: else: provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] + if provider_record.quota_used is None: + raise ValueError("quota_used is None") + if provider_record.quota_limit is None: + raise ValueError("quota_limit is None") + quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, @@ -791,10 +778,9 @@ class ProviderManager: cached_provider_credentials = provider_credentials_cache.get() if not cached_provider_credentials: - try: - provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) - except JSONDecodeError: - provider_credentials = {} + provider_credentials: dict[str, Any] = {} + if provider_records and provider_records[0].encrypted_config: + provider_credentials = json.loads(provider_records[0].encrypted_config) # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 486b4b01af..36d0688807 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData): context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) structured_output: dict | None = None - structured_output_enabled: bool = False + # We used 'structured_output_enabled' in the past, but it's not a good name. + structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") @field_validator("prompt_config", mode="before") @classmethod @@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData): if v is None: return PromptConfig() return v + + @property + def structured_output_enabled(self) -> bool: + return self.structured_output_switch_on and self.structured_output is not None diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index df8f614db3..ee181cf3bf 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -12,9 +12,7 @@ from sqlalchemy.orm import Session from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file import FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.memory.token_buffer_memory import TokenBufferMemory @@ -74,7 +72,6 @@ from core.workflow.nodes.event import ( from core.workflow.utils.structured_output.entities import ( ResponseFormat, SpecialModelType, - SupportStructuredOutputStatus, ) from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -277,7 +274,7 @@ class LLMNode(BaseNode[LLMNodeData]): llm_usage=usage, ) ) - except LLMNodeError as e: + except ValueError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -527,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]): def _fetch_model_config( self, node_data_model: ModelConfig ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - model_name = node_data_model.name - provider_name = node_data_model.provider + if not node_data_model.mode: + raise LLMModeRequiredError("LLM mode is required.") - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name + model = ModelManager().get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=node_data_model.provider, + model=node_data_model.name, ) - provider_model_bundle = model_instance.provider_model_bundle - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_credentials = model_instance.credentials + model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance) # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, model_type=ModelType.LLM + provider_model = model.provider_model_bundle.configuration.get_provider_model( + model=node_data_model.name, model_type=ModelType.LLM ) if provider_model is None: - raise ModelNotExistError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + provider_model.raise_for_status() # model config - completion_params = node_data_model.completion_params - stop = [] - if "stop" in completion_params: - stop = completion_params["stop"] - del completion_params["stop"] - - # get model mode - model_mode = node_data_model.mode - if not model_mode: - raise LLMModeRequiredError("LLM mode is required.") - - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + stop: list[str] = [] + if "stop" in node_data_model.completion_params: + stop = node_data_model.completion_params.pop("stop") + model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) if not model_schema: - raise ModelNotExistError(f"Model {model_name} not exist.") - support_structured_output = self._check_model_structured_output_support() - if support_structured_output == SupportStructuredOutputStatus.SUPPORTED: - completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) - elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: - # Set appropriate response format based on model capabilities - self._set_response_format(completion_params, model_schema.parameter_rules) - return model_instance, ModelConfigWithCredentialsEntity( - provider=provider_name, - model=model_name, + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + + if self.node_data.structured_output_enabled: + if model_schema.support_structure_output: + node_data_model.completion_params = self._handle_native_json_schema( + node_data_model.completion_params, model_schema.parameter_rules + ) + else: + # Set appropriate response format based on model capabilities + self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules) + + return model, ModelConfigWithCredentialsEntity( + provider=node_data_model.provider, + model=node_data_model.name, model_schema=model_schema, - mode=model_mode, - provider_model_bundle=provider_model_bundle, - credentials=model_credentials, - parameters=completion_params, + mode=node_data_model.mode, + provider_model_bundle=model.provider_model_bundle, + credentials=model.credentials, + parameters=node_data_model.completion_params, stop=stop, ) @@ -786,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]): "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) - support_structured_output = self._check_model_structured_output_support() - if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: - filtered_prompt_messages = self._handle_prompt_based_schema( - prompt_messages=filtered_prompt_messages, - ) - stop = model_config.stop - return filtered_prompt_messages, stop + + model = ModelManager().get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=self.node_data.model.provider, + model=self.node_data.model.name, + ) + model_schema = model.model_type_instance.get_model_schema( + model=self.node_data.model.name, + credentials=model.credentials, + ) + if not model_schema: + raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.") + if self.node_data.structured_output_enabled: + if not model_schema.support_structure_output: + filtered_prompt_messages = self._handle_prompt_based_schema( + prompt_messages=filtered_prompt_messages, + ) + return filtered_prompt_messages, model_config.stop def _parse_structured_output(self, result_text: str) -> dict[str, Any]: structured_output: dict[str, Any] = {} @@ -903,7 +900,7 @@ class LLMNode(BaseNode[LLMNodeData]): variable_mapping["#context#"] = node_data.context.variable_selector if node_data.vision.enabled: - variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value] + variable_mapping["#files#"] = node_data.vision.configs.variable_selector if node_data.memory: variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] @@ -1185,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]): except json.JSONDecodeError: raise LLMNodeError("structured_output_schema is not valid JSON format") - def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus: - """ - Check if the current model supports structured output. - - Returns: - SupportStructuredOutput: The support status of structured output - """ - # Early return if structured output is disabled - if ( - not isinstance(self.node_data, LLMNodeData) - or not self.node_data.structured_output_enabled - or not self.node_data.structured_output - ): - return SupportStructuredOutputStatus.DISABLED - # Get model schema and check if it exists - model_schema = self._fetch_model_schema(self.node_data.model.provider) - if not model_schema: - return SupportStructuredOutputStatus.DISABLED - - # Check if model supports structured output feature - return ( - SupportStructuredOutputStatus.SUPPORTED - if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features) - else SupportStructuredOutputStatus.UNSUPPORTED - ) - def _save_multimodal_output_and_convert_result_to_markdown( self, contents: str | list[PromptMessageContentUnionTypes] | None, diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py index 7954acbaee..6491042bfe 100644 --- a/api/core/workflow/utils/structured_output/entities.py +++ b/api/core/workflow/utils/structured_output/entities.py @@ -14,11 +14,3 @@ class SpecialModelType(StrEnum): GEMINI = "gemini" OLLAMA = "ollama" - - -class SupportStructuredOutputStatus(StrEnum): - """Constants for structured output support status""" - - SUPPORTED = "supported" - UNSUPPORTED = "unsupported" - DISABLED = "disabled" diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 26bd6b3577..a837552007 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -70,6 +70,7 @@ def init_app(app: DifyApp) -> Celery: "schedule.update_tidb_serverless_status_task", "schedule.clean_messages", "schedule.mail_clean_document_notify_task", + "schedule.queue_monitor_task", ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME beat_schedule = { @@ -98,6 +99,12 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task", "schedule": crontab(minute="0", hour="10", day_of_week="1"), }, + "datasets-queue-monitor": { + "task": "schedule.queue_monitor_task.queue_monitor_task", + "schedule": timedelta( + minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30 + ), + }, } celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) diff --git a/api/models/provider.py b/api/models/provider.py index 497cbefc61..1e25f0c90f 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,6 +1,9 @@ +from datetime import datetime from enum import Enum +from typing import Optional -from sqlalchemy import func +from sqlalchemy import func, text +from sqlalchemy.orm import Mapped, mapped_column from .base import Base from .engine import db @@ -51,20 +54,24 @@ class Provider(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) - encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - last_used = db.Column(db.DateTime, nullable=True) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_type: Mapped[str] = mapped_column( + db.String(40), nullable=False, server_default=text("'custom'::character varying") + ) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) - quota_limit = db.Column(db.BigInteger, nullable=True) - quota_used = db.Column(db.BigInteger, default=0) + quota_type: Mapped[Optional[str]] = mapped_column( + db.String(40), nullable=True, server_default=text("''::character varying") + ) + quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) + quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -104,15 +111,15 @@ class ProviderModel(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(Base): @@ -122,13 +129,13 @@ class TenantDefaultModel(Base): db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(Base): @@ -138,12 +145,12 @@ class TenantPreferredModelProvider(Base): db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(Base): @@ -153,22 +160,24 @@ class ProviderOrder(Base): db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - account_id = db.Column(StringUUID, nullable=False) - payment_product_id = db.Column(db.String(191), nullable=False) - payment_id = db.Column(db.String(191)) - transaction_id = db.Column(db.String(191)) - quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) - currency = db.Column(db.String(40)) - total_amount = db.Column(db.Integer) - payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) - paid_at = db.Column(db.DateTime) - pay_failed_at = db.Column(db.DateTime) - refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False) + payment_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) + currency: Mapped[Optional[str]] = mapped_column(db.String(40)) + total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) + payment_status: Mapped[str] = mapped_column( + db.String(40), nullable=False, server_default=text("'wait_pay'::character varying") + ) + paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(Base): @@ -182,15 +191,15 @@ class ProviderModelSetting(Base): db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) + load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(Base): @@ -204,13 +213,13 @@ class LoadBalancingModelConfig(Base): db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - name = db.Column(db.String(255), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + name: Mapped[str] = mapped_column(db.String(255), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py new file mode 100644 index 0000000000..e3a7021b9d --- /dev/null +++ b/api/schedule/queue_monitor_task.py @@ -0,0 +1,62 @@ +import logging +from datetime import datetime +from urllib.parse import urlparse + +import click +from flask import render_template +from redis import Redis + +import app +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_mail import mail + +# Create a dedicated Redis connection (using the same configuration as Celery) +celery_broker_url = dify_config.CELERY_BROKER_URL + +parsed = urlparse(celery_broker_url) +host = parsed.hostname or "localhost" +port = parsed.port or 6379 +password = parsed.password or None +redis_db = parsed.path.strip("/") or "1" # type: ignore + +celery_redis = Redis(host=host, port=port, password=password, db=redis_db) + + +@app.celery.task(queue="monitor") +def queue_monitor_task(): + queue_name = "dataset" + threshold = dify_config.QUEUE_MONITOR_THRESHOLD + + try: + queue_length = celery_redis.llen(f"{queue_name}") + logging.info(click.style(f"Start monitor {queue_name}", fg="green")) + logging.info(click.style(f"Queue length: {queue_length}", fg="green")) + + if queue_length >= threshold: + warning_msg = f"Queue {queue_name} task count exceeded the limit.: {queue_length}/{threshold}" + logging.warning(click.style(warning_msg, fg="red")) + alter_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS + if alter_emails: + to_list = alter_emails.split(",") + for to in to_list: + try: + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + html_content = render_template( + "queue_monitor_alert_email_template_en-US.html", + queue_name=queue_name, + queue_length=queue_length, + threshold=threshold, + alert_time=current_time, + ) + mail.send( + to=to, subject="Alert: Dataset Queue pending tasks exceeded the limit", html=html_content + ) + except Exception as e: + logging.exception(click.style("Exception occurred during sending email", fg="red")) + + except Exception as e: + logging.exception(click.style("Exception occurred during queue monitoring", fg="red")) + finally: + if db.session.is_active: + db.session.close() diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index f32bc4f187..51b6343fdc 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -5,7 +5,7 @@ import uuid import click from celery import shared_task # type: ignore -from sqlalchemy import func, select +from sqlalchemy import func from sqlalchemy.orm import Session from core.model_manager import ModelManager @@ -68,11 +68,6 @@ def batch_create_segment_to_index_task( model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) - word_count_change = 0 - segments_to_insert: list[str] = [] - max_position_stmt = select(func.max(DocumentSegment.position)).where( - DocumentSegment.document_id == dataset_document.id - ) word_count_change = 0 if embedding_model: tokens_list = embedding_model.get_text_embedding_num_tokens( diff --git a/api/templates/queue_monitor_alert_email_template_en-US.html b/api/templates/queue_monitor_alert_email_template_en-US.html new file mode 100644 index 0000000000..2885210864 --- /dev/null +++ b/api/templates/queue_monitor_alert_email_template_en-US.html @@ -0,0 +1,129 @@ + + + + + + + + +
+
+ Dify Logo +
+

Queue Monitoring Alert

+

Our system has detected an abnormal queue status that requires your attention:

+ +
+
Queue Task Alert
+
+ Queue "{{queue_name}}" has {{queue_length}} pending tasks (Threshold: {{threshold}}) +
+
+ +
+

Recommended actions:

+

1. Check the queue processing status in the system dashboard

+

2. Verify if there are any processing bottlenecks

+

3. Consider scaling up workers if needed

+
+ +

Additional Information:

+ +
+ + + diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 5fbee266bd..6aa48b1cbb 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -3,11 +3,16 @@ import os import time import uuid from collections.abc import Generator -from unittest.mock import MagicMock +from decimal import Decimal +from unittest.mock import MagicMock, patch import pytest +from app_factory import create_app +from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import AssistantPromptMessage from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey @@ -19,13 +24,27 @@ from core.workflow.nodes.llm.node import LLMNode from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowType -from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock +@pytest.fixture(scope="session") +def app(): + # Set up storage configuration + os.environ["STORAGE_TYPE"] = "opendal" + os.environ["OPENDAL_SCHEME"] = "fs" + os.environ["OPENDAL_FS_ROOT"] = "storage" + + # Ensure storage directory exists + os.makedirs("storage", exist_ok=True) + + app = create_app() + dify_config.LOGIN_DISABLED = True + return app + + def init_llm_node(config: dict) -> LLMNode: graph_config = { "edges": [ @@ -40,13 +59,19 @@ def init_llm_node(config: dict) -> LLMNode: graph = Graph.init(graph_config=graph_config) + # Use proper UUIDs for database compatibility + tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c" + workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d" + user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e" + init_params = GraphInitParams( - tenant_id="1", - app_id="1", + tenant_id=tenant_id, + app_id=app_id, workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", + workflow_id=workflow_id, graph_config=graph_config, - user_id="1", + user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, call_depth=0, @@ -77,115 +102,197 @@ def init_llm_node(config: dict) -> LLMNode: return node -def test_execute_llm(setup_model_mock): - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": { - "provider": "langgenius/openai/openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": {}, +def test_execute_llm(app): + with app.app_context(): + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": { + "provider": "langgenius/openai/openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": {}, + }, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.", + }, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, }, - "prompt_template": [ - {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, - {"role": "user", "text": "{{#sys.query#}}"}, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, }, - }, - ) + ) - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - # Mock db.session.close() - db.session.close = MagicMock() + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) - node._fetch_model_config = get_mocked_fetch_model_config( - provider="langgenius/openai/openai", - model="gpt-3.5-turbo", - mode="chat", - credentials=credentials, - ) + mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") - # execute node - result = node._run() - assert isinstance(result, Generator) + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert item.run_result.outputs is not None - assert item.run_result.outputs.get("text") is not None - assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result + + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "langgenius/openai/openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config + + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance + + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() + assert isinstance(result, Generator) + + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert item.run_result.outputs is not None + assert item.run_result.outputs.get("text") is not None + assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) -def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock): +def test_execute_llm_with_jinja2(app, setup_code_executor_mock): """ Test execute LLM node with jinja2 """ - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, - "prompt_config": { - "jinja2_variables": [ - {"variable": "sys_query", "value_selector": ["sys", "query"]}, - {"variable": "output", "value_selector": ["abc", "output"]}, - ] + with app.app_context(): + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_config": { + "jinja2_variables": [ + {"variable": "sys_query", "value_selector": ["sys", "query"]}, + {"variable": "output", "value_selector": ["abc", "output"]}, + ] + }, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", + "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", + "edition_type": "jinja2", + }, + { + "role": "user", + "text": "{{#sys.query#}}", + "jinja2_text": "{{sys_query}}", + "edition_type": "basic", + }, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, }, - "prompt_template": [ - { - "role": "system", - "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", - "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", - "edition_type": "jinja2", - }, - { - "role": "user", - "text": "{{#sys.query#}}", - "jinja2_text": "{{sys_query}}", - "edition_type": "basic", - }, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, }, - }, - ) + ) - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + # Mock db.session.close() + db.session.close = MagicMock() - # Mock db.session.close() - db.session.close = MagicMock() + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) - node._fetch_model_config = get_mocked_fetch_model_config( - provider="langgenius/openai/openai", - model="gpt-3.5-turbo", - mode="chat", - credentials=credentials, - ) + mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") - # execute node - result = node._run() + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert "sunny" in json.dumps(item.run_result.process_data) - assert "what's the weather today?" in json.dumps(item.run_result.process_data) + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result + + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config + + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance + + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() + + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert "sunny" in json.dumps(item.run_result.process_data) + assert "what's the weather today?" in json.dumps(item.run_result.process_data) def test_extract_json(): diff --git a/docker/.env.example b/docker/.env.example index ac9536be03..4cf5e202d0 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1057,7 +1057,7 @@ PLUGIN_MAX_EXECUTION_TIMEOUT=600 PIP_MIRROR_URL= # https://github.com/langgenius/dify-plugin-daemon/blob/main/.env.example -# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss +# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss volcengine_tos PLUGIN_STORAGE_TYPE=local PLUGIN_STORAGE_LOCAL_ROOT=/app/storage PLUGIN_WORKING_PATH=/app/storage/cwd @@ -1087,6 +1087,11 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID= PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET= PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4 PLUGIN_ALIYUN_OSS_PATH= +# Plugin oss volcengine tos +PLUGIN_VOLCENGINE_TOS_ENDPOINT= +PLUGIN_VOLCENGINE_TOS_ACCESS_KEY= +PLUGIN_VOLCENGINE_TOS_SECRET_KEY= +PLUGIN_VOLCENGINE_TOS_REGION= # ------------------------------ # OTLP Collector Configuration @@ -1106,3 +1111,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000 # Prevent Clickjacking ALLOW_EMBED=false + +# Dataset queue monitor configuration +QUEUE_MONITOR_THRESHOLD=200 +# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai +QUEUE_MONITOR_ALERT_EMAILS= +# Monitor interval in minutes, default is 30 minutes +QUEUE_MONITOR_INTERVAL=30 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 74a7b87bf9..75bdab1a06 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -184,6 +184,10 @@ services: ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} + VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-} + VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-} + VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} + VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} ports: - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" volumes: diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index d4a0b94619..8276e2977f 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -121,6 +121,10 @@ services: ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} + VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-} + VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-} + VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} + VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} ports: - "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}" - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 41e86d015f..e559021684 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -484,6 +484,10 @@ x-shared-env: &shared-api-worker-env PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} PLUGIN_ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} PLUGIN_ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} + PLUGIN_VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-} + PLUGIN_VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-} + PLUGIN_VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} + PLUGIN_VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} ENABLE_OTEL: ${ENABLE_OTEL:-false} OTLP_BASE_ENDPOINT: ${OTLP_BASE_ENDPOINT:-http://localhost:4318} OTLP_API_KEY: ${OTLP_API_KEY:-} @@ -497,6 +501,9 @@ x-shared-env: &shared-api-worker-env OTEL_BATCH_EXPORT_TIMEOUT: ${OTEL_BATCH_EXPORT_TIMEOUT:-10000} OTEL_METRIC_EXPORT_TIMEOUT: ${OTEL_METRIC_EXPORT_TIMEOUT:-30000} ALLOW_EMBED: ${ALLOW_EMBED:-false} + QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200} + QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-} + QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30} services: # API service @@ -683,6 +690,10 @@ services: ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} + VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-} + VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-} + VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} + VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} ports: - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" volumes: diff --git a/docker/middleware.env.example b/docker/middleware.env.example index ba6859885b..66037f281c 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -152,3 +152,8 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID= PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET= PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4 PLUGIN_ALIYUN_OSS_PATH= +# Plugin oss volcengine tos +PLUGIN_VOLCENGINE_TOS_ENDPOINT= +PLUGIN_VOLCENGINE_TOS_ACCESS_KEY= +PLUGIN_VOLCENGINE_TOS_SECRET_KEY= +PLUGIN_VOLCENGINE_TOS_REGION= diff --git a/web/app/components/base/app-icon-picker/index.tsx b/web/app/components/base/app-icon-picker/index.tsx index 975f8aeb6c..048ae2a576 100644 --- a/web/app/components/base/app-icon-picker/index.tsx +++ b/web/app/components/base/app-icon-picker/index.tsx @@ -5,7 +5,6 @@ import type { Area } from 'react-easy-crop' import Modal from '../modal' import Divider from '../divider' import Button from '../button' -import { ImagePlus } from '../icons/src/vender/line/images' import { useLocalFileUploader } from '../image-uploader/hooks' import EmojiPickerInner from '../emoji-picker/Inner' import type { OnImageInput } from './ImageInput' @@ -16,6 +15,7 @@ import type { AppIconType, ImageFile } from '@/types/app' import cn from '@/utils/classnames' import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' import { noop } from 'lodash-es' +import { RiImageCircleAiLine } from '@remixicon/react' export type AppIconEmojiSelection = { type: 'emoji' @@ -46,7 +46,7 @@ const AppIconPicker: FC = ({ const tabs = [ { key: 'emoji', label: t('app.iconPicker.emoji'), icon: 🤖 }, - { key: 'image', label: t('app.iconPicker.image'), icon: }, + { key: 'image', label: t('app.iconPicker.image'), icon: }, ] const [activeTab, setActiveTab] = useState('emoji') @@ -119,10 +119,10 @@ const AppIconPicker: FC = ({ {tabs.map(tab => ( + {/* Data source form Preview */}
+ {/* Process documents form Preview */} ) diff --git a/web/app/components/rag-pipeline/components/input-field/preview/process-documents.tsx b/web/app/components/rag-pipeline/components/input-field/preview/process-documents.tsx index 2d93f66b75..caabf4bde6 100644 --- a/web/app/components/rag-pipeline/components/input-field/preview/process-documents.tsx +++ b/web/app/components/rag-pipeline/components/input-field/preview/process-documents.tsx @@ -16,7 +16,7 @@ const ProcessDocuments = ({ const { data: paramsConfig } = useDraftPipelineProcessingParams({ pipeline_id: pipelineId!, node_id: dataSourceNodeId, - }) + }, !!pipelineId && !!dataSourceNodeId) return (
diff --git a/web/app/components/rag-pipeline/components/panel/test-run/data-source-options/index.tsx b/web/app/components/rag-pipeline/components/panel/test-run/data-source-options/index.tsx index 9c00aa1d7d..afbba09594 100644 --- a/web/app/components/rag-pipeline/components/panel/test-run/data-source-options/index.tsx +++ b/web/app/components/rag-pipeline/components/panel/test-run/data-source-options/index.tsx @@ -4,12 +4,12 @@ import OptionCard from './option-card' import type { Datasource } from '../types' type DataSourceOptionsProps = { - datasourceNodeId: string + dataSourceNodeId: string onSelect: (option: Datasource) => void } const DataSourceOptions = ({ - datasourceNodeId, + dataSourceNodeId, onSelect, }: DataSourceOptionsProps) => { const { datasources, options } = useDatasourceOptions() @@ -34,7 +34,7 @@ const DataSourceOptions = ({ key={option.value} label={option.label} nodeData={option.data} - selected={datasourceNodeId === option.value} + selected={dataSourceNodeId === option.value} onClick={handelSelect.bind(null, option.value)} /> ))} diff --git a/web/app/components/rag-pipeline/components/panel/test-run/data-source/website-crawl/base/crawler.tsx b/web/app/components/rag-pipeline/components/panel/test-run/data-source/website-crawl/base/crawler.tsx index e3f3edac0a..4c598bfc10 100644 --- a/web/app/components/rag-pipeline/components/panel/test-run/data-source/website-crawl/base/crawler.tsx +++ b/web/app/components/rag-pipeline/components/panel/test-run/data-source/website-crawl/base/crawler.tsx @@ -1,5 +1,5 @@ 'use client' -import React, { useCallback, useEffect, useState } from 'react' +import React, { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import type { CrawlResultItem } from '@/models/datasets' import Header from '@/app/components/datasets/create/website/base/header' @@ -7,15 +7,17 @@ import Options from './options' import Crawling from './crawling' import ErrorMessage from './error-message' import CrawledResult from './crawled-result' -import type { RAGPipelineVariables } from '@/models/pipeline' -import { useDatasourceNodeRun } from '@/service/use-pipeline' -import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' +import { + useDatasourceNodeRun, + useDraftPipelinePreProcessingParams, + usePublishedPipelineProcessingParams, +} from '@/service/use-pipeline' +import { useStore } from '@/app/components/workflow/store' const I18N_PREFIX = 'datasetCreation.stepOne.website' type CrawlerProps = { nodeId: string - variables: RAGPipelineVariables checkedCrawlResult: CrawlResultItem[] onCheckedCrawlResultChange: (payload: CrawlResultItem[]) => void onJobIdChange: (jobId: string) => void @@ -25,6 +27,7 @@ type CrawlerProps = { docLink: string } onPreview?: (payload: CrawlResultItem) => void + usingPublished?: boolean } enum Step { @@ -35,17 +38,23 @@ enum Step { const Crawler = ({ nodeId, - variables, checkedCrawlResult, headerInfo, onCheckedCrawlResultChange, onJobIdChange, onPreview, + usingPublished = false, }: CrawlerProps) => { const { t } = useTranslation() const [step, setStep] = useState(Step.init) const [controlFoldOptions, setControlFoldOptions] = useState(0) - const pipelineId = useDatasetDetailContextWithSelector(s => s.dataset?.pipeline_id) + const pipelineId = useStore(s => s.pipelineId) + + const usePreProcessingParams = useRef(usingPublished ? usePublishedPipelineProcessingParams : useDraftPipelinePreProcessingParams) + const { data: paramsConfig } = usePreProcessingParams.current({ + pipeline_id: pipelineId!, + node_id: nodeId, + }, !!pipelineId && !!nodeId) useEffect(() => { if (step !== Step.init) @@ -95,7 +104,7 @@ const Crawler = ({ />
{ diff --git a/web/app/components/rag-pipeline/components/panel/test-run/data-source/website-crawl/index.tsx b/web/app/components/rag-pipeline/components/panel/test-run/data-source/website-crawl/index.tsx index 181322256f..50be76e524 100644 --- a/web/app/components/rag-pipeline/components/panel/test-run/data-source/website-crawl/index.tsx +++ b/web/app/components/rag-pipeline/components/panel/test-run/data-source/website-crawl/index.tsx @@ -1,12 +1,10 @@ 'use client' import React from 'react' import type { CrawlResultItem } from '@/models/datasets' -import type { RAGPipelineVariables } from '@/models/pipeline' import Crawler from './base/crawler' type WebsiteCrawlProps = { nodeId: string - variables: RAGPipelineVariables checkedCrawlResult: CrawlResultItem[] onCheckedCrawlResultChange: (payload: CrawlResultItem[]) => void onJobIdChange: (jobId: string) => void @@ -16,26 +14,27 @@ type WebsiteCrawlProps = { docLink: string } onPreview?: (payload: CrawlResultItem) => void + usingPublished?: boolean } const WebsiteCrawl = ({ nodeId, - variables, checkedCrawlResult, headerInfo, onCheckedCrawlResultChange, onJobIdChange, onPreview, + usingPublished, }: WebsiteCrawlProps) => { return ( ) } diff --git a/web/app/components/rag-pipeline/components/panel/test-run/index.tsx b/web/app/components/rag-pipeline/components/panel/test-run/index.tsx index 84b48c3653..97c0a4002f 100644 --- a/web/app/components/rag-pipeline/components/panel/test-run/index.tsx +++ b/web/app/components/rag-pipeline/components/panel/test-run/index.tsx @@ -117,7 +117,7 @@ const TestRunPanel = () => { <>
{datasource?.type === DatasourceType.localFile && ( @@ -139,7 +139,6 @@ const TestRunPanel = () => { {datasource?.type === DatasourceType.websiteCrawl && ( { +export const useDraftPipelineProcessingParams = (params: PipelineProcessingParamsRequest, enabled = true) => { const { pipeline_id, node_id } = params return useQuery({ - queryKey: [NAME_SPACE, 'pipeline-processing-params', pipeline_id], + queryKey: [NAME_SPACE, 'pipeline-processing-params', pipeline_id, node_id], queryFn: () => { return get(`/rag/pipelines/${pipeline_id}/workflows/draft/processing/parameters`, { params: { @@ -148,14 +150,14 @@ export const useDraftPipelineProcessingParams = (params: PipelineProcessingParam }) }, staleTime: 0, - enabled: !!pipeline_id && !!node_id, + enabled, }) } export const usePublishedPipelineProcessingParams = (params: PipelineProcessingParamsRequest) => { const { pipeline_id, node_id } = params return useQuery({ - queryKey: [NAME_SPACE, 'pipeline-processing-params', pipeline_id], + queryKey: [NAME_SPACE, 'pipeline-processing-params', pipeline_id, node_id], queryFn: () => { return get(`/rag/pipelines/${pipeline_id}/workflows/published/processing/parameters`, { params: { @@ -163,6 +165,7 @@ export const usePublishedPipelineProcessingParams = (params: PipelineProcessingP }, }) }, + staleTime: 0, }) } @@ -248,3 +251,35 @@ export const useUpdateDataSourceCredentials = ( }, }) } + +export const useDraftPipelinePreProcessingParams = (params: PipelinePreProcessingParamsRequest, enabled = true) => { + const { pipeline_id, node_id } = params + return useQuery({ + queryKey: [NAME_SPACE, 'pipeline-pre-processing-params', pipeline_id, node_id], + queryFn: () => { + return get(`/rag/pipelines/${pipeline_id}/workflows/draft/pre-processing/parameters`, { + params: { + node_id, + }, + }) + }, + staleTime: 0, + enabled, + }) +} + +export const usePublishedPipelinePreProcessingParams = (params: PipelinePreProcessingParamsRequest, enabled = true) => { + const { pipeline_id, node_id } = params + return useQuery({ + queryKey: [NAME_SPACE, 'pipeline-pre-processing-params', pipeline_id, node_id], + queryFn: () => { + return get(`/rag/pipelines/${pipeline_id}/workflows/published/processing/parameters`, { + params: { + node_id, + }, + }) + }, + staleTime: 0, + enabled, + }) +}