diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index b6c9131c08..9c79dbc57e 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -89,7 +89,9 @@ jobs: - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run lint + run: | + pnpm run lint + pnpm run eslint docker-compose-template: name: Docker Compose Template diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 19ca464a79..f730cfa3fe 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -44,22 +44,19 @@ def oauth_server_access_token_required(view): if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp): raise BadRequest("Invalid oauth_provider_app") - if not request.headers.get("Authorization"): - raise BadRequest("Authorization is required") - authorization_header = request.headers.get("Authorization") if not authorization_header: raise BadRequest("Authorization header is required") - parts = authorization_header.split(" ") + parts = authorization_header.strip().split(" ") if len(parts) != 2: raise BadRequest("Invalid Authorization header format") - token_type = parts[0] - if token_type != "Bearer": + token_type = parts[0].strip() + if token_type.lower() != "bearer": raise BadRequest("token_type is invalid") - access_token = parts[1] + access_token = parts[1].strip() if not access_token: raise BadRequest("access_token is required") @@ -125,7 +122,10 @@ class OAuthServerUserTokenApi(Resource): parser.add_argument("refresh_token", type=str, required=False, location="json") parsed_args = parser.parse_args() - grant_type = OAuthGrantType(parsed_args["grant_type"]) + try: + grant_type = OAuthGrantType(parsed_args["grant_type"]) + except ValueError: + raise BadRequest("invalid grant_type") if grant_type == OAuthGrantType.AUTHORIZATION_CODE: if not parsed_args["code"]: @@ -163,8 +163,6 @@ class OAuthServerUserTokenApi(Resource): "refresh_token": refresh_token, } ) - else: - raise BadRequest("invalid grant_type") class OAuthServerUserAccountApi(Resource): diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 8d50b0d41c..22bb81f9e3 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -354,9 +354,6 @@ class DatasetInitApi(Resource): parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() - # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator - if not current_user.is_dataset_editor: - raise Forbidden() knowledge_config = KnowledgeConfig(**args) if knowledge_config.indexing_technique == "high_quality": if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index c5aa318f58..de4f1da801 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -1,8 +1,12 @@ from base64 import b64encode +from collections.abc import Callable from functools import wraps from hashlib import sha1 from hmac import new as hmac_new +from typing import ParamSpec, TypeVar +P = ParamSpec("P") +R = TypeVar("R") from flask import abort, request from configs import dify_config @@ -10,9 +14,9 @@ from extensions.ext_database import db from models.model import EndUser -def billing_inner_api_only(view): +def billing_inner_api_only(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if not dify_config.INNER_API: abort(404) @@ -26,9 +30,9 @@ def billing_inner_api_only(view): return decorated -def enterprise_inner_api_only(view): +def enterprise_inner_api_only(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if not dify_config.INNER_API: abort(404) @@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view): return decorated -def plugin_inner_api_only(view): +def plugin_inner_api_only(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if not dify_config.PLUGIN_DAEMON_KEY: abort(404) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index d2473c15af..cc4b5f65bd 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,7 +1,7 @@ import time from collections.abc import Callable from datetime import timedelta -from enum import Enum +from enum import StrEnum, auto from functools import wraps from typing import Optional @@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser from services.feature_service import FeatureService -class WhereisUserArg(Enum): +class WhereisUserArg(StrEnum): """ Enum for whereis_user_arg. """ - QUERY = "query" - JSON = "json" - FORM = "form" + QUERY = auto() + JSON = auto() + FORM = auto() class FetchUserArg(BaseModel): diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index f7c83f927f..f5e45bcb47 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -334,7 +334,8 @@ class BaseAgentRunner(AppRunner): """ Save agent thought """ - agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first() + stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id) + agent_thought = db.session.scalar(stmt) if not agent_thought: raise ValueError("agent thought not found") @@ -492,7 +493,8 @@ class BaseAgentRunner(AppRunner): return result def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: - files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() + stmt = select(MessageFile).where(MessageFile.message_id == message.id) + files = db.session.scalars(stmt).all() if not files: return UserPromptMessage(content=message.query) if message.app_model_config: diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 52ae20ee16..74e282fdcd 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -450,6 +450,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): worker_thread.start() + # release database connection, because the following new thread operations may take a long time + db.session.refresh(workflow) + db.session.refresh(message) + db.session.refresh(user) + db.session.close() + # return response or stream generator response = self._handle_advanced_chat_response( application_generate_entity=application_generate_entity, diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 452dbbec01..2a8efad15e 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -73,7 +73,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) - app_record = db.session.query(App).where(App.id == app_config.app_id).first() + with Session(db.engine, expire_on_commit=False) as session: + app_record = session.scalar(select(App).where(App.id == app_config.app_id)) + if not app_record: raise ValueError("App not found") @@ -147,7 +149,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): environment_variables=self._workflow.environment_variables, # Based on the definition of `VariableUnion`, # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. - conversation_variables=cast(list[VariableUnion], conversation_variables), + conversation_variables=conversation_variables, ) # init graph diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index b08bd9c872..10e51ff338 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -68,7 +68,6 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager -from events.message_event import message_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Conversation, EndUser, Message, MessageFile @@ -886,10 +885,6 @@ class AdvancedChatAppGenerateTaskPipeline: self._task_state.metadata.usage = usage else: self._task_state.metadata.usage = LLMUsage.empty_usage() - message_was_created.send( - message, - application_generate_entity=self._application_generate_entity, - ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: """ diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 39d6ba39f5..d3207365f3 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from sqlalchemy import select + from core.agent.cot_chat_agent_runner import CotChatAgentRunner from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from core.agent.entities import AgentEntity @@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner): """ app_config = application_generate_entity.app_config app_config = cast(AgentChatAppConfig, app_config) - - app_record = db.session.query(App).where(App.id == app_config.app_id).first() + app_stmt = select(App).where(App.id == app_config.app_id) + app_record = db.session.scalar(app_stmt) if not app_record: raise ValueError("App not found") @@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner): if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - - conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first() + conversation_stmt = select(Conversation).where(Conversation.id == conversation.id) + conversation_result = db.session.scalar(conversation_stmt) if conversation_result is None: raise ValueError("Conversation not found") - message_result = db.session.query(Message).where(Message.id == message.id).first() + msg_stmt = select(Message).where(Message.id == message.id) + message_result = db.session.scalar(msg_stmt) if message_result is None: raise ValueError("Message not found") db.session.close() diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 2cffe4a0a5..feacca1a07 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -1,7 +1,7 @@ import queue import time from abc import abstractmethod -from enum import Enum +from enum import IntEnum, auto from typing import Any, Optional from sqlalchemy.orm import DeclarativeMeta @@ -19,9 +19,9 @@ from core.app.entities.queue_entities import ( from extensions.ext_redis import redis_client -class PublishFrom(Enum): - APPLICATION_MANAGER = 1 - TASK_PIPELINE = 2 +class PublishFrom(IntEnum): + APPLICATION_MANAGER = auto() + TASK_PIPELINE = auto() class AppQueueManager: diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 894d7906d5..4385d0f08d 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from sqlalchemy import select + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.apps.chat.app_config_manager import ChatAppConfig @@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner): """ app_config = application_generate_entity.app_config app_config = cast(ChatAppConfig, app_config) - - app_record = db.session.query(App).where(App.id == app_config.app_id).first() + stmt = select(App).where(App.id == app_config.app_id) + app_record = db.session.scalar(stmt) if not app_record: raise ValueError("App not found") diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 64dade2968..8d2f3d488b 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload from flask import Flask, copy_current_request_context, current_app from pydantic import ValidationError +from sqlalchemy import select from configs import dify_config from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter @@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - message = ( - db.session.query(Message) - .where( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ("api" if isinstance(user, EndUser) else "console"), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ) - .first() + stmt = select(Message).where( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), ) + message = db.session.scalar(stmt) if not message: raise MessageNotExistsError() diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 50d2a0036c..d384bff255 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from sqlalchemy import select + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_runner import AppRunner from core.app.apps.completion.app_config_manager import CompletionAppConfig @@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner): """ app_config = application_generate_entity.app_config app_config = cast(CompletionAppConfig, app_config) - - app_record = db.session.query(App).where(App.id == app_config.app_id).first() + stmt = select(App).where(App.id == app_config.app_id) + app_record = db.session.scalar(stmt) if not app_record: raise ValueError("App not found") diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 11c979765b..92f3b6507c 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -3,6 +3,9 @@ import logging from collections.abc import Generator from typing import Optional, Union, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager @@ -83,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator): def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: if conversation: - app_model_config = ( - db.session.query(AppModelConfig) - .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) - .first() + stmt = select(AppModelConfig).where( + AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id ) + app_model_config = db.session.scalar(stmt) if not app_model_config: raise AppModelConfigBrokenError() @@ -253,7 +255,8 @@ class MessageBasedAppGenerator(BaseAppGenerator): :param conversation_id: conversation id :return: conversation """ - conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() + with Session(db.engine, expire_on_commit=False) as session: + conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id)) if not conversation: raise ConversationNotExistsError("Conversation not exists") @@ -266,7 +269,8 @@ class MessageBasedAppGenerator(BaseAppGenerator): :param message_id: message id :return: message """ - message = db.session.query(Message).where(Message.id == message_id).first() + with Session(db.engine, expire_on_commit=False) as session: + message = session.scalar(select(Message).where(Message.id == message_id)) if message is None: raise MessageNotExistsError("Message not exists") diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index b829340401..be183e2086 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -1,6 +1,8 @@ import logging from typing import Optional +from sqlalchemy import select + from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db @@ -25,9 +27,8 @@ class AnnotationReplyFeature: :param invoke_from: invoke from :return: """ - annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first() - ) + stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id) + annotation_setting = db.session.scalar(stmt) if not annotation_setting: return None diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 471118c8cb..e3b917067f 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -472,9 +472,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): :param event: agent thought event :return: """ - agent_thought: Optional[MessageAgentThought] = ( - db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() - ) + with Session(db.engine, expire_on_commit=False) as session: + agent_thought: Optional[MessageAgentThought] = ( + session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() + ) if agent_thought: return AgentThoughtStreamResponse( diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 50b51f70fe..5c19eda21e 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -3,6 +3,8 @@ from threading import Thread from typing import Optional, Union from flask import Flask, current_app +from sqlalchemy import select +from sqlalchemy.orm import Session from configs import dify_config from core.app.entities.app_invoke_entities import ( @@ -84,7 +86,8 @@ class MessageCycleManager: def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): with flask_app.app_context(): # get conversation and message - conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() + stmt = select(Conversation).where(Conversation.id == conversation_id) + conversation = db.session.scalar(stmt) if not conversation: return @@ -143,7 +146,8 @@ class MessageCycleManager: :param event: event :return: """ - message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first() + with Session(db.engine, expire_on_commit=False) as session: + message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id)) if message_file and message_file.url is not None: # get tool file id @@ -183,7 +187,8 @@ class MessageCycleManager: :param message_id: message id :return: """ - message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first() + with Session(db.engine, expire_on_commit=False) as session: + message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id)) event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE return MessageStreamResponse( diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index c55ba5e0fe..5cf39d7611 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,6 +1,8 @@ import logging from collections.abc import Sequence +from sqlalchemy import select + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent @@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler: for document in documents: if document.metadata is not None: document_id = document.metadata["document_id"] - dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id) + dataset_document = db.session.scalar(dataset_document_stmt) if not dataset_document: _logger.warning( "Expected DatasetDocument record to exist, but none was found, document_id=%s", @@ -57,15 +60,12 @@ class DatasetIndexToolCallbackHandler: ) continue if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk = ( - db.session.query(ChildChunk) - .where( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ) - .first() + child_chunk_stmt = select(ChildChunk).where( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, ) + child_chunk = db.session.scalar(child_chunk_stmt) if child_chunk: segment = ( db.session.query(DocumentSegment) diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index d81f372d40..2100e7fadc 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -1,5 +1,7 @@ from typing import Optional +from sqlalchemy import select + from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor from core.external_data_tool.base import ExternalDataTool from core.helper import encrypter @@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool): api_based_extension_id = config.get("api_based_extension_id") if not api_based_extension_id: raise ValueError("api_based_extension_id is required") - # get api_based_extension - api_based_extension = ( - db.session.query(APIBasedExtension) - .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) - .first() + stmt = select(APIBasedExtension).where( + APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id ) + api_based_extension = db.session.scalar(stmt) if not api_based_extension: raise ValueError("api_based_extension_id is invalid") @@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool): raise ValueError(f"config is required, config: {self.config}") api_based_extension_id = self.config.get("api_based_extension_id") assert api_based_extension_id is not None, "api_based_extension_id is required" - # get api_based_extension - api_based_extension = ( - db.session.query(APIBasedExtension) - .where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) - .first() + stmt = select(APIBasedExtension).where( + APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id ) + api_based_extension = db.session.scalar(stmt) if not api_based_extension: raise ValueError( diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index c6bb2007d6..17345dc203 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -3,7 +3,7 @@ import base64 from libs import rsa -def obfuscated_token(token: str): +def obfuscated_token(token: str) -> str: if not token: return token if len(token) <= 8: @@ -11,6 +11,10 @@ def obfuscated_token(token: str): return token[:6] + "*" * 12 + token[-2:] +def full_mask_token(token_length=20): + return "*" * token_length + + def encrypt_token(tenant_id: str, token: str): from extensions.ext_database import db from models.account import Tenant diff --git a/api/core/helper/marketplace.py b/api/core/helper/marketplace.py index fe3078923d..e837f2fd38 100644 --- a/api/core/helper/marketplace.py +++ b/api/core/helper/marketplace.py @@ -1,6 +1,6 @@ from collections.abc import Sequence -import requests +import httpx from yarl import URL from configs import dify_config @@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP return [] url = str(marketplace_api_url / "api/v1/plugins/batch") - response = requests.post(url, json={"plugin_ids": plugin_ids}) + response = httpx.post(url, json={"plugin_ids": plugin_ids}) response.raise_for_status() return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] @@ -36,7 +36,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error( return [] url = str(marketplace_api_url / "api/v1/plugins/batch") - response = requests.post(url, json={"plugin_ids": plugin_ids}) + response = httpx.post(url, json={"plugin_ids": plugin_ids}) response.raise_for_status() result: list[MarketplacePluginDeclaration] = [] for plugin in response.json()["data"]["plugins"]: @@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error( def record_install_plugin_event(plugin_unique_identifier: str): url = str(marketplace_api_url / "api/v1/stats/plugins/install_count") - response = requests.post(url, json={"unique_identifier": plugin_unique_identifier}) + response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier}) response.raise_for_status() diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index a8e6c261c2..4a768618f5 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -8,6 +8,7 @@ import uuid from typing import Any, Optional, cast from flask import current_app +from sqlalchemy import select from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config @@ -56,13 +57,11 @@ class IndexingRunner: if not dataset: raise ValueError("no dataset found") - # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) - .first() + stmt = select(DatasetProcessRule).where( + DatasetProcessRule.id == dataset_document.dataset_process_rule_id ) + processing_rule = db.session.scalar(stmt) if not processing_rule: raise ValueError("no process rule found") index_type = dataset_document.doc_form @@ -123,11 +122,8 @@ class IndexingRunner: db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) - .first() - ) + stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + processing_rule = db.session.scalar(stmt) if not processing_rule: raise ValueError("no process rule found") @@ -208,7 +204,6 @@ class IndexingRunner: child_documents.append(child_document) document.children = child_documents documents.append(document) - # build index index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -310,7 +305,8 @@ class IndexingRunner: # delete image files and related db records image_upload_file_ids = get_image_upload_file_ids(document.page_content) for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + stmt = select(UploadFile).where(UploadFile.id == upload_file_id) + image_file = db.session.scalar(stmt) if image_file is None: continue try: @@ -339,10 +335,8 @@ class IndexingRunner: if dataset_document.data_source_type == "upload_file": if not data_source_info or "upload_file_id" not in data_source_info: raise ValueError("no upload file found") - - file_detail = ( - db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() - ) + stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) + file_detail = db.session.scalars(stmt).one_or_none() if file_detail: extract_setting = ExtractSetting( diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 6cbd949be9..cb768e2036 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -110,9 +110,9 @@ class TokenBufferMemory: else: message_limit = 500 - stmt = stmt.limit(message_limit) + msg_limit_stmt = stmt.limit(message_limit) - messages = db.session.scalars(stmt).all() + messages = db.session.scalars(msg_limit_stmt).all() # instead of all messages from the conversation, we only need to extract messages # that belong to the thread of last message diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 51af3d1877..e567565548 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -158,8 +158,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, LargeLanguageModel): raise Exception("Model type instance is not LargeLanguageModel") - - self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) return cast( Union[LLMResult, Generator], self._round_robin_invoke( @@ -188,8 +186,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, LargeLanguageModel): raise Exception("Model type instance is not LargeLanguageModel") - - self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) return cast( int, self._round_robin_invoke( @@ -214,8 +210,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - - self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) return cast( TextEmbeddingResult, self._round_robin_invoke( @@ -237,8 +231,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - - self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) return cast( list[int], self._round_robin_invoke( @@ -269,8 +261,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") - - self.model_type_instance = cast(RerankModel, self.model_type_instance) return cast( RerankResult, self._round_robin_invoke( @@ -295,8 +285,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, ModerationModel): raise Exception("Model type instance is not ModerationModel") - - self.model_type_instance = cast(ModerationModel, self.model_type_instance) return cast( bool, self._round_robin_invoke( @@ -318,8 +306,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, Speech2TextModel): raise Exception("Model type instance is not Speech2TextModel") - - self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) return cast( str, self._round_robin_invoke( @@ -343,8 +329,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") - - self.model_type_instance = cast(TTSModel, self.model_type_instance) return cast( Iterable[bytes], self._round_robin_invoke( @@ -404,8 +388,6 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") - - self.model_type_instance = cast(TTSModel, self.model_type_instance) return self.model_type_instance.get_tts_model_voices( model=self.model, credentials=self.credentials, language=language ) diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index af51b72cd5..06d5c02bb8 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,6 +1,7 @@ from typing import Optional from pydantic import BaseModel, Field +from sqlalchemy import select from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor from core.helper.encrypter import decrypt_token @@ -87,10 +88,9 @@ class ApiModeration(Moderation): @staticmethod def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: - extension = ( - db.session.query(APIBasedExtension) - .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) - .first() + stmt = select(APIBasedExtension).where( + APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id ) + extension = db.session.scalar(stmt) return extension diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 77852e2a98..7caad89353 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -5,6 +5,7 @@ from typing import Optional from urllib.parse import urljoin from opentelemetry.trace import Link, Status, StatusCode +from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.ops.aliyun_trace.data_exporter.traceclient import ( @@ -260,15 +261,15 @@ class AliyunDataTrace(BaseTraceInstance): app_id = trace_info.metadata.get("app_id") if not app_id: raise ValueError("No app_id found in trace_info metadata") - - app = session.query(App).where(App.id == app_id).first() + app_stmt = select(App).where(App.id == app_id) + app = session.scalar(app_stmt) if not app: raise ValueError(f"App with id {app_id} not found") if not app.created_by: raise ValueError(f"App with id {app_id} has no creator (created_by is None)") - - service_account = session.query(Account).where(Account.id == app.created_by).first() + account_stmt = select(Account).where(Account.id == app.created_by) + service_account = session.scalar(account_stmt) if not service_account: raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") current_tenant = ( diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index f8e428daf1..04b46d67a8 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from sqlalchemy import select from sqlalchemy.orm import Session from core.ops.entities.config_entity import BaseTracingConfig @@ -44,14 +45,15 @@ class BaseTraceInstance(ABC): """ with Session(db.engine, expire_on_commit=False) as session: # Get the app to find its creator - app = session.query(App).where(App.id == app_id).first() + app_stmt = select(App).where(App.id == app_id) + app = session.scalar(app_stmt) if not app: raise ValueError(f"App with id {app_id} not found") if not app.created_by: raise ValueError(f"App with id {app_id} has no creator (created_by is None)") - - service_account = session.query(Account).where(Account.id == app.created_by).first() + account_stmt = select(Account).where(Account.id == app.created_by) + service_account = session.scalar(account_stmt) if not service_account: raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index c2bc4339d7..d3040f1093 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -228,9 +228,9 @@ class OpsTraceManager: if not trace_config_data: return None - # decrypt_token - app = db.session.query(App).where(App.id == app_id).first() + stmt = select(App).where(App.id == app_id) + app = db.session.scalar(stmt) if not app: raise ValueError("App not found") @@ -297,20 +297,19 @@ class OpsTraceManager: @classmethod def get_app_config_through_message_id(cls, message_id: str): app_model_config = None - message_data = db.session.query(Message).where(Message.id == message_id).first() + message_stmt = select(Message).where(Message.id == message_id) + message_data = db.session.scalar(message_stmt) if not message_data: return None conversation_id = message_data.conversation_id - conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first() + conversation_stmt = select(Conversation).where(Conversation.id == conversation_id) + conversation_data = db.session.scalar(conversation_stmt) if not conversation_data: return None if conversation_data.app_model_config_id: - app_model_config = ( - db.session.query(AppModelConfig) - .where(AppModelConfig.id == conversation_data.app_model_config_id) - .first() - ) + config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id) + app_model_config = db.session.scalar(config_stmt) elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: app_model_config = conversation_data.override_model_configs diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 74972a2a9c..549b6a8889 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,6 +1,8 @@ from collections.abc import Generator, Mapping from typing import Optional, Union +from sqlalchemy import select + from controllers.service_api.wraps import create_or_update_end_user_for_user_id from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator @@ -191,10 +193,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ get the user by user id """ - - user = db.session.query(EndUser).where(EndUser.id == user_id).first() + stmt = select(EndUser).where(EndUser.id == user_id) + user = db.session.scalar(stmt) if not user: - user = db.session.query(Account).where(Account.id == user_id).first() + stmt = select(Account).where(Account.id == user_id) + user = db.session.scalar(stmt) if not user: raise ValueError("user not found") diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 2f4e651461..cdc6ccc821 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -87,7 +87,6 @@ class PromptMessageUtil: if isinstance(prompt_message.content, list): for content in prompt_message.content: if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) text += content.data else: content = cast(ImagePromptMessageContent, content) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index d8cc2293a2..e0fb0591e8 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -2,7 +2,7 @@ import contextlib import json from collections import defaultdict from json import JSONDecodeError -from typing import Any, Optional, cast +from typing import Any, Optional from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -154,8 +154,8 @@ class ProviderManager: for provider_entity in provider_entities: # handle include, exclude if is_filtered( - include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET), - exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET), + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, data=provider_entity, name_func=lambda x: x.provider, ): @@ -276,15 +276,11 @@ class ProviderManager: :param model_type: model type :return: """ - # Get the corresponding TenantDefaultModel record - default_model = ( - db.session.query(TenantDefaultModel) - .where( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type(), - ) - .first() + stmt = select(TenantDefaultModel).where( + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), ) + default_model = db.session.scalar(stmt) # If it does not exist, get the first available provider model from get_configurations # and update the TenantDefaultModel record @@ -367,16 +363,11 @@ class ProviderManager: model_names = [model.model for model in available_models] if model not in model_names: raise ValueError(f"Model {model} does not exist.") - - # Get the list of available models from get_configurations and check if it is LLM - default_model = ( - db.session.query(TenantDefaultModel) - .where( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type(), - ) - .first() + stmt = select(TenantDefaultModel).where( + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), ) + default_model = db.session.scalar(stmt) # create or update TenantDefaultModel record if default_model: @@ -598,16 +589,13 @@ class ProviderManager: provider_name_to_provider_records_dict[provider_name].append(new_provider_record) except IntegrityError: db.session.rollback() - existed_provider_record = ( - db.session.query(Provider) - .where( - Provider.tenant_id == tenant_id, - Provider.provider_name == ModelProviderID(provider_name).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.TRIAL.value, - ) - .first() + stmt = select(Provider).where( + Provider.tenant_id == tenant_id, + Provider.provider_name == ModelProviderID(provider_name).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == ProviderQuotaType.TRIAL.value, ) + existed_provider_record = db.session.scalar(stmt) if not existed_provider_record: continue diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index c98306ea4b..5fb6f9fcc8 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -3,6 +3,7 @@ from typing import Any, Optional import orjson from pydantic import BaseModel +from sqlalchemy import select from configs import dify_config from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -211,11 +212,10 @@ class Jieba(BaseKeyword): return sorted_chunk_indices[:k] def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): - document_segment = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) - .first() + stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id ) + document_segment = db.session.scalar(stmt) if document_segment: document_segment.keywords = keywords db.session.add(document_segment) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index e872a4e375..fefd42f84d 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional from flask import Flask, current_app +from sqlalchemy import select from sqlalchemy.orm import Session, load_only from configs import dify_config @@ -24,7 +25,7 @@ default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, + "top_k": 4, "score_threshold_enabled": False, } @@ -127,7 +128,8 @@ class RetrievalService: external_retrieval_model: Optional[dict] = None, metadata_filtering_conditions: Optional[dict] = None, ): - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + stmt = select(Dataset).where(Dataset.id == dataset_id) + dataset = db.session.scalar(stmt) if not dataset: return [] metadata_condition = ( @@ -316,10 +318,8 @@ class RetrievalService: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: # Handle parent-child documents child_index_node_id = document.metadata.get("doc_id") - - child_chunk = ( - db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first() - ) + child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) + child_chunk = db.session.scalar(child_chunk_stmt) if not child_chunk: continue @@ -378,17 +378,13 @@ class RetrievalService: index_node_id = document.metadata.get("doc_id") if not index_node_id: continue - - segment = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - .first() + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, ) + segment = db.session.scalar(document_segment_stmt) if not segment: continue diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 6f3e15d166..aa0204ba70 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -256,7 +256,7 @@ class AnalyticdbVectorOpenAPI: response = self._client.query_collection_data(request) documents = [] for match in response.body.matches.match: - if match.score > score_threshold: + if match.score >= score_threshold: metadata = json.loads(match.metadata.get("metadata_")) metadata["score"] = match.score doc = Document( @@ -293,7 +293,7 @@ class AnalyticdbVectorOpenAPI: response = self._client.query_collection_data(request) documents = [] for match in response.body.matches.match: - if match.score > score_threshold: + if match.score >= score_threshold: metadata = json.loads(match.metadata.get("metadata_")) metadata["score"] = match.score doc = Document( diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index df2173d1ca..71472cce41 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -229,7 +229,7 @@ class AnalyticdbVectorBySql: documents = [] for record in cur: id, vector, score, page_content, metadata = record - if score > score_threshold: + if score >= score_threshold: metadata["score"] = score doc = Document( page_content=page_content, diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index d63ca9f695..d30cf42601 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -157,7 +157,7 @@ class BaiduVector(BaseVector): if meta is not None: meta = json.loads(meta) score = row.get("score", 0.0) - if score > score_threshold: + if score >= score_threshold: meta["score"] = score doc = Document(page_content=row_data.get(self.field_text), metadata=meta) docs.append(doc) diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 699a602365..88da86cf76 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -120,7 +120,7 @@ class ChromaVector(BaseVector): distance = distances[index] metadata = dict(metadatas[index]) score = 1 - distance - if score > score_threshold: + if score >= score_threshold: metadata["score"] = score doc = Document( page_content=documents[index], diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py index bd986393d1..d22a7e4fd4 100644 --- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py +++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py @@ -304,7 +304,7 @@ class CouchbaseVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - top_k = kwargs.get("top_k", 2) + top_k = kwargs.get("top_k", 4) try: CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) search_iter = self._scope.search( diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 49c4b392fe..cbad0e67de 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -216,7 +216,7 @@ class ElasticSearchVector(BaseVector): docs = [] for doc, score in docs_and_scores: score_threshold = float(kwargs.get("score_threshold") or 0.0) - if score > score_threshold: + if score >= score_threshold: if doc.metadata is not None: doc.metadata["score"] = score docs.append(doc) diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index 0a4067e39c..f0d014b1ec 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -127,7 +127,7 @@ class HuaweiCloudVector(BaseVector): docs = [] for doc, score in docs_and_scores: score_threshold = float(kwargs.get("score_threshold") or 0.0) - if score > score_threshold: + if score >= score_threshold: if doc.metadata is not None: doc.metadata["score"] = score docs.append(doc) diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 3c65a41f08..cba10b5aa5 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -275,7 +275,7 @@ class LindormVectorStore(BaseVector): docs = [] for doc, score in docs_and_scores: score_threshold = kwargs.get("score_threshold", 0.0) or 0.0 - if score > score_threshold: + if score >= score_threshold: if doc.metadata is not None: doc.metadata["score"] = score docs.append(doc) diff --git a/api/core/rag/datasource/vdb/opengauss/opengauss.py b/api/core/rag/datasource/vdb/opengauss/opengauss.py index 3ba9569d3f..c448210d94 100644 --- a/api/core/rag/datasource/vdb/opengauss/opengauss.py +++ b/api/core/rag/datasource/vdb/opengauss/opengauss.py @@ -194,7 +194,7 @@ class OpenGauss(BaseVector): metadata, text, distance = record score = 1 - distance metadata["score"] = score - if score > score_threshold: + if score >= score_threshold: docs.append(Document(page_content=text, metadata=metadata)) return docs diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 3c9302f4da..71d2ba1427 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -211,7 +211,7 @@ class OpenSearchVector(BaseVector): metadata["score"] = hit["_score"] score_threshold = float(kwargs.get("score_threshold") or 0.0) - if hit["_score"] > score_threshold: + if hit["_score"] >= score_threshold: doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 0956914070..1b99f649bf 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -261,7 +261,7 @@ class OracleVector(BaseVector): metadata, text, distance = record score = 1 - distance metadata["score"] = score - if score > score_threshold: + if score >= score_threshold: docs.append(Document(page_content=text, metadata=metadata)) conn.close() return docs diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index e77befcdae..99cd4a22cb 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -202,7 +202,7 @@ class PGVectoRS(BaseVector): score = 1 - dis metadata["score"] = score score_threshold = float(kwargs.get("score_threshold") or 0.0) - if score > score_threshold: + if score >= score_threshold: doc = Document(page_content=record.text, metadata=metadata) docs.append(doc) return docs diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 108167d749..13be18f920 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -195,7 +195,7 @@ class PGVector(BaseVector): metadata, text, distance = record score = 1 - distance metadata["score"] = score - if score > score_threshold: + if score >= score_threshold: docs.append(Document(page_content=text, metadata=metadata)) return docs diff --git a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py index 580da7d62e..c33e344bff 100644 --- a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py +++ b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py @@ -170,7 +170,7 @@ class VastbaseVector(BaseVector): metadata, text, distance = record score = 1 - distance metadata["score"] = score - if score > score_threshold: + if score >= score_threshold: docs.append(Document(page_content=text, metadata=metadata)) return docs diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index fcf3a6d126..e55c06e665 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -3,7 +3,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union import qdrant_client from flask import current_app @@ -18,6 +18,7 @@ from qdrant_client.http.models import ( TokenizerType, ) from qdrant_client.local.qdrant_local import QdrantLocal +from sqlalchemy import select from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -369,7 +370,7 @@ class QdrantVector(BaseVector): continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold - if result.score > score_threshold: + if result.score >= score_threshold: metadata["score"] = result.score doc = Document( page_content=result.payload.get(Field.CONTENT_KEY.value, ""), @@ -426,7 +427,6 @@ class QdrantVector(BaseVector): def _reload_if_needed(self): if isinstance(self._client, QdrantLocal): - self._client = cast(QdrantLocal, self._client) self._client._load() @classmethod @@ -446,11 +446,8 @@ class QdrantVector(BaseVector): class QdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: if dataset.collection_binding_id: - dataset_collection_binding = ( - db.session.query(DatasetCollectionBinding) - .where(DatasetCollectionBinding.id == dataset.collection_binding_id) - .one_or_none() - ) + stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id) + dataset_collection_binding = db.session.scalars(stmt).one_or_none() if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 7a42dd1a89..a200bacfb6 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -233,7 +233,7 @@ class RelytVector(BaseVector): docs = [] for document, score in results: score_threshold = float(kwargs.get("score_threshold") or 0.0) - if 1 - score > score_threshold: + if 1 - score >= score_threshold: docs.append(document) return docs diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index e66959045f..1c154ef360 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -300,7 +300,7 @@ class TableStoreVector(BaseVector): ) documents = [] for search_hit in search_response.search_hits: - if search_hit.score > score_threshold: + if search_hit.score >= score_threshold: ots_column_map = {} for col in search_hit.row[1]: ots_column_map[col[0]] = col[1] diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 0517d5a6d1..3df35d081f 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -39,6 +39,9 @@ class TencentConfig(BaseModel): return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} +bm25 = BM25Encoder.default("zh") + + class TencentVector(BaseVector): field_id: str = "id" field_vector: str = "vector" @@ -53,7 +56,6 @@ class TencentVector(BaseVector): self._dimension = 1024 self._init_database() self._load_collection() - self._bm25 = BM25Encoder.default("zh") def _load_collection(self): """ @@ -186,7 +188,7 @@ class TencentVector(BaseVector): metadata=metadata, ) if self._enable_hybrid_search: - doc.__dict__["sparse_vector"] = self._bm25.encode_texts(texts[i]) + doc.__dict__["sparse_vector"] = bm25.encode_texts(texts[i]) docs.append(doc) self._client.upsert( database_name=self._client_config.database, @@ -264,7 +266,7 @@ class TencentVector(BaseVector): match=[ KeywordSearch( field_name="sparse_vector", - data=self._bm25.encode_queries(query), + data=bm25.encode_queries(query), ), ], rerank=WeightedRerank( @@ -291,7 +293,7 @@ class TencentVector(BaseVector): score = 1 - result.get("score", 0.0) else: score = result.get("score", 0.0) - if score > score_threshold: + if score >= score_threshold: meta["score"] = score doc = Document(page_content=result.get(self.field_text), metadata=meta) docs.append(doc) diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index a76b5d579c..be24f5a561 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -20,6 +20,7 @@ from qdrant_client.http.models import ( ) from qdrant_client.local.qdrant_local import QdrantLocal from requests.auth import HTTPDigestAuth +from sqlalchemy import select from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -351,7 +352,7 @@ class TidbOnQdrantVector(BaseVector): metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold score_threshold = kwargs.get("score_threshold") or 0.0 - if result.score > score_threshold: + if result.score >= score_threshold: metadata["score"] = result.score doc = Document( page_content=result.payload.get(Field.CONTENT_KEY.value, ""), @@ -416,16 +417,12 @@ class TidbOnQdrantVector(BaseVector): class TidbOnQdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: - tidb_auth_binding = ( - db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() - ) + stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) + tidb_auth_binding = db.session.scalars(stmt).one_or_none() if not tidb_auth_binding: with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): - tidb_auth_binding = ( - db.session.query(TidbAuthBinding) - .where(TidbAuthBinding.tenant_id == dataset.tenant_id) - .one_or_none() - ) + stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) + tidb_auth_binding = db.session.scalars(stmt).one_or_none() if tidb_auth_binding: TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" diff --git a/api/core/rag/datasource/vdb/upstash/upstash_vector.py b/api/core/rag/datasource/vdb/upstash/upstash_vector.py index e4f15be2b0..9e99f14dc5 100644 --- a/api/core/rag/datasource/vdb/upstash/upstash_vector.py +++ b/api/core/rag/datasource/vdb/upstash/upstash_vector.py @@ -110,7 +110,7 @@ class UpstashVector(BaseVector): score = record.score if metadata is not None and text is not None: metadata["score"] = score - if score > score_threshold: + if score >= score_threshold: docs.append(Document(page_content=text, metadata=metadata)) return docs diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index eef03ce412..661a8f37aa 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -3,6 +3,8 @@ import time from abc import ABC, abstractmethod from typing import Any, Optional +from sqlalchemy import select + from configs import dify_config from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -45,11 +47,10 @@ class Vector: vector_type = self._dataset.index_struct_dict["type"] else: if dify_config.VECTOR_STORE_WHITELIST_ENABLE: - whitelist = ( - db.session.query(Whitelist) - .where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") - .one_or_none() + stmt = select(Whitelist).where( + Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db" ) + whitelist = db.session.scalars(stmt).one_or_none() if whitelist: vector_type = VectorType.TIDB_ON_QDRANT diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 9166d35bc8..a0a2e47d19 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -192,7 +192,7 @@ class VikingDBVector(BaseVector): metadata = result.fields.get(vdb_Field.METADATA_KEY.value) if metadata is not None: metadata = json.loads(metadata) - if result.score > score_threshold: + if result.score >= score_threshold: metadata["score"] = result.score doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 5525ef1685..a7e0789a92 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -220,7 +220,7 @@ class WeaviateVector(BaseVector): for doc, score in docs_and_scores: score_threshold = float(kwargs.get("score_threshold") or 0.0) # check score threshold - if score > score_threshold: + if score >= score_threshold: if doc.metadata is not None: doc.metadata["score"] = score docs.append(doc) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index f8da3657fc..717cfe8f53 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from typing import Any, Optional -from sqlalchemy import func +from sqlalchemy import func, select from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -41,9 +41,8 @@ class DatasetDocumentStore: @property def docs(self) -> dict[str, Document]: - document_segments = ( - db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all() - ) + stmt = select(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id) + document_segments = db.session.scalars(stmt).all() output = {} for document_segment in document_segments: @@ -228,10 +227,9 @@ class DatasetDocumentStore: return data def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: - document_segment = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) - .first() + stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id ) + document_segment = db.session.scalar(stmt) return document_segment diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index c97765b1dc..3845392c8d 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -2,7 +2,7 @@ import re from pathlib import Path -from typing import Optional, cast +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.helpers import detect_file_encodings @@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor): markdown_tups.append((current_header, current_text)) markdown_tups = [ - (re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value)) + (re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value)) for key, value in markdown_tups ] diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 17f4d1af2d..206b2bb921 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -4,6 +4,7 @@ import operator from typing import Any, Optional, cast import requests +from sqlalchemy import select from configs import dify_config from core.rag.extractor.extractor_base import BaseExtractor @@ -367,22 +368,17 @@ class NotionExtractor(BaseExtractor): @classmethod def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: - data_source_binding = ( - db.session.query(DataSourceOauthBinding) - .where( - db.and_( - DataSourceOauthBinding.tenant_id == tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', - ) - ) - .first() + stmt = select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', ) + data_source_binding = db.session.scalar(stmt) if not data_source_binding: raise Exception( f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}" ) - return cast(str, data_source_binding.access_token) + return data_source_binding.access_token diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 7dfe2e357c..3c43f34104 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -2,7 +2,7 @@ import contextlib from collections.abc import Iterator -from typing import Optional, cast +from typing import Optional from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor @@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor): plaintext_file_exists = False if self._file_cache_key: with contextlib.suppress(FileNotFoundError): - text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8") + text = storage.load(self._file_cache_key).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] documents = list(self.load()) diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 9b90bd2bb3..997b0b953b 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -123,7 +123,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): for result in results: metadata = result.metadata metadata["score"] = result.score - if result.score > score_threshold: + if result.score >= score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) return docs diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 52756fbacd..cb7f6ab57a 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -130,13 +130,16 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if delete_child_chunks: db.session.query(ChildChunk).where( ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) - ).delete() + ).delete(synchronize_session=False) db.session.commit() else: vector.delete() if delete_child_chunks: - db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete() + # Use existing compound index: (tenant_id, dataset_id, ...) + db.session.query(ChildChunk).where( + ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id + ).delete(synchronize_session=False) db.session.commit() def retrieve( @@ -162,7 +165,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): for result in results: metadata = result.metadata metadata["score"] = result.score - if result.score > score_threshold: + if result.score >= score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) return docs diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 609a8aafa1..8c345b7edf 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -158,7 +158,7 @@ class QAIndexProcessor(BaseIndexProcessor): for result in results: metadata = result.metadata metadata["score"] = result.score - if result.score > score_threshold: + if result.score >= score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) return docs diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index cd4af72832..11010c9d60 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping from typing import Any, Optional, Union, cast from flask import Flask, current_app -from sqlalchemy import Float, and_, or_, text +from sqlalchemy import Float, and_, or_, select, text from sqlalchemy import cast as sqlalchemy_cast from sqlalchemy.orm import Session @@ -65,7 +65,7 @@ default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, + "top_k": 4, "score_threshold_enabled": False, } @@ -135,7 +135,8 @@ class DatasetRetrieval: available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id) + dataset = db.session.scalar(dataset_stmt) # pass if dataset is not available if not dataset: @@ -240,15 +241,12 @@ class DatasetRetrieval: for record in records: segment = record.segment dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() - document = ( - db.session.query(DatasetDocument) - .where( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - .first() + dataset_document_stmt = select(DatasetDocument).where( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, ) + document = db.session.scalar(dataset_document_stmt) if dataset and document: source = RetrievalSourceMetadata( dataset_id=dataset.id, @@ -327,7 +325,8 @@ class DatasetRetrieval: if dataset_id: # get retrieval model config - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) + dataset = db.session.scalar(dataset_stmt) if dataset: results = [] if dataset.provider == "external": @@ -514,22 +513,18 @@ class DatasetRetrieval: dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: if document.metadata is not None: - dataset_document = ( - db.session.query(DatasetDocument) - .where(DatasetDocument.id == document.metadata["document_id"]) - .first() + dataset_document_stmt = select(DatasetDocument).where( + DatasetDocument.id == document.metadata["document_id"] ) + dataset_document = db.session.scalar(dataset_document_stmt) if dataset_document: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk = ( - db.session.query(ChildChunk) - .where( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ) - .first() + child_chunk_stmt = select(ChildChunk).where( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, ) + child_chunk = db.session.scalar(child_chunk_stmt) if child_chunk: segment = ( db.session.query(DocumentSegment) @@ -600,7 +595,8 @@ class DatasetRetrieval: ): with flask_app.app_context(): with Session(db.engine) as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) + dataset = db.session.scalar(dataset_stmt) if not dataset: return [] @@ -647,7 +643,7 @@ class DatasetRetrieval: retrieval_method=retrieval_model["search_method"], dataset_id=dataset.id, query=query, - top_k=retrieval_model.get("top_k") or 2, + top_k=retrieval_model.get("top_k") or 4, score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, @@ -685,7 +681,8 @@ class DatasetRetrieval: available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id) + dataset = db.session.scalar(dataset_stmt) # pass if dataset is not available if not dataset: @@ -743,7 +740,7 @@ class DatasetRetrieval: tool = DatasetMultiRetrieverTool.from_dataset( dataset_ids=[dataset.id for dataset in available_datasets], tenant_id=tenant_id, - top_k=retrieve_config.top_k or 2, + top_k=retrieve_config.top_k or 4, score_threshold=retrieve_config.score_threshold, hit_callbacks=[hit_callback], return_resource=return_resource, @@ -958,7 +955,8 @@ class DatasetRetrieval: self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig ) -> Optional[list[dict[str, Any]]]: # get all metadata field - metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() + metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) + metadata_fields = db.session.scalars(metadata_stmt).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] # get metadata model config if metadata_model_config is None: diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index cdfefbadb3..90c09a4441 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController @@ -54,17 +56,13 @@ class ToolLabelManager: return controller.tool_labels else: raise ValueError("Unsupported tool type") - - labels = ( - db.session.query(ToolLabelBinding.label_name) - .where( - ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, - ) - .all() + stmt = select(ToolLabelBinding.label_name).where( + ToolLabelBinding.tool_id == provider_id, + ToolLabelBinding.tool_type == controller.provider_type.value, ) + labels = db.session.scalars(stmt).all() - return [label.label_name for label in labels] + return list(labels) @classmethod def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 474f8e3bcc..0069c6d6ee 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast import sqlalchemy as sa from pydantic import TypeAdapter +from sqlalchemy import select from sqlalchemy.orm import Session from yarl import URL @@ -198,14 +199,11 @@ class ToolManager: # get specific credentials if is_valid_uuid(credential_id): try: - builtin_provider = ( - db.session.query(BuiltinToolProvider) - .where( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.id == credential_id, - ) - .first() + builtin_provider_stmt = select(BuiltinToolProvider).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, ) + builtin_provider = db.session.scalar(builtin_provider_stmt) except Exception as e: builtin_provider = None logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) @@ -319,11 +317,10 @@ class ToolManager: ), ) elif provider_type == ToolProviderType.WORKFLOW: - workflow_provider = ( - db.session.query(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) - .first() + workflow_provider_stmt = select(WorkflowToolProvider).where( + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id ) + workflow_provider = db.session.scalar(workflow_provider_stmt) if workflow_provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") @@ -333,16 +330,13 @@ class ToolManager: if controller_tools is None or len(controller_tools) == 0: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - return cast( - WorkflowTool, - controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), + return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) ) elif provider_type == ToolProviderType.APP: raise NotImplementedError("app provider not implemented") @@ -652,8 +646,8 @@ class ToolManager: for provider in builtin_providers: # handle include, exclude if is_filtered( - include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), - exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, data=provider, name_func=lambda x: x.identity.name, ): diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 7eb4bc017a..75c0c6738e 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -3,6 +3,7 @@ from typing import Any from flask import Flask, current_app from pydantic import BaseModel, Field +from sqlalchemy import select from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelManager @@ -85,17 +86,14 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): document_context_list = [] index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] - segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.dataset_id.in_(self.dataset_ids), - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ) - .all() + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id.in_(self.dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), ) + segments = db.session.scalars(document_segment_stmt).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} @@ -112,15 +110,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): resource_number = 1 for segment in sorted_segments: dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() - document = ( - db.session.query(Document) - .where( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ) - .first() + document_stmt = select(Document).where( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ) + document = db.session.scalar(document_stmt) if dataset and document: source = RetrievalSourceMetadata( position=resource_number, @@ -162,9 +157,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): hit_callbacks: list[DatasetIndexToolCallbackHandler], ): with flask_app.app_context(): - dataset = ( - db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() - ) + stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id) + dataset = db.session.scalar(stmt) if not dataset: return [] @@ -181,7 +175,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): retrieval_method="keyword_search", dataset_id=dataset.id, query=query, - top_k=retrieval_model.get("top_k") or 2, + top_k=retrieval_model.get("top_k") or 4, ) if documents: all_documents.extend(documents) @@ -192,7 +186,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): retrieval_method=retrieval_model["search_method"], dataset_id=dataset.id, query=query, - top_k=retrieval_model.get("top_k") or 2, + top_k=retrieval_model.get("top_k") or 4, score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py index 567275531e..4f489e00f4 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py @@ -13,7 +13,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC): name: str = "dataset" description: str = "use this to retrieve a dataset. " tenant_id: str - top_k: int = 2 + top_k: int = 4 score_threshold: Optional[float] = None hit_callbacks: list[DatasetIndexToolCallbackHandler] = [] return_resource: bool diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index f7689d7707..b536c5a25c 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,6 +1,7 @@ from typing import Any, Optional, cast from pydantic import BaseModel, Field +from sqlalchemy import select from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig from core.rag.datasource.retrieval_service import RetrievalService @@ -56,9 +57,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): ) def _run(self, query: str) -> str: - dataset = ( - db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() - ) + dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id) + dataset = db.session.scalar(dataset_stmt) if not dataset: return "" @@ -188,15 +188,12 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for record in records: segment = record.segment dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() - document = ( - db.session.query(DatasetDocument) # type: ignore - .where( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - .first() + dataset_document_stmt = select(DatasetDocument).where( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, ) + document = db.session.scalar(dataset_document_stmt) # type: ignore if dataset and document: source = RetrievalSourceMetadata( dataset_id=dataset.id, diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 8357dac0d7..bf075bd730 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -3,7 +3,7 @@ from collections.abc import Generator from datetime import date, datetime from decimal import Decimal from mimetypes import guess_extension -from typing import Optional, cast +from typing import Optional from uuid import UUID import numpy as np @@ -159,8 +159,7 @@ class ToolFileMessageTransformer: elif message.type == ToolInvokeMessage.MessageType.JSON: if isinstance(message.message, ToolInvokeMessage.JsonMessage): - json_msg = cast(ToolInvokeMessage.JsonMessage, message.message) - json_msg.json_object = safe_json_value(json_msg.json_object) + message.message.json_object = safe_json_value(message.message.json_object) yield message else: yield message diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 3f59b3f472..251d914800 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -129,17 +129,14 @@ class ModelInvocationUtils: db.session.commit() try: - response: LLMResult = cast( - LLMResult, - model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=[], - stop=[], - stream=False, - user=user_id, - callbacks=[], - ), + response: LLMResult = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], ) except InvokeRateLimitError as e: raise InvokeModelError(f"Invoke rate limit error: {e}") diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index d8749f9851..c2092853ea 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,7 +1,9 @@ import json import logging from collections.abc import Generator -from typing import Any, Optional, cast +from typing import Any, Optional + +from sqlalchemy import select from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.__base.tool import Tool @@ -133,7 +135,8 @@ class WorkflowTool(Tool): .first() ) else: - workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first() + stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) + workflow = db.session.scalar(stmt) if not workflow: raise ValueError("workflow not found or not published") @@ -144,7 +147,8 @@ class WorkflowTool(Tool): """ get the app by app id """ - app = db.session.query(App).where(App.id == app_id).first() + stmt = select(App).where(App.id == app_id) + app = db.session.scalar(stmt) if not app: raise ValueError("app not found") @@ -201,14 +205,14 @@ class WorkflowTool(Tool): item = self._update_file_mapping(item) file = build_from_mapping( mapping=item, - tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id), + tenant_id=str(self.runtime.tenant_id), ) files.append(file) elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: value = self._update_file_mapping(value) file = build_from_mapping( mapping=value, - tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id), + tenant_id=str(self.runtime.tenant_id), ) files.append(file) diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 16c8116ac1..a994730cd5 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Annotated, TypeAlias, cast +from typing import Annotated, TypeAlias from uuid import uuid4 from pydantic import Discriminator, Field, Tag @@ -86,7 +86,7 @@ class SecretVariable(StringVariable): @property def log(self) -> str: - return cast(str, encrypter.obfuscated_token(self.value)) + return encrypter.obfuscated_token(self.value) class NoneVariable(NoneSegment, Variable): diff --git a/api/core/workflow/graph_engine/graph_engine.py.orig b/api/core/workflow/graph_engine/graph_engine.py.orig new file mode 100644 index 0000000000..833cee0ffe --- /dev/null +++ b/api/core/workflow/graph_engine/graph_engine.py.orig @@ -0,0 +1,339 @@ +""" +QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution. + +This engine uses a modular architecture with separated packages following +Domain-Driven Design principles for improved maintainability and testability. +""" + +import contextvars +import logging +import queue +from collections.abc import Generator, Mapping +from typing import final + +from flask import Flask, current_app + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities import GraphRuntimeState +from core.workflow.enums import NodeExecutionType +from core.workflow.graph import Graph +from core.workflow.graph_events import ( + GraphEngineEvent, + GraphNodeEventBase, + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) +from models.enums import UserFrom + +from .command_processing import AbortCommandHandler, CommandProcessor +from .domain import ExecutionContext, GraphExecution +from .entities.commands import AbortCommand +from .error_handling import ErrorHandler +from .event_management import EventHandler, EventManager +from .graph_traversal import EdgeProcessor, SkipPropagator +from .layers.base import Layer +from .orchestration import Dispatcher, ExecutionCoordinator +from .protocols.command_channel import CommandChannel +from .response_coordinator import ResponseStreamCoordinator +from .state_management import UnifiedStateManager +from .worker_management import SimpleWorkerPool + +logger = logging.getLogger(__name__) + + +@final +class GraphEngine: + """ + Queue-based graph execution engine. + + Uses a modular architecture that delegates responsibilities to specialized + subsystems, following Domain-Driven Design and SOLID principles. + """ + + def __init__( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + graph: Graph, + graph_config: Mapping[str, object], + graph_runtime_state: GraphRuntimeState, + max_execution_steps: int, + max_execution_time: int, + command_channel: CommandChannel, + min_workers: int | None = None, + max_workers: int | None = None, + scale_up_threshold: int | None = None, + scale_down_idle_time: float | None = None, + ) -> None: + """Initialize the graph engine with all subsystems and dependencies.""" + + # === Domain Models === + # Execution context encapsulates workflow execution metadata + self._execution_context = ExecutionContext( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + call_depth=call_depth, + max_execution_steps=max_execution_steps, + max_execution_time=max_execution_time, + ) + + # Graph execution tracks the overall execution state + self._graph_execution = GraphExecution(workflow_id=workflow_id) + + # === Core Dependencies === + # Graph structure and configuration + self._graph = graph + self._graph_config = graph_config + self._graph_runtime_state = graph_runtime_state + self._command_channel = command_channel + + # === Worker Management Parameters === + # Parameters for dynamic worker pool scaling + self._min_workers = min_workers + self._max_workers = max_workers + self._scale_up_threshold = scale_up_threshold + self._scale_down_idle_time = scale_down_idle_time + + # === Execution Queues === + # Queue for nodes ready to execute + self._ready_queue: queue.Queue[str] = queue.Queue() + # Queue for events generated during execution + self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() + + # === State Management === + # Unified state manager handles all node state transitions and queue operations + self._state_manager = UnifiedStateManager(self._graph, self._ready_queue) + + # === Response Coordination === + # Coordinates response streaming from response nodes + self._response_coordinator = ResponseStreamCoordinator( + variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph + ) + + # === Event Management === + # Event manager handles both collection and emission of events + self._event_manager = EventManager() + + # === Error Handling === + # Centralized error handler for graph execution errors + self._error_handler = ErrorHandler(self._graph, self._graph_execution) + + # === Graph Traversal Components === + # Propagates skip status through the graph when conditions aren't met + self._skip_propagator = SkipPropagator( + graph=self._graph, + state_manager=self._state_manager, + ) + + # Processes edges to determine next nodes after execution + # Also handles conditional branching and route selection + self._edge_processor = EdgeProcessor( + graph=self._graph, + state_manager=self._state_manager, + response_coordinator=self._response_coordinator, + skip_propagator=self._skip_propagator, + ) + + # === Event Handler Registry === + # Central registry for handling all node execution events + self._event_handler_registry = EventHandler( + graph=self._graph, + graph_runtime_state=self._graph_runtime_state, + graph_execution=self._graph_execution, + response_coordinator=self._response_coordinator, + event_collector=self._event_manager, + edge_processor=self._edge_processor, + state_manager=self._state_manager, + error_handler=self._error_handler, + ) + + # === Command Processing === + # Processes external commands (e.g., abort requests) + self._command_processor = CommandProcessor( + command_channel=self._command_channel, + graph_execution=self._graph_execution, + ) + + # Register abort command handler + abort_handler = AbortCommandHandler() + self._command_processor.register_handler( + AbortCommand, + abort_handler, + ) + + # === Worker Pool Setup === + # Capture Flask app context for worker threads + flask_app: Flask | None = None + try: + app = current_app._get_current_object() # type: ignore + if isinstance(app, Flask): + flask_app = app + except RuntimeError: + pass + + # Capture context variables for worker threads + context_vars = contextvars.copy_context() + + # Create worker pool for parallel node execution + self._worker_pool = SimpleWorkerPool( + ready_queue=self._ready_queue, + event_queue=self._event_queue, + graph=self._graph, + flask_app=flask_app, + context_vars=context_vars, + min_workers=self._min_workers, + max_workers=self._max_workers, + scale_up_threshold=self._scale_up_threshold, + scale_down_idle_time=self._scale_down_idle_time, + ) + + # === Orchestration === + # Coordinates the overall execution lifecycle + self._execution_coordinator = ExecutionCoordinator( + graph_execution=self._graph_execution, + state_manager=self._state_manager, + event_handler=self._event_handler_registry, + event_collector=self._event_manager, + command_processor=self._command_processor, + worker_pool=self._worker_pool, + ) + + # Dispatches events and manages execution flow + self._dispatcher = Dispatcher( + event_queue=self._event_queue, + event_handler=self._event_handler_registry, + event_collector=self._event_manager, + execution_coordinator=self._execution_coordinator, + max_execution_time=self._execution_context.max_execution_time, + event_emitter=self._event_manager, + ) + + # === Extensibility === + # Layers allow plugins to extend engine functionality + self._layers: list[Layer] = [] + + # === Validation === + # Ensure all nodes share the same GraphRuntimeState instance + self._validate_graph_state_consistency() + + def _validate_graph_state_consistency(self) -> None: + """Validate that all nodes share the same GraphRuntimeState.""" + expected_state_id = id(self._graph_runtime_state) + for node in self._graph.nodes.values(): + if id(node.graph_runtime_state) != expected_state_id: + raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") + + def layer(self, layer: Layer) -> "GraphEngine": + """Add a layer for extending functionality.""" + self._layers.append(layer) + return self + + def run(self) -> Generator[GraphEngineEvent, None, None]: + """ + Execute the graph using the modular architecture. + + Returns: + Generator yielding GraphEngineEvent instances + """ + try: + # Initialize layers + self._initialize_layers() + + # Start execution + self._graph_execution.start() + start_event = GraphRunStartedEvent() + yield start_event + + # Start subsystems + self._start_execution() + + # Yield events as they occur + yield from self._event_manager.emit_events() + + # Handle completion + if self._graph_execution.aborted: + abort_reason = "Workflow execution aborted by user command" + if self._graph_execution.error: + abort_reason = str(self._graph_execution.error) + yield GraphRunAbortedEvent( + reason=abort_reason, + outputs=self._graph_runtime_state.outputs, + ) + elif self._graph_execution.has_error: + if self._graph_execution.error: + raise self._graph_execution.error + else: + yield GraphRunSucceededEvent( + outputs=self._graph_runtime_state.outputs, + ) + + except Exception as e: + yield GraphRunFailedEvent(error=str(e)) + raise + + finally: + self._stop_execution() + + def _initialize_layers(self) -> None: + """Initialize layers with context.""" + self._event_manager.set_layers(self._layers) + for layer in self._layers: + try: + layer.initialize(self._graph_runtime_state, self._command_channel) + except Exception as e: + logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e) + + try: + layer.on_graph_start() + except Exception as e: + logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e) + + def _start_execution(self) -> None: + """Start execution subsystems.""" + # Start worker pool (it calculates initial workers internally) + self._worker_pool.start() + + # Register response nodes + for node in self._graph.nodes.values(): + if node.execution_type == NodeExecutionType.RESPONSE: + self._response_coordinator.register(node.id) + + # Enqueue root node + root_node = self._graph.root_node + self._state_manager.enqueue_node(root_node.id) + self._state_manager.start_execution(root_node.id) + + # Start dispatcher + self._dispatcher.start() + + def _stop_execution(self) -> None: + """Stop execution subsystems.""" + self._dispatcher.stop() + self._worker_pool.stop() + # Don't mark complete here as the dispatcher already does it + + # Notify layers + logger = logging.getLogger(__name__) + + for layer in self._layers: + try: + layer.on_graph_end(self._graph_execution.error) + except Exception as e: + logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e) + + # Public property accessors for attributes that need external access + @property + def graph_runtime_state(self) -> GraphRuntimeState: + """Get the graph runtime state.""" + return self._graph_runtime_state diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 57b58ab8f5..fa912d5035 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -157,7 +157,7 @@ class AgentNode(Node): messages=message_stream, tool_info={ "icon": self.agent_strategy_icon, - "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name, + "agent_strategy": self._node_data.agent_strategy_name, }, parameters_for_log=parameters_for_log, user_id=self.user_id, @@ -401,8 +401,7 @@ class AgentNode(Node): current_plugin = next( plugin for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" - == cast(AgentNodeData, self._node_data).agent_strategy_provider_name + if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name ) icon = current_plugin.declaration.icon except StopIteration: diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 65dde51191..ff241b9cf0 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -301,12 +301,12 @@ def _extract_text_from_yaml(file_content: bytes) -> str: encoding = "utf-8" yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore")) - return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) + return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e: # If decoding fails, try with utf-8 as last resort try: yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore")) - return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) + return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) except (UnicodeDecodeError, yaml.YAMLError): raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index d7ccb338f5..bd5cab1e72 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -6,7 +6,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast -from sqlalchemy import Float, and_, func, or_, text +from sqlalchemy import Float, and_, func, or_, select, text from sqlalchemy import cast as sqlalchemy_cast from sqlalchemy.orm import sessionmaker @@ -75,7 +75,7 @@ default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, + "top_k": 4, "score_threshold_enabled": False, } @@ -358,15 +358,12 @@ class KnowledgeRetrievalNode(Node): for record in records: segment = record.segment dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore - document = ( - db.session.query(Document) - .where( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ) - .first() + stmt = select(Document).where( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ) + document = db.session.scalar(stmt) if dataset and document: source = { "metadata": { @@ -505,7 +502,8 @@ class KnowledgeRetrievalNode(Node): self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> list[dict[str, Any]]: # get all metadata field - metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() + stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) + metadata_fields = db.session.scalars(stmt).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] if node_data.metadata_model_config is None: raise ValueError("metadata_model_config is required") diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 3e4882fd1e..648ea69936 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -138,7 +138,7 @@ class ParameterExtractorNode(Node): """ Run the node. """ - node_data = cast(ParameterExtractorNodeData, self._node_data) + node_data = self._node_data variable = self.graph_runtime_state.variable_pool.get(node_data.query) query = variable.text if variable else "" diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 968332959c..afc45cf9cf 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -103,7 +103,7 @@ class QuestionClassifierNode(Node): return "1" def _run(self): - node_data = cast(QuestionClassifierNodeData, self._node_data) + node_data = self._node_data variable_pool = self.graph_runtime_state.variable_pool # extract variables diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 39524dcd4f..9708edcb38 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from sqlalchemy import select from sqlalchemy.orm import Session @@ -61,7 +61,7 @@ class ToolNode(Node): """ from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError - node_data = cast(ToolNodeData, self._node_data) + node_data = self._node_data # fetch tool icon tool_info = { diff --git a/api/core/workflow/nodes/tool/tool_node.py.orig b/api/core/workflow/nodes/tool/tool_node.py.orig new file mode 100644 index 0000000000..9708edcb38 --- /dev/null +++ b/api/core/workflow/nodes/tool/tool_node.py.orig @@ -0,0 +1,493 @@ +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.file import File, FileTransferMethod +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.variables.segments import ArrayAnySegment, ArrayFileSegment +from core.variables.variables import ArrayAnyVariable +from core.workflow.enums import ( + ErrorStrategy, + NodeType, + SystemVariableKey, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser +from extensions.ext_database import db +from factories import file_factory +from models import ToolFile +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .entities import ToolNodeData +from .exc import ( + ToolFileError, + ToolNodeError, + ToolParameterError, +) + +if TYPE_CHECKING: + from core.workflow.entities import VariablePool + + +class ToolNode(Node): + """ + Tool Node + """ + + node_type = NodeType.TOOL + + _node_data: ToolNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = ToolNodeData.model_validate(data) + + @classmethod + def version(cls) -> str: + return "1" + + def _run(self) -> Generator: + """ + Run the tool node + """ + from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError + + node_data = self._node_data + + # fetch tool icon + tool_info = { + "provider_type": node_data.provider_type.value, + "provider_id": node_data.provider_id, + "plugin_unique_identifier": node_data.plugin_unique_identifier, + } + + # get tool runtime + try: + from core.tools.tool_manager import ToolManager + + # This is an issue that caused problems before. + # Logically, we shouldn't use the node_data.version field for judgment + # But for backward compatibility with historical data + # this version field judgment is still preserved here. + variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version != "1": + variable_pool = self.graph_runtime_state.variable_pool + tool_runtime = ToolManager.get_workflow_tool_runtime( + self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool + ) + except ToolNodeError as e: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to get tool runtime: {str(e)}", + error_type=type(e).__name__, + ) + ) + return + + # get parameters + tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] + parameters = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self._node_data, + ) + parameters_for_log = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self._node_data, + for_log=True, + ) + # get conversation id + conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + + try: + message_stream = ToolEngine.generic_invoke( + tool=tool_runtime, + tool_parameters=parameters, + user_id=self.user_id, + workflow_tool_callback=DifyWorkflowCallbackHandler(), + workflow_call_depth=self.workflow_call_depth, + app_id=self.app_id, + conversation_id=conversation_id.text if conversation_id else None, + ) + except ToolNodeError as e: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to invoke tool: {str(e)}", + error_type=type(e).__name__, + ) + ) + return + + try: + # convert tool messages + yield from self._transform_message( + messages=message_stream, + tool_info=tool_info, + parameters_for_log=parameters_for_log, + user_id=self.user_id, + tenant_id=self.tenant_id, + node_id=self._node_id, + ) + except ToolInvokeError as e: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}", + error_type=type(e).__name__, + ) + ) + except PluginInvokeError as e: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, + error="An error occurred in the plugin, " + f"please contact the author of {node_data.provider_name} for help, " + f"error type: {e.get_error_type()}, " + f"error details: {e.get_error_message()}", + error_type=type(e).__name__, + ) + ) + except PluginDaemonClientSideError as e: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, + error=f"Failed to invoke tool, error: {e.description}", + error_type=type(e).__name__, + ) + ) + + def _generate_parameters( + self, + *, + tool_parameters: Sequence[ToolParameter], + variable_pool: "VariablePool", + node_data: ToolNodeData, + for_log: bool = False, + ) -> dict[str, Any]: + """ + Generate parameters based on the given tool parameters, variable pool, and node data. + + Args: + tool_parameters (Sequence[ToolParameter]): The list of tool parameters. + variable_pool (VariablePool): The variable pool containing the variables. + node_data (ToolNodeData): The data associated with the tool node. + + Returns: + Mapping[str, Any]: A dictionary containing the generated parameters. + + """ + tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} + + result: dict[str, Any] = {} + for parameter_name in node_data.tool_parameters: + parameter = tool_parameters_dictionary.get(parameter_name) + if not parameter: + result[parameter_name] = None + continue + tool_input = node_data.tool_parameters[parameter_name] + if tool_input.type == "variable": + variable = variable_pool.get(tool_input.value) + if variable is None: + if parameter.required: + raise ToolParameterError(f"Variable {tool_input.value} does not exist") + continue + parameter_value = variable.value + elif tool_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(tool_input.value)) + parameter_value = segment_group.log if for_log else segment_group.text + else: + raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") + result[parameter_name] = parameter_value + + return result + + def _fetch_files(self, variable_pool: "VariablePool") -> list[File]: + variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) + assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) + return list(variable.value) if variable else [] + + def _transform_message( + self, + messages: Generator[ToolInvokeMessage, None, None], + tool_info: Mapping[str, Any], + parameters_for_log: dict[str, Any], + user_id: str, + tenant_id: str, + node_id: str, + ) -> Generator: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + from core.plugin.impl.plugin import PluginInstaller + + message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json: list[dict] = [] + + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + ToolInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + else: + transfer_method = FileTransferMethod.TOOL_FILE + + tool_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileError(f"Tool file {tool_file_id} does not exist") + + mapping = { + "tool_file_id": tool_file_id, + "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + files.append(file) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.meta + + tool_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileError(f"tool file {tool_file_id} not exists") + + mapping = { + "tool_file_id": tool_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + ) + elif message.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + text += message.message.text + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk=message.message.text, + is_final=False, + ) + elif message.type == ToolInvokeMessage.MessageType.JSON: + assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + # JSON message handling for tool node + if message.message.json_object is not None: + json.append(message.message.json_object) + elif message.type == ToolInvokeMessage.MessageType.LINK: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk=stream_text, + is_final=False, + ) + elif message.type == ToolInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield StreamChunkEvent( + selector=[node_id, variable_name], + chunk=variable_value, + is_final=False, + ) + else: + variables[variable_name] = variable_value + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + assert isinstance(message.meta, dict) + # Validate that meta contains a 'file' key + if "file" not in message.meta: + raise ToolNodeError("File message is missing 'file' key in meta") + + # Validate that the file is an instance of File + if not isinstance(message.meta["file"], File): + raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") + files.append(message.meta["file"]) + elif message.type == ToolInvokeMessage.MessageType.LOG: + assert isinstance(message.message, ToolInvokeMessage.LogMessage) + if message.message.metadata: + icon = tool_info.get("icon", "") + dict_metadata = dict(message.message.metadata) + if dict_metadata.get("provider"): + manager = PluginInstaller() + plugins = manager.list_plugins(tenant_id) + try: + current_plugin = next( + plugin + for plugin in plugins + if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] + ) + icon = current_plugin.declaration.icon + except StopIteration: + pass + icon_dark = None + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + if provider.name == dict_metadata["provider"] + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + dict_metadata["icon"] = icon + dict_metadata["icon_dark"] = icon_dark + message.message.metadata = dict_metadata + + # Add agent_logs to outputs['json'] to ensure frontend can access thinking process + json_output: list[dict[str, Any]] = [] + + # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] + if json: + json_output.extend(json) + else: + json_output.append({"data": []}) + + # Send final chunk events for all streamed outputs + # Final chunk for text stream + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk="", + is_final=True, + ) + + # Final chunks for any streamed variables + for var_name in variables: + yield StreamChunkEvent( + selector=[self._node_id, var_name], + chunk="", + is_final=True, + ) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, + metadata={ + WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, + }, + inputs=parameters_for_log, + ) + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: Mapping[str, Any], + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + # Create typed NodeData from dict + typed_node_data = ToolNodeData.model_validate(node_data) + + result = {} + for parameter_name in typed_node_data.tool_parameters: + input = typed_node_data.tool_parameters[parameter_name] + if input.type == "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + elif input.type == "variable": + result[parameter_name] = input.value + elif input.type == "constant": + pass + + result = {node_id + "." + key: value for key, value in result.items()} + + return result + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 78f7b39a06..6f1c530b14 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, Optional from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError @@ -318,7 +318,6 @@ class WorkflowEntry: # init graph graph = Graph.init(graph_config=graph_dict, node_factory=node_factory) - node_cls = cast(type[Node], node_cls) # init workflow run state node_config = { "id": node_id, diff --git a/api/core/workflow/workflow_entry.py.orig b/api/core/workflow/workflow_entry.py.orig new file mode 100644 index 0000000000..69dd1bdebc --- /dev/null +++ b/api/core/workflow/workflow_entry.py.orig @@ -0,0 +1,445 @@ +import logging +import time +import uuid +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Optional + +from configs import dify_config +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.models import File +from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID +from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.graph import Graph +from core.workflow.graph_engine import GraphEngine +from core.workflow.graph_engine.command_channels import InMemoryChannel +from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from core.workflow.graph_engine.protocols.command_channel import CommandChannel +from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent +from core.workflow.nodes import NodeType +from core.workflow.nodes.base.node import Node +from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.system_variable import SystemVariable +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from factories import file_factory +from models.enums import UserFrom +from models.workflow import Workflow + +logger = logging.getLogger(__name__) + + +class WorkflowEntry: + def __init__( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + graph_config: Mapping[str, Any], + graph: Graph, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + graph_runtime_state: GraphRuntimeState, + command_channel: Optional[CommandChannel] = None, + ) -> None: + """ + Init workflow entry + :param tenant_id: tenant id + :param app_id: app id + :param workflow_id: workflow id + :param workflow_type: workflow type + :param graph_config: workflow graph config + :param graph: workflow graph + :param user_id: user id + :param user_from: user from + :param invoke_from: invoke from + :param call_depth: call depth + :param variable_pool: variable pool + :param graph_runtime_state: pre-created graph runtime state + :param command_channel: command channel for external control (optional, defaults to InMemoryChannel) + :param thread_pool_id: thread pool id + """ + # check call depth + workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH + if call_depth > workflow_call_max_depth: + raise ValueError(f"Max workflow call depth {workflow_call_max_depth} reached.") + + # Use provided command channel or default to InMemoryChannel + if command_channel is None: + command_channel = InMemoryChannel() + + self.command_channel = command_channel + self.graph_engine = GraphEngine( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + call_depth=call_depth, + graph=graph, + graph_config=graph_config, + graph_runtime_state=graph_runtime_state, + max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + command_channel=command_channel, + ) + + # Add debug logging layer when in debug mode + if dify_config.DEBUG: + logger.info("Debug mode enabled - adding DebugLoggingLayer to GraphEngine") + debug_layer = DebugLoggingLayer( + level="DEBUG", + include_inputs=True, + include_outputs=True, + include_process_data=False, # Process data can be very verbose + logger_name=f"GraphEngine.Debug.{workflow_id[:8]}", # Use workflow ID prefix for unique logger + ) + self.graph_engine.layer(debug_layer) + + # Add execution limits layer + limits_layer = ExecutionLimitsLayer( + max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME + ) + self.graph_engine.layer(limits_layer) + + def run(self) -> Generator[GraphEngineEvent, None, None]: + graph_engine = self.graph_engine + + try: + # run workflow + generator = graph_engine.run() + yield from generator + except GenerateTaskStoppedError: + pass + except Exception as e: + logger.exception("Unknown Error when workflow entry running") + yield GraphRunFailedEvent(error=str(e)) + return + + @classmethod + def single_step_run( + cls, + *, + workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: Mapping[str, Any], + variable_pool: VariablePool, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, + ) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]: + """ + Single step run workflow node + :param workflow: Workflow instance + :param node_id: node id + :param user_id: user id + :param user_inputs: user inputs + :return: + """ + node_config = workflow.get_node_config_by_id(node_id) + node_config_data = node_config.get("data", {}) + + # Get node class + node_type = NodeType(node_config_data.get("type")) + node_version = node_config_data.get("version", "1") + node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] + + # init graph init params and runtime state + graph_init_params = GraphInitParams( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + graph_config=workflow.graph_dict, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + # init node factory + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + # init graph + graph = Graph.init(graph_config=workflow.graph_dict, node_factory=node_factory) + + # init workflow run state + node = node_cls( + id=str(uuid.uuid4()), + config=node_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + node.init_node_data(node_config_data) + + try: + # variable selector to variable mapping + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, config=node_config + ) + except NotImplementedError: + variable_mapping = {} + + # Loading missing variable from draft var here, and set it into + # variable_pool. + load_into_variable_pool( + variable_loader=variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) + + cls.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + ) + + try: + # run node + generator = node.run() + except Exception as e: + logger.exception( + "error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s", + workflow.id, + node.id, + node.node_type, + node.version(), + ) + raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) + return node, generator + + @staticmethod + def _create_single_node_graph( + node_id: str, + node_data: dict[str, Any], + node_width: int = 114, + node_height: int = 514, + ) -> dict[str, Any]: + """ + Create a minimal graph structure for testing a single node in isolation. + + :param node_id: ID of the target node + :param node_data: configuration data for the target node + :param node_width: width for UI layout (default: 200) + :param node_height: height for UI layout (default: 100) + :return: graph dictionary with start node and target node + """ + node_config = { + "id": node_id, + "width": node_width, + "height": node_height, + "type": "custom", + "data": node_data, + } + start_node_config = { + "id": "start", + "width": node_width, + "height": node_height, + "type": "custom", + "data": { + "type": NodeType.START.value, + "title": "Start", + "desc": "Start", + }, + } + return { + "nodes": [start_node_config, node_config], + "edges": [ + { + "source": "start", + "target": node_id, + "sourceHandle": "source", + "targetHandle": "target", + } + ], + } + + @classmethod + def run_free_node( + cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any] + ) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]: + """ + Run free node + + NOTE: only parameter_extractor/question_classifier are supported + + :param node_data: node data + :param node_id: node id + :param tenant_id: tenant id + :param user_id: user id + :param user_inputs: user inputs + :return: + """ + # Create a minimal graph for single node execution + graph_dict = cls._create_single_node_graph(node_id, node_data) + + node_type = NodeType(node_data.get("type", "")) + if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}: + raise ValueError(f"Node type {node_type} not supported") + + node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"] + if not node_cls: + raise ValueError(f"Node class not found for node type {node_type}") + + # init variable pool + variable_pool = VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + environment_variables=[], + ) + + # init graph init params and runtime state + graph_init_params = GraphInitParams( + tenant_id=tenant_id, + app_id="", + workflow_id="", + graph_config=graph_dict, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + # init node factory + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + # init graph + graph = Graph.init(graph_config=graph_dict, node_factory=node_factory) + + node_cls = cast(type[Node], node_cls) + # init workflow run state + node_config = { + "id": node_id, + "data": node_data, + } + node: Node = node_cls( + id=str(uuid.uuid4()), + config=node_config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + node.init_node_data(node_data) + + try: + # variable selector to variable mapping + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=graph_dict, config=node_config + ) + except NotImplementedError: + variable_mapping = {} + + cls.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=tenant_id, + ) + + # run node + generator = node.run() + + return node, generator + except Exception as e: + logger.exception( + "error while running node, node_id=%s, node_type=%s, node_version=%s", + node.id, + node.node_type, + node.version(), + ) + raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) + + @staticmethod + def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: + # NOTE(QuantumGhost): Avoid using this function in new code. + # Keep values structured as long as possible and only convert to dict + # immediately before serialization (e.g., JSON serialization) to maintain + # data integrity and type information. + result = WorkflowEntry._handle_special_values(value) + return result if isinstance(result, Mapping) or result is None else dict(result) + + @staticmethod + def _handle_special_values(value: Any) -> Any: + if value is None: + return value + if isinstance(value, dict): + res = {} + for k, v in value.items(): + res[k] = WorkflowEntry._handle_special_values(v) + return res + if isinstance(value, list): + res_list = [] + for item in value: + res_list.append(WorkflowEntry._handle_special_values(item)) + return res_list + if isinstance(value, File): + return value.to_dict() + return value + + @classmethod + def mapping_user_inputs_to_variable_pool( + cls, + *, + variable_mapping: Mapping[str, Sequence[str]], + user_inputs: Mapping[str, Any], + variable_pool: VariablePool, + tenant_id: str, + ) -> None: + # NOTE(QuantumGhost): This logic should remain synchronized with + # the implementation of `load_into_variable_pool`, specifically the logic about + # variable existence checking. + + # WARNING(QuantumGhost): The semantics of this method are not clearly defined, + # and multiple parts of the codebase depend on its current behavior. + # Modify with caution. + for node_variable, variable_selector in variable_mapping.items(): + # fetch node id and variable key from node_variable + node_variable_list = node_variable.split(".") + if len(node_variable_list) < 1: + raise ValueError(f"Invalid node variable {node_variable}") + + node_variable_key = ".".join(node_variable_list[1:]) + + if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get( + variable_selector + ): + raise ValueError(f"Variable key {node_variable} not found in user inputs.") + + # environment variable already exist in variable pool, not from user inputs + if variable_pool.get(variable_selector): + continue + + # fetch variable node id from variable selector + variable_node_id = variable_selector[0] + variable_key_list = variable_selector[1:] + variable_key_list = list(variable_key_list) + + # get input value + input_value = user_inputs.get(node_variable) + if not input_value: + input_value = user_inputs.get(node_variable_key) + + if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value: + input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id) + if ( + isinstance(input_value, list) + and all(isinstance(item, dict) for item in input_value) + and all("type" in item and "transfer_method" in item for item in input_value) + ): + input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id) + + # append variable and value to variable pool + if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID: + variable_pool.add([variable_node_id] + variable_key_list, input_value) diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 6680bc692d..35feaae4e0 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, Ch from core.entities.provider_entities import QuotaUnit, SystemConfiguration from events.message_event import message_was_created from extensions.ext_database import db +from extensions.ext_redis import redis_client, redis_fallback from libs import datetime_utils from models.model import Message from models.provider import Provider, ProviderType @@ -19,6 +20,32 @@ from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) +# Redis cache key prefix for provider last used timestamps +_PROVIDER_LAST_USED_CACHE_PREFIX = "provider:last_used" +# Default TTL for cache entries (10 minutes) +_CACHE_TTL_SECONDS = 600 +LAST_USED_UPDATE_WINDOW_SECONDS = 60 * 5 + + +def _get_provider_cache_key(tenant_id: str, provider_name: str) -> str: + """Generate Redis cache key for provider last used timestamp.""" + return f"{_PROVIDER_LAST_USED_CACHE_PREFIX}:{tenant_id}:{provider_name}" + + +@redis_fallback(default_return=None) +def _get_last_update_timestamp(cache_key: str) -> Optional[datetime]: + """Get last update timestamp from Redis cache.""" + timestamp_str = redis_client.get(cache_key) + if timestamp_str: + return datetime.fromtimestamp(float(timestamp_str.decode("utf-8"))) + return None + + +@redis_fallback() +def _set_last_update_timestamp(cache_key: str, timestamp: datetime) -> None: + """Set last update timestamp in Redis cache with TTL.""" + redis_client.setex(cache_key, _CACHE_TTL_SECONDS, str(timestamp.timestamp())) + class _ProviderUpdateFilters(BaseModel): """Filters for identifying Provider records to update.""" @@ -139,7 +166,7 @@ def handle(sender: Message, **kwargs): provider_name, ) - except Exception as e: + except Exception: # Log failure with timing and context duration = time_module.perf_counter() - start_time @@ -215,8 +242,23 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation] # Prepare values dict for SQLAlchemy update update_values = {} - # updateing to `last_used` is removed due to performance reason. - # ref: https://github.com/langgenius/dify/issues/24526 + + # NOTE: For frequently used providers under high load, this implementation may experience + # race conditions or update contention despite the time-window optimization: + # 1. Multiple concurrent requests might check the same cache key simultaneously + # 2. Redis cache operations are not atomic with database updates + # 3. Heavy providers could still face database lock contention during peak usage + # The current implementation is acceptable for most scenarios, but future optimization + # considerations could include: batched updates, or async processing. + if values.last_used is not None: + cache_key = _get_provider_cache_key(filters.tenant_id, filters.provider_name) + now = datetime_utils.naive_utc_now() + last_update = _get_last_update_timestamp(cache_key) + + if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS: + update_values["last_used"] = values.last_used + _set_last_update_timestamp(cache_key, now) + if values.quota_used is not None: update_values["quota_used"] = values.quota_used # Skip the current update operation if no updates are required. diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 0ea7d3ae1e..62e3bfa3ba 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -3,7 +3,7 @@ import os import urllib.parse import uuid from collections.abc import Callable, Mapping, Sequence -from typing import Any, cast +from typing import Any import httpx from sqlalchemy import select @@ -258,7 +258,6 @@ def _get_remote_file_info(url: str): mime_type = "" resp = ssrf_proxy.head(url, follow_redirects=True) - resp = cast(httpx.Response, resp) if resp.status_code == httpx.codes.OK: if content_disposition := resp.headers.get("Content-Disposition"): filename = str(content_disposition.split("filename=")[-1].strip('"')) diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index f048d0f3b6..53cb9de3ee 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -17,7 +17,7 @@ class EnvironmentVariableField(fields.Raw): return { "id": value.id, "name": value.name, - "value": encrypter.obfuscated_token(value.value), + "value": encrypter.full_mask_token(), "value_type": value.value_type.value, "description": value.description, } diff --git a/api/migrations/versions/2025_08_29_1534-b95962a3885c_add_workflow_app_log_run_id_index.py b/api/migrations/versions/2025_08_29_1534-b95962a3885c_add_workflow_app_log_run_id_index.py new file mode 100644 index 0000000000..465f8664a5 --- /dev/null +++ b/api/migrations/versions/2025_08_29_1534-b95962a3885c_add_workflow_app_log_run_id_index.py @@ -0,0 +1,32 @@ +"""chore: add workflow app log run id index + +Revision ID: b95962a3885c +Revises: 0e154742a5fa +Create Date: 2025-08-29 15:34:09.838623 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'b95962a3885c' +down_revision = '8d289573e1da' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op: + batch_op.create_index('workflow_app_log_workflow_run_id_idx', ['workflow_run_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op: + batch_op.drop_index('workflow_app_log_workflow_run_id_idx') + # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index 26bbc03694..2d15b778e9 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -319,7 +319,7 @@ class MCPToolProvider(Base): @property def decrypted_server_url(self) -> str: - return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url)) + return encrypter.decrypt_token(self.tenant_id, self.server_url) @property def masked_server_url(self) -> str: diff --git a/api/models/workflow.py b/api/models/workflow.py index 842b227028..3c449c89cc 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -835,6 +835,7 @@ class WorkflowAppLog(Base): __table_args__ = ( sa.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), sa.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), + sa.Index("workflow_app_log_workflow_run_id_idx", "workflow_run_id"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) diff --git a/api/services/account_service.py b/api/services/account_service.py index 089e667166..50ce171ded 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -146,7 +146,7 @@ class AccountService: account.last_active_at = naive_utc_now() db.session.commit() - return cast(Account, account) + return account @staticmethod def get_account_jwt_token(account: Account) -> str: @@ -191,7 +191,7 @@ class AccountService: db.session.commit() - return cast(Account, account) + return account @staticmethod def update_account_password(account, password, new_password): @@ -1127,7 +1127,7 @@ class TenantService: def get_custom_config(tenant_id: str) -> dict: tenant = db.get_or_404(Tenant, tenant_id) - return cast(dict, tenant.custom_config_dict) + return tenant.custom_config_dict @staticmethod def is_owner(account: Account, tenant: Tenant) -> bool: diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 6603063c22..9ee92bc2dc 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,5 +1,5 @@ import uuid -from typing import cast +from typing import Optional import pandas as pd from flask_login import current_user @@ -40,7 +40,7 @@ class AppAnnotationService: if not message: raise NotFound("Message Not Exists.") - annotation = message.annotation + annotation: Optional[MessageAnnotation] = message.annotation # save the message annotation if annotation: annotation.content = args["answer"] @@ -70,7 +70,7 @@ class AppAnnotationService: app_id, annotation_setting.collection_binding_id, ) - return cast(MessageAnnotation, annotation) + return annotation @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 0d10aa15dd..bc574fd8bc 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1149,7 +1149,7 @@ class DocumentService: "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, + "top_k": 4, "score_threshold_enabled": False, } @@ -1612,7 +1612,7 @@ class DocumentService: search_method=RetrievalMethod.SEMANTIC_SEARCH.value, reranking_enable=False, reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""), - top_k=2, + top_k=4, score_threshold_enabled=False, ) # save dataset diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 1517ca6594..bce28da032 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -18,7 +18,7 @@ default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, + "top_k": 4, "score_threshold_enabled": False, } @@ -66,7 +66,7 @@ class HitTestingService: retrieval_method=retrieval_model.get("search_method", "semantic_search"), dataset_id=dataset.id, query=query, - top_k=retrieval_model.get("top_k", 2), + top_k=retrieval_model.get("top_k", 4), score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, diff --git a/api/services/message_service.py b/api/services/message_service.py index a19d6ee157..13c8e948ca 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -112,7 +112,9 @@ class MessageService: base_query = base_query.where(Message.conversation_id == conversation.id) # Check if include_ids is not None and not empty to avoid WHERE false condition - if include_ids is not None and len(include_ids) > 0: + if include_ids is not None: + if len(include_ids) == 0: + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) base_query = base_query.where(Message.id.in_(include_ids)) if last_id: diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 88e8697d17..955e898ec1 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -27,73 +27,73 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): documents = [] start_at = time.perf_counter() - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if dataset is None: - logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) - db.session.close() - return - - # check document limit - features = FeatureService.get_features(dataset.tenant_id) try: - if features.billing.enabled: - vector_space = features.vector_space - count = len(document_ids) - if features.billing.subscription.plan == "sandbox" and count > 1: - raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - if 0 < vector_space.limit <= vector_space.size: - raise ValueError( - "Your total number of documents plus the number of uploads have over the limit of " - "your subscription." + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + if dataset is None: + logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) + db.session.close() + return + + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + count = len(document_ids) + if features.billing.subscription.plan == "sandbox" and count > 1: + raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + if 0 < vector_space.limit <= vector_space.size: + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) + except Exception as e: + for document_id in document_ids: + document = ( + db.session.query(Document) + .where(Document.id == document_id, Document.dataset_id == dataset_id) + .first() ) - except Exception as e: + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = naive_utc_now() + db.session.add(document) + db.session.commit() + return + for document_id in document_ids: + logger.info(click.style(f"Start process document: {document_id}", fg="green")) + document = ( db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) + if document: - document.indexing_status = "error" - document.error = str(e) - document.stopped_at = naive_utc_now() + # clean old data + index_type = document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + db.session.delete(segment) + db.session.commit() + + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() + documents.append(document) db.session.add(document) db.session.commit() - return - finally: - db.session.close() - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) - - document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) - - if document: - # clean old data - index_type = document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = naive_utc_now() - documents.append(document) - db.session.add(document) - db.session.commit() - - try: indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e6f3f0ddf6..fe46ed7658 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -1,7 +1,6 @@ import time import uuid from os import getenv -from typing import cast import pytest @@ -11,7 +10,6 @@ from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import NodeRunResult from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.code.entities import CodeNodeData from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.system_variable import SystemVariable from models.enums import UserFrom @@ -242,8 +240,6 @@ def test_execute_code_output_validator_depth(): "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, } - node._node_data = cast(CodeNodeData, node._node_data) - # validate node._transform_result(result, node._node_data.outputs) @@ -338,8 +334,6 @@ def test_execute_code_output_object_list(): ] } - node._node_data = cast(CodeNodeData, node._node_data) - # validate node._transform_result(result, node._node_data.outputs) diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py.orig b/api/tests/integration_tests/workflow/nodes/test_code.py.orig new file mode 100644 index 0000000000..fe46ed7658 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/test_code.py.orig @@ -0,0 +1,390 @@ +import time +import uuid +from os import getenv + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool +from core.workflow.enums import WorkflowNodeExecutionStatus +from core.workflow.graph import Graph +from core.workflow.node_events import NodeRunResult +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.node_factory import DifyNodeFactory +from core.workflow.system_variable import SystemVariable +from models.enums import UserFrom +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock + +CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) + + +def init_code_node(code_config: dict): + graph_config = { + "edges": [ + { + "id": "start-source-code-target", + "source": "start", + "target": "code", + }, + ], + "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config], + } + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="aaa", files=[]), + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["code", "args1"], 1) + variable_pool.add(["code", "args2"], 2) + + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + # Create node factory + node_factory = DifyNodeFactory( + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + ) + + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + + node = CodeNode( + id=str(uuid.uuid4()), + config=code_config, + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + ) + + # Initialize node data + if "data" in code_config: + node.init_node_data(code_config["data"]) + + return node + + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) +def test_execute_code(setup_code_executor_mock): + code = """ + def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + } + """ + # trim first 4 spaces at the beginning of each line + code = "\n".join([line[4:] for line in code.split("\n")]) + + code_config = { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "number", + }, + }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) + node.graph_runtime_state.variable_pool.add(["1", "args1"], 1) + node.graph_runtime_state.variable_pool.add(["1", "args2"], 2) + + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert result.outputs["result"] == 3 + assert result.error == "" + + +@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) +def test_execute_code_output_validator(setup_code_executor_mock): + code = """ + def main(args1: int, args2: int) -> dict: + return { + "result": args1 + args2, + } + """ + # trim first 4 spaces at the beginning of each line + code = "\n".join([line[4:] for line in code.split("\n")]) + + code_config = { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "string", + }, + }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) + node.graph_runtime_state.variable_pool.add(["1", "args1"], 1) + node.graph_runtime_state.variable_pool.add(["1", "args2"], 2) + + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "Output variable `result` must be a string" + + +def test_execute_code_output_validator_depth(): + code = """ + def main(args1: int, args2: int) -> dict: + return { + "result": { + "result": args1 + args2, + } + } + """ + # trim first 4 spaces at the beginning of each line + code = "\n".join([line[4:] for line in code.split("\n")]) + + code_config = { + "id": "code", + "data": { + "outputs": { + "string_validator": { + "type": "string", + }, + "number_validator": { + "type": "number", + }, + "number_array_validator": { + "type": "array[number]", + }, + "string_array_validator": { + "type": "array[string]", + }, + "object_validator": { + "type": "object", + "children": { + "result": { + "type": "number", + }, + "depth": { + "type": "object", + "children": { + "depth": { + "type": "object", + "children": { + "depth": { + "type": "number", + } + }, + } + }, + }, + }, + }, + }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1", + "number_array_validator": [1, 2, 3, 3.333], + "string_array_validator": ["1", "2", "3"], + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, + } + + # validate + node._transform_result(result, node._node_data.outputs) + + # construct result + result = { + "number_validator": "1", + "string_validator": 1, + "number_array_validator": ["1", "2", "3", "3.333"], + "string_array_validator": [1, 2, 3], + "object_validator": {"result": "1", "depth": {"depth": {"depth": "1"}}}, + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node._node_data.outputs) + + # construct result + result = { + "number_validator": 1, + "string_validator": (CODE_MAX_STRING_LENGTH + 1) * "1", + "number_array_validator": [1, 2, 3, 3.333], + "string_array_validator": ["1", "2", "3"], + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node._node_data.outputs) + + # construct result + result = { + "number_validator": 1, + "string_validator": "1", + "number_array_validator": [1, 2, 3, 3.333] * 2000, + "string_array_validator": ["1", "2", "3"], + "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node._node_data.outputs) + + +def test_execute_code_output_object_list(): + code = """ + def main(args1: int, args2: int) -> dict: + return { + "result": { + "result": args1 + args2, + } + } + """ + # trim first 4 spaces at the beginning of each line + code = "\n".join([line[4:] for line in code.split("\n")]) + + code_config = { + "id": "code", + "data": { + "outputs": { + "object_list": { + "type": "array[object]", + }, + }, + "title": "123", + "variables": [ + { + "variable": "args1", + "value_selector": ["1", "args1"], + }, + {"variable": "args2", "value_selector": ["1", "args2"]}, + ], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) + + # construct result + result = { + "object_list": [ + { + "result": 1, + }, + { + "result": 2, + }, + { + "result": [1, 2, 3], + }, + ] + } + + # validate + node._transform_result(result, node._node_data.outputs) + + # construct result + result = { + "object_list": [ + { + "result": 1, + }, + { + "result": 2, + }, + { + "result": [1, 2, 3], + }, + 1, + ] + } + + # validate + with pytest.raises(ValueError): + node._transform_result(result, node._node_data.outputs) + + +def test_execute_code_scientific_notation(): + code = """ + def main() -> dict: + return { + "result": -8.0E-5 + } + """ + code = "\n".join([line[4:] for line in code.split("\n")]) + + code_config = { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "number", + }, + }, + "title": "123", + "variables": [], + "answer": "123", + "code_language": "python3", + "code": code, + }, + } + + node = init_code_node(code_config) + # execute node + result = node._run() + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py new file mode 100644 index 0000000000..bf25968100 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -0,0 +1,788 @@ +from unittest.mock import Mock, patch + +import pytest +from faker import Faker + +from core.tools.entities.api_entities import ToolProviderApiEntity +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolProviderType +from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider +from services.tools.tools_transform_service import ToolTransformService + + +class TestToolTransformService: + """Integration tests for ToolTransformService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.tools.tools_transform_service.dify_config") as mock_dify_config, + ): + # Setup default mock returns + mock_dify_config.CONSOLE_API_URL = "https://console.example.com" + + yield { + "dify_config": mock_dify_config, + } + + def _create_test_tool_provider( + self, db_session_with_containers, mock_external_service_dependencies, provider_type="api" + ): + """ + Helper method to create a test tool provider for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + provider_type: Type of provider to create + + Returns: + Tool provider instance + """ + fake = Faker() + + if provider_type == "api": + provider = ApiToolProvider( + name=fake.company(), + description=fake.text(max_nb_chars=100), + icon='{"background": "#FF6B6B", "content": "🔧"}', + icon_dark='{"background": "#252525", "content": "🔧"}', + tenant_id="test_tenant_id", + user_id="test_user_id", + credentials={"auth_type": "api_key_header", "api_key": "test_key"}, + provider_type="api", + ) + elif provider_type == "builtin": + provider = BuiltinToolProvider( + name=fake.company(), + description=fake.text(max_nb_chars=100), + icon="🔧", + icon_dark="🔧", + tenant_id="test_tenant_id", + provider="test_provider", + credential_type="api_key", + credentials={"api_key": "test_key"}, + ) + elif provider_type == "workflow": + provider = WorkflowToolProvider( + name=fake.company(), + description=fake.text(max_nb_chars=100), + icon='{"background": "#FF6B6B", "content": "🔧"}', + icon_dark='{"background": "#252525", "content": "🔧"}', + tenant_id="test_tenant_id", + user_id="test_user_id", + workflow_id="test_workflow_id", + ) + elif provider_type == "mcp": + provider = MCPToolProvider( + name=fake.company(), + description=fake.text(max_nb_chars=100), + provider_icon='{"background": "#FF6B6B", "content": "🔧"}', + tenant_id="test_tenant_id", + user_id="test_user_id", + server_url="https://mcp.example.com", + server_identifier="test_server", + tools='[{"name": "test_tool", "description": "Test tool"}]', + authed=True, + ) + else: + raise ValueError(f"Unknown provider type: {provider_type}") + + from extensions.ext_database import db + + db.session.add(provider) + db.session.commit() + + return provider + + def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful plugin icon URL generation. + + This test verifies: + - Proper URL construction for plugin icons + - Correct tenant_id and filename handling + - URL format compliance + """ + # Arrange: Setup test data + fake = Faker() + tenant_id = fake.uuid4() + filename = "test_icon.png" + + # Act: Execute the method under test + result = ToolTransformService.get_plugin_icon_url(tenant_id, filename) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, str) + assert "console/api/workspaces/current/plugin/icon" in result + assert tenant_id in result + assert filename in result + assert result.startswith("https://console.example.com") + + # Verify URL structure + expected_url = f"https://console.example.com/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={filename}" + assert result == expected_url + + def test_get_plugin_icon_url_with_empty_console_url( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test plugin icon URL generation when CONSOLE_API_URL is empty. + + This test verifies: + - Fallback to relative URL when CONSOLE_API_URL is None + - Proper URL construction with relative path + """ + # Arrange: Setup mock with empty console URL + mock_external_service_dependencies["dify_config"].CONSOLE_API_URL = None + fake = Faker() + tenant_id = fake.uuid4() + filename = "test_icon.png" + + # Act: Execute the method under test + result = ToolTransformService.get_plugin_icon_url(tenant_id, filename) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, str) + assert result.startswith("/console/api/workspaces/current/plugin/icon") + assert tenant_id in result + assert filename in result + + # Verify URL structure + expected_url = f"/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={filename}" + assert result == expected_url + + def test_get_tool_provider_icon_url_builtin_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful tool provider icon URL generation for builtin providers. + + This test verifies: + - Proper URL construction for builtin tool providers + - Correct provider type handling + - URL format compliance + """ + # Arrange: Setup test data + fake = Faker() + provider_type = ToolProviderType.BUILT_IN.value + provider_name = fake.company() + icon = "🔧" + + # Act: Execute the method under test + result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, str) + assert "console/api/workspaces/current/tool-provider/builtin" in result + # Note: provider_name may contain spaces that get URL encoded + assert provider_name.replace(" ", "%20") in result or provider_name in result + assert result.endswith("/icon") + assert result.startswith("https://console.example.com") + + # Verify URL structure (accounting for URL encoding) + # The actual result will have URL-encoded spaces (%20), so we need to compare accordingly + expected_url = ( + f"https://console.example.com/console/api/workspaces/current/tool-provider/builtin/{provider_name}/icon" + ) + # Convert expected URL to match the actual URL encoding + expected_encoded = expected_url.replace(" ", "%20") + assert result == expected_encoded + + def test_get_tool_provider_icon_url_api_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful tool provider icon URL generation for API providers. + + This test verifies: + - Proper icon handling for API tool providers + - JSON string parsing for icon data + - Fallback icon when parsing fails + """ + # Arrange: Setup test data + fake = Faker() + provider_type = ToolProviderType.API.value + provider_name = fake.company() + icon = '{"background": "#FF6B6B", "content": "🔧"}' + + # Act: Execute the method under test + result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, dict) + assert result["background"] == "#FF6B6B" + assert result["content"] == "🔧" + + def test_get_tool_provider_icon_url_api_invalid_json( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tool provider icon URL generation for API providers with invalid JSON. + + This test verifies: + - Proper fallback when JSON parsing fails + - Default icon structure when exception occurs + """ + # Arrange: Setup test data with invalid JSON + fake = Faker() + provider_type = ToolProviderType.API.value + provider_name = fake.company() + icon = '{"invalid": json}' + + # Act: Execute the method under test + result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, dict) + assert result["background"] == "#252525" + # Note: emoji characters may be represented as Unicode escape sequences + assert result["content"] == "😁" or result["content"] == "\ud83d\ude01" + + def test_get_tool_provider_icon_url_workflow_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful tool provider icon URL generation for workflow providers. + + This test verifies: + - Proper icon handling for workflow tool providers + - Direct icon return for workflow type + """ + # Arrange: Setup test data + fake = Faker() + provider_type = ToolProviderType.WORKFLOW.value + provider_name = fake.company() + icon = {"background": "#FF6B6B", "content": "🔧"} + + # Act: Execute the method under test + result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, dict) + assert result["background"] == "#FF6B6B" + assert result["content"] == "🔧" + + def test_get_tool_provider_icon_url_mcp_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful tool provider icon URL generation for MCP providers. + + This test verifies: + - Direct icon return for MCP type + - No URL transformation for MCP providers + """ + # Arrange: Setup test data + fake = Faker() + provider_type = ToolProviderType.MCP.value + provider_name = fake.company() + icon = {"background": "#FF6B6B", "content": "🔧"} + + # Act: Execute the method under test + result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, dict) + assert result["background"] == "#FF6B6B" + assert result["content"] == "🔧" + + def test_get_tool_provider_icon_url_unknown_type( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tool provider icon URL generation for unknown provider types. + + This test verifies: + - Empty string return for unknown provider types + - Proper handling of unsupported types + """ + # Arrange: Setup test data with unknown type + fake = Faker() + provider_type = "unknown_type" + provider_name = fake.company() + icon = "🔧" + + # Act: Execute the method under test + result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon) + + # Assert: Verify the expected outcomes + assert result == "" + + def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful provider repacking with dictionary input. + + This test verifies: + - Proper icon URL generation for dictionary providers + - Correct provider type handling + - Icon transformation for different provider types + """ + # Arrange: Setup test data + fake = Faker() + tenant_id = fake.uuid4() + provider = {"type": ToolProviderType.BUILT_IN.value, "name": fake.company(), "icon": "🔧"} + + # Act: Execute the method under test + ToolTransformService.repack_provider(tenant_id, provider) + + # Assert: Verify the expected outcomes + assert "icon" in provider + assert isinstance(provider["icon"], str) + assert "console/api/workspaces/current/tool-provider/builtin" in provider["icon"] + # Note: provider name may contain spaces that get URL encoded + assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"] + + def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful provider repacking with ToolProviderApiEntity input. + + This test verifies: + - Proper icon URL generation for entity providers + - Plugin icon handling when plugin_id is present + - Regular icon handling when plugin_id is not present + """ + # Arrange: Setup test data + fake = Faker() + tenant_id = fake.uuid4() + + # Create provider entity with plugin_id + provider = ToolProviderApiEntity( + id=fake.uuid4(), + author=fake.name(), + name=fake.company(), + description=I18nObject(en_US=fake.text(max_nb_chars=100)), + icon="test_icon.png", + icon_dark="test_icon_dark.png", + label=I18nObject(en_US=fake.company()), + type=ToolProviderType.API, + masked_credentials={}, + is_team_authorization=True, + plugin_id="test_plugin_id", + tools=[], + labels=[], + ) + + # Act: Execute the method under test + ToolTransformService.repack_provider(tenant_id, provider) + + # Assert: Verify the expected outcomes + assert provider.icon is not None + assert isinstance(provider.icon, str) + assert "console/api/workspaces/current/plugin/icon" in provider.icon + assert tenant_id in provider.icon + assert "test_icon.png" in provider.icon + + # Verify dark icon handling + assert provider.icon_dark is not None + assert isinstance(provider.icon_dark, str) + assert "console/api/workspaces/current/plugin/icon" in provider.icon_dark + assert tenant_id in provider.icon_dark + assert "test_icon_dark.png" in provider.icon_dark + + def test_repack_provider_entity_no_plugin_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful provider repacking with ToolProviderApiEntity input without plugin_id. + + This test verifies: + - Proper icon URL generation for non-plugin providers + - Regular tool provider icon handling + - Dark icon handling when present + """ + # Arrange: Setup test data + fake = Faker() + tenant_id = fake.uuid4() + + # Create provider entity without plugin_id + provider = ToolProviderApiEntity( + id=fake.uuid4(), + author=fake.name(), + name=fake.company(), + description=I18nObject(en_US=fake.text(max_nb_chars=100)), + icon='{"background": "#FF6B6B", "content": "🔧"}', + icon_dark='{"background": "#252525", "content": "🔧"}', + label=I18nObject(en_US=fake.company()), + type=ToolProviderType.API, + masked_credentials={}, + is_team_authorization=True, + plugin_id=None, + tools=[], + labels=[], + ) + + # Act: Execute the method under test + ToolTransformService.repack_provider(tenant_id, provider) + + # Assert: Verify the expected outcomes + assert provider.icon is not None + assert isinstance(provider.icon, dict) + assert provider.icon["background"] == "#FF6B6B" + assert provider.icon["content"] == "🔧" + + # Verify dark icon handling + assert provider.icon_dark is not None + assert isinstance(provider.icon_dark, dict) + assert provider.icon_dark["background"] == "#252525" + assert provider.icon_dark["content"] == "🔧" + + def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test provider repacking with ToolProviderApiEntity input without dark icon. + + This test verifies: + - Proper handling when icon_dark is None or empty + - No errors when dark icon is not present + """ + # Arrange: Setup test data + fake = Faker() + tenant_id = fake.uuid4() + + # Create provider entity without dark icon + provider = ToolProviderApiEntity( + id=fake.uuid4(), + author=fake.name(), + name=fake.company(), + description=I18nObject(en_US=fake.text(max_nb_chars=100)), + icon='{"background": "#FF6B6B", "content": "🔧"}', + icon_dark=None, + label=I18nObject(en_US=fake.company()), + type=ToolProviderType.API, + masked_credentials={}, + is_team_authorization=True, + plugin_id=None, + tools=[], + labels=[], + ) + + # Act: Execute the method under test + ToolTransformService.repack_provider(tenant_id, provider) + + # Assert: Verify the expected outcomes + assert provider.icon is not None + assert isinstance(provider.icon, dict) + assert provider.icon["background"] == "#FF6B6B" + assert provider.icon["content"] == "🔧" + + # Verify dark icon remains None + assert provider.icon_dark is None + + def test_builtin_provider_to_user_provider_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful conversion of builtin provider to user provider. + + This test verifies: + - Proper entity creation with all required fields + - Credentials schema handling + - Team authorization setup + - Plugin ID handling + """ + # Arrange: Setup test data + fake = Faker() + + # Create mock provider controller + mock_controller = Mock() + mock_controller.entity.identity.name = fake.company() + mock_controller.entity.identity.author = fake.name() + mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100)) + mock_controller.entity.identity.icon = "🔧" + mock_controller.entity.identity.icon_dark = "🔧" + mock_controller.entity.identity.label = I18nObject(en_US=fake.company()) + mock_controller.plugin_id = None + mock_controller.plugin_unique_identifier = None + mock_controller.tool_labels = ["label1", "label2"] + mock_controller.need_credentials = True + + # Mock credentials schema + mock_credential = Mock() + mock_credential.to_basic_provider_config.return_value.name = "api_key" + mock_controller.get_credentials_schema_by_type.return_value = [mock_credential] + + # Create mock database provider + mock_db_provider = Mock() + mock_db_provider.credential_type = "api-key" + mock_db_provider.tenant_id = fake.uuid4() + mock_db_provider.credentials = {"api_key": "encrypted_key"} + + # Mock encryption + with patch("services.tools.tools_transform_service.create_provider_encrypter") as mock_encrypter: + mock_encrypter_instance = Mock() + mock_encrypter_instance.decrypt.return_value = {"api_key": "decrypted_key"} + mock_encrypter_instance.mask_tool_credentials.return_value = {"api_key": ""} + mock_encrypter.return_value = (mock_encrypter_instance, None) + + # Act: Execute the method under test + result = ToolTransformService.builtin_provider_to_user_provider( + mock_controller, mock_db_provider, decrypt_credentials=True + ) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.id == mock_controller.entity.identity.name + assert result.author == mock_controller.entity.identity.author + assert result.name == mock_controller.entity.identity.name + assert result.description == mock_controller.entity.identity.description + assert result.icon == mock_controller.entity.identity.icon + assert result.icon_dark == mock_controller.entity.identity.icon_dark + assert result.label == mock_controller.entity.identity.label + assert result.type == ToolProviderType.BUILT_IN + assert result.is_team_authorization is True + assert result.plugin_id is None + assert result.tools == [] + assert result.labels == ["label1", "label2"] + assert result.masked_credentials == {"api_key": ""} + assert result.original_credentials == {"api_key": "decrypted_key"} + + def test_builtin_provider_to_user_provider_plugin_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful conversion of builtin provider to user provider with plugin. + + This test verifies: + - Plugin ID and unique identifier handling + - Proper entity creation for plugin providers + """ + # Arrange: Setup test data + fake = Faker() + + # Create mock provider controller with plugin + mock_controller = Mock() + mock_controller.entity.identity.name = fake.company() + mock_controller.entity.identity.author = fake.name() + mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100)) + mock_controller.entity.identity.icon = "🔧" + mock_controller.entity.identity.icon_dark = "🔧" + mock_controller.entity.identity.label = I18nObject(en_US=fake.company()) + mock_controller.plugin_id = "test_plugin_id" + mock_controller.plugin_unique_identifier = "test_unique_id" + mock_controller.tool_labels = ["label1"] + mock_controller.need_credentials = False + + # Mock credentials schema + mock_credential = Mock() + mock_credential.to_basic_provider_config.return_value.name = "api_key" + mock_controller.get_credentials_schema_by_type.return_value = [mock_credential] + + # Act: Execute the method under test + result = ToolTransformService.builtin_provider_to_user_provider( + mock_controller, None, decrypt_credentials=False + ) + + # Assert: Verify the expected outcomes + assert result is not None + # Note: The method checks isinstance(provider_controller, PluginToolProviderController) + # Since we're using a Mock, this check will fail, so plugin_id will remain None + # In a real test with actual PluginToolProviderController, this would work + assert result.is_team_authorization is True + assert result.allow_delete is False + + def test_builtin_provider_to_user_provider_no_credentials( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test conversion of builtin provider to user provider without credentials. + + This test verifies: + - Proper handling when no credentials are needed + - Team authorization setup for no-credentials providers + """ + # Arrange: Setup test data + fake = Faker() + + # Create mock provider controller + mock_controller = Mock() + mock_controller.entity.identity.name = fake.company() + mock_controller.entity.identity.author = fake.name() + mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100)) + mock_controller.entity.identity.icon = "🔧" + mock_controller.entity.identity.icon_dark = "🔧" + mock_controller.entity.identity.label = I18nObject(en_US=fake.company()) + mock_controller.plugin_id = None + mock_controller.plugin_unique_identifier = None + mock_controller.tool_labels = [] + mock_controller.need_credentials = False + + # Mock credentials schema + mock_credential = Mock() + mock_credential.to_basic_provider_config.return_value.name = "api_key" + mock_controller.get_credentials_schema_by_type.return_value = [mock_credential] + + # Act: Execute the method under test + result = ToolTransformService.builtin_provider_to_user_provider( + mock_controller, None, decrypt_credentials=False + ) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.is_team_authorization is True + assert result.allow_delete is False + assert result.masked_credentials == {} + + def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful conversion of API provider to controller. + + This test verifies: + - Proper controller creation from database provider + - Auth type handling for different credential types + - Backward compatibility for auth types + """ + # Arrange: Setup test data + fake = Faker() + + # Create API tool provider with api_key_header auth + provider = ApiToolProvider( + name=fake.company(), + description=fake.text(max_nb_chars=100), + icon='{"background": "#FF6B6B", "content": "🔧"}', + tenant_id=fake.uuid4(), + user_id=fake.uuid4(), + credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', + schema="{}", + schema_type_str="openapi", + tools_str="[]", + ) + + from extensions.ext_database import db + + db.session.add(provider) + db.session.commit() + + # Act: Execute the method under test + result = ToolTransformService.api_provider_to_controller(provider) + + # Assert: Verify the expected outcomes + assert result is not None + assert hasattr(result, "from_db") + # Additional assertions would depend on the actual controller implementation + + def test_api_provider_to_controller_api_key_query( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test conversion of API provider to controller with api_key_query auth type. + + This test verifies: + - Proper auth type handling for query parameter authentication + """ + # Arrange: Setup test data + fake = Faker() + + # Create API tool provider with api_key_query auth + provider = ApiToolProvider( + name=fake.company(), + description=fake.text(max_nb_chars=100), + icon='{"background": "#FF6B6B", "content": "🔧"}', + tenant_id=fake.uuid4(), + user_id=fake.uuid4(), + credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}', + schema="{}", + schema_type_str="openapi", + tools_str="[]", + ) + + from extensions.ext_database import db + + db.session.add(provider) + db.session.commit() + + # Act: Execute the method under test + result = ToolTransformService.api_provider_to_controller(provider) + + # Assert: Verify the expected outcomes + assert result is not None + assert hasattr(result, "from_db") + + def test_api_provider_to_controller_backward_compatibility( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test conversion of API provider to controller with backward compatibility auth types. + + This test verifies: + - Proper handling of legacy auth type values + - Backward compatibility for api_key and api_key_header + """ + # Arrange: Setup test data + fake = Faker() + + # Create API tool provider with legacy auth type + provider = ApiToolProvider( + name=fake.company(), + description=fake.text(max_nb_chars=100), + icon='{"background": "#FF6B6B", "content": "🔧"}', + tenant_id=fake.uuid4(), + user_id=fake.uuid4(), + credentials_str='{"auth_type": "api_key", "api_key": "test_key"}', + schema="{}", + schema_type_str="openapi", + tools_str="[]", + ) + + from extensions.ext_database import db + + db.session.add(provider) + db.session.commit() + + # Act: Execute the method under test + result = ToolTransformService.api_provider_to_controller(provider) + + # Assert: Verify the expected outcomes + assert result is not None + assert hasattr(result, "from_db") + + def test_workflow_provider_to_controller_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful conversion of workflow provider to controller. + + This test verifies: + - Proper controller creation from workflow provider + - Workflow-specific controller handling + """ + # Arrange: Setup test data + fake = Faker() + + # Create workflow tool provider + provider = WorkflowToolProvider( + name=fake.company(), + description=fake.text(max_nb_chars=100), + icon='{"background": "#FF6B6B", "content": "🔧"}', + tenant_id=fake.uuid4(), + user_id=fake.uuid4(), + app_id=fake.uuid4(), + label="Test Workflow", + version="1.0.0", + parameter_configuration="[]", + ) + + from extensions.ext_database import db + + db.session.add(provider) + db.session.commit() + + # Mock the WorkflowToolProviderController.from_db method to avoid app dependency + with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db: + mock_controller = Mock() + mock_from_db.return_value = mock_controller + + # Act: Execute the method under test + result = ToolTransformService.workflow_provider_to_controller(provider) + + # Assert: Verify the expected outcomes + assert result is not None + assert result == mock_controller + mock_from_db.assert_called_once_with(provider) diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts index b4c4f1540d..b579f22d4b 100644 --- a/web/__tests__/check-i18n.test.ts +++ b/web/__tests__/check-i18n.test.ts @@ -621,7 +621,7 @@ export default translation && !trimmed.startsWith('//')) break } - else { + else { break } diff --git a/web/__tests__/description-validation.test.tsx b/web/__tests__/description-validation.test.tsx index 85263b035f..a78a4e632e 100644 --- a/web/__tests__/description-validation.test.tsx +++ b/web/__tests__/description-validation.test.tsx @@ -60,7 +60,7 @@ describe('Description Validation Logic', () => { try { validateDescriptionLength(invalidDescription) } - catch (error) { + catch (error) { expect((error as Error).message).toBe(expectedErrorMessage) } }) @@ -86,7 +86,7 @@ describe('Description Validation Logic', () => { expect(() => validateDescriptionLength(testDescription)).not.toThrow() expect(validateDescriptionLength(testDescription)).toBe(testDescription) } - else { + else { expect(() => validateDescriptionLength(testDescription)).toThrow( 'Description cannot exceed 400 characters.', ) diff --git a/web/__tests__/document-list-sorting.test.tsx b/web/__tests__/document-list-sorting.test.tsx index 1510dbec23..77c0bb60cf 100644 --- a/web/__tests__/document-list-sorting.test.tsx +++ b/web/__tests__/document-list-sorting.test.tsx @@ -39,7 +39,7 @@ describe('Document List Sorting', () => { const result = aValue.localeCompare(bValue) return order === 'asc' ? result : -result } - else { + else { const result = aValue - bValue return order === 'asc' ? result : -result } diff --git a/web/__tests__/plugin-tool-workflow-error.test.tsx b/web/__tests__/plugin-tool-workflow-error.test.tsx index 370052bc80..87bda8fa13 100644 --- a/web/__tests__/plugin-tool-workflow-error.test.tsx +++ b/web/__tests__/plugin-tool-workflow-error.test.tsx @@ -196,7 +196,7 @@ describe('Plugin Tool Workflow Integration', () => { const _pluginId = (tool.uniqueIdentifier as any).split(':')[0] }).toThrow() } - else { + else { // Valid tools should work fine expect(() => { const _pluginId = tool.uniqueIdentifier.split(':')[0] diff --git a/web/__tests__/real-browser-flicker.test.tsx b/web/__tests__/real-browser-flicker.test.tsx index cf3abd5f80..52bdf4777f 100644 --- a/web/__tests__/real-browser-flicker.test.tsx +++ b/web/__tests__/real-browser-flicker.test.tsx @@ -252,7 +252,7 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => { if (hasStyleChange) console.log('⚠️ Style changes detected - this causes visible flicker') - else + else console.log('✅ No style changes detected') expect(timingData.length).toBeGreaterThan(1) diff --git a/web/__tests__/workflow-parallel-limit.test.tsx b/web/__tests__/workflow-parallel-limit.test.tsx index 0843122ab4..64e9d328f0 100644 --- a/web/__tests__/workflow-parallel-limit.test.tsx +++ b/web/__tests__/workflow-parallel-limit.test.tsx @@ -15,7 +15,7 @@ const originalEnv = process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT function setupEnvironment(value?: string) { if (value) process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT = value - else + else delete process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT // Clear module cache to force re-evaluation @@ -25,7 +25,7 @@ function setupEnvironment(value?: string) { function restoreEnvironment() { if (originalEnv) process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT = originalEnv - else + else delete process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT jest.resetModules() diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx index a3281be8eb..b1e915b2bf 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx @@ -47,7 +47,7 @@ describe('SVG Attribute Error Reproduction', () => { console.log(` ${index + 1}. ${error.substring(0, 100)}...`) }) } - else { + else { console.log('No inkscape errors found in this render') } @@ -150,7 +150,7 @@ describe('SVG Attribute Error Reproduction', () => { if (problematicKeys.length > 0) console.log(`🚨 PROBLEM: Still found problematic attributes: ${problematicKeys.join(', ')}`) - else + else console.log('✅ No problematic attributes found after normalization') }) }) diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index 0408d2ee34..5890c2ea92 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -106,7 +106,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { onClick={() => { if (hoverArea === 'right' && !onAvatarError) setIsShowDeleteConfirm(true) - else + else setIsShowAvatarPicker(true) }} onMouseMove={(e) => { diff --git a/web/app/components/app-sidebar/basic.tsx b/web/app/components/app-sidebar/basic.tsx index 00357d6c27..77a965c03e 100644 --- a/web/app/components/app-sidebar/basic.tsx +++ b/web/app/components/app-sidebar/basic.tsx @@ -45,8 +45,8 @@ const ICON_MAP = { , dataset: , webapp:
- -
, + + , notion: , } diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index c3ff45d6a6..c60aa26f5d 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -62,12 +62,12 @@ const AppDetailNav = ({ title, desc, isExternal, icon, icon_background, navigati }, [appSidebarExpand, setAppSiderbarExpand]) if (inWorkflowCanvas && hideHeader) { - return ( + return (
) -} + } return (
{ })) }) - describe('Issue #1: Toggle Button Position Movement - FIXED', () => { + describe('Issue #1: Toggle Button Position Movement - FIXED', () => { it('should verify consistent padding prevents button position shift', () => { let expanded = false const handleToggle = () => { diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index bb2a95b0b5..afa8732701 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -84,7 +84,7 @@ const Annotation: FC = (props) => { setList(data as AnnotationItem[]) setTotal(total) } - finally { + finally { setIsLoading(false) } } diff --git a/web/app/components/app/configuration/config-var/config-modal/type-select.tsx b/web/app/components/app/configuration/config-var/config-modal/type-select.tsx index 3f6a01ed7c..beb7b03e37 100644 --- a/web/app/components/app/configuration/config-var/config-modal/type-select.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/type-select.tsx @@ -52,13 +52,13 @@ const TypeSelector: FC = ({ >
- - {selectedItem?.name} - + > + {selectedItem?.name} +
{inputVarTypeToVarType(selectedItem?.value as InputVarType)} diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 86025f68fa..cb61b927bc 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -175,7 +175,6 @@ const ConfigContent: FC = ({ ...datasetConfigs, reranking_enable: enable, }) - // eslint-disable-next-line react-hooks/exhaustive-deps }, [currentRerankModel, datasetConfigs, onChange]) return ( diff --git a/web/app/components/app/configuration/debug/chat-user-input.tsx b/web/app/components/app/configuration/debug/chat-user-input.tsx index ac07691ce4..b1161de075 100644 --- a/web/app/components/app/configuration/debug/chat-user-input.tsx +++ b/web/app/components/app/configuration/debug/chat-user-input.tsx @@ -57,10 +57,10 @@ const ChatUserInput = ({ >
{type !== 'checkbox' && ( -
-
{name || key}
- {!required && {t('workflow.panel.optional')}} -
+
+
{name || key}
+ {!required && {t('workflow.panel.optional')}} +
)}
{type === 'string' && ( diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 67b8065745..b73d1f19de 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -112,72 +112,72 @@ const getFormattedChatList = (messages: ChatMessage[], conversationId: string, t const newChatList: IChatItem[] = [] try { messages.forEach((item: ChatMessage) => { - const questionFiles = item.message_files?.filter((file: any) => file.belongs_to === 'user') || [] - newChatList.push({ - id: `question-${item.id}`, - content: item.inputs.query || item.inputs.default_input || item.query, // text generation: item.inputs.query; chat: item.query - isAnswer: false, - message_files: getProcessedFilesFromResponse(questionFiles.map((item: any) => ({ ...item, related_id: item.id }))), - parentMessageId: item.parent_message_id || undefined, - }) + const questionFiles = item.message_files?.filter((file: any) => file.belongs_to === 'user') || [] + newChatList.push({ + id: `question-${item.id}`, + content: item.inputs.query || item.inputs.default_input || item.query, // text generation: item.inputs.query; chat: item.query + isAnswer: false, + message_files: getProcessedFilesFromResponse(questionFiles.map((item: any) => ({ ...item, related_id: item.id }))), + parentMessageId: item.parent_message_id || undefined, + }) - const answerFiles = item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [] - newChatList.push({ - id: item.id, - content: item.answer, - agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files), - feedback: item.feedbacks.find(item => item.from_source === 'user'), // user feedback - adminFeedback: item.feedbacks.find(item => item.from_source === 'admin'), // admin feedback - feedbackDisabled: false, - isAnswer: true, - message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id }))), - log: [ - ...item.message, - ...(item.message[item.message.length - 1]?.role !== 'assistant' - ? [ - { - role: 'assistant', - text: item.answer, - files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], - }, - ] - : []), - ] as IChatItem['log'], - workflow_run_id: item.workflow_run_id, - conversationId, - input: { - inputs: item.inputs, - query: item.query, - }, - more: { - time: dayjs.unix(item.created_at).tz(timezone).format(format), - tokens: item.answer_tokens + item.message_tokens, - latency: item.provider_response_latency.toFixed(2), - }, - citation: item.metadata?.retriever_resources, - annotation: (() => { - if (item.annotation_hit_history) { - return { - id: item.annotation_hit_history.annotation_id, - authorName: item.annotation_hit_history.annotation_create_account?.name || 'N/A', - created_at: item.annotation_hit_history.created_at, + const answerFiles = item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [] + newChatList.push({ + id: item.id, + content: item.answer, + agent_thoughts: addFileInfos(item.agent_thoughts ? sortAgentSorts(item.agent_thoughts) : item.agent_thoughts, item.message_files), + feedback: item.feedbacks.find(item => item.from_source === 'user'), // user feedback + adminFeedback: item.feedbacks.find(item => item.from_source === 'admin'), // admin feedback + feedbackDisabled: false, + isAnswer: true, + message_files: getProcessedFilesFromResponse(answerFiles.map((item: any) => ({ ...item, related_id: item.id }))), + log: [ + ...item.message, + ...(item.message[item.message.length - 1]?.role !== 'assistant' + ? [ + { + role: 'assistant', + text: item.answer, + files: item.message_files?.filter((file: any) => file.belongs_to === 'assistant') || [], + }, + ] + : []), + ] as IChatItem['log'], + workflow_run_id: item.workflow_run_id, + conversationId, + input: { + inputs: item.inputs, + query: item.query, + }, + more: { + time: dayjs.unix(item.created_at).tz(timezone).format(format), + tokens: item.answer_tokens + item.message_tokens, + latency: item.provider_response_latency.toFixed(2), + }, + citation: item.metadata?.retriever_resources, + annotation: (() => { + if (item.annotation_hit_history) { + return { + id: item.annotation_hit_history.annotation_id, + authorName: item.annotation_hit_history.annotation_create_account?.name || 'N/A', + created_at: item.annotation_hit_history.created_at, + } } - } - if (item.annotation) { - return { - id: item.annotation.id, - authorName: item.annotation.account.name, - logAnnotation: item.annotation, - created_at: 0, + if (item.annotation) { + return { + id: item.annotation.id, + authorName: item.annotation.account.name, + logAnnotation: item.annotation, + created_at: 0, + } } - } - return undefined - })(), - parentMessageId: `question-${item.id}`, + return undefined + })(), + parentMessageId: `question-${item.id}`, + }) }) - }) return newChatList } @@ -503,7 +503,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { setThreadChatItems(getThreadMessages(tree, newAllChatItems.at(-1)?.id)) } - catch (error) { + catch (error) { console.error(error) setHasMore(false) } @@ -522,7 +522,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { if (outerDiv && outerDiv.scrollHeight > outerDiv.clientHeight) { scrollContainer = outerDiv } - else if (scrollableDiv && scrollableDiv.scrollHeight > scrollableDiv.clientHeight) { + else if (scrollableDiv && scrollableDiv.scrollHeight > scrollableDiv.clientHeight) { scrollContainer = scrollableDiv } else if (chatContainer && chatContainer.scrollHeight > chatContainer.clientHeight) { diff --git a/web/app/components/app/overview/app-card.tsx b/web/app/components/app/overview/app-card.tsx index 8713c8ef7b..c6df0ebfd9 100644 --- a/web/app/components/app/overview/app-card.tsx +++ b/web/app/components/app/overview/app-card.tsx @@ -167,7 +167,7 @@ function AppCard({ setAppDetail(res) setShowAccessControl(false) } - catch (error) { + catch (error) { console.error('Failed to fetch app detail:', error) } }, [appDetail, setAppDetail]) diff --git a/web/app/components/app/overview/embedded/index.tsx b/web/app/components/app/overview/embedded/index.tsx index cd25c4ca65..6eba993e1d 100644 --- a/web/app/components/app/overview/embedded/index.tsx +++ b/web/app/components/app/overview/embedded/index.tsx @@ -40,12 +40,12 @@ const OPTION_MAP = { `