From beda78e91129874d7a2b5377f71f2cc2ca6dc1bb Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Wed, 1 Apr 2026 06:00:05 +0200 Subject: [PATCH 1/4] refactor: select in 13 small service files (#34371) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/audio_service.py | 2 +- api/services/billing_service.py | 7 +++-- api/services/conversation_service.py | 12 ++++---- api/services/credit_pool_service.py | 14 ++++----- .../enterprise/account_deletion_sync.py | 5 +++- .../rag_pipeline/pipeline_generate_service.py | 2 +- .../customized/customized_retrieval.py | 12 ++++---- .../database/database_retrieval.py | 11 +++---- .../database/database_retrieval.py | 8 ++--- api/services/web_conversation_service.py | 12 ++++---- api/services/webapp_auth_service.py | 5 ++-- api/services/workflow/workflow_converter.py | 7 +++-- api/services/workspace_service.py | 7 +++-- .../unit_tests/services/test_audio_service.py | 21 ++++--------- .../services/test_billing_service.py | 30 ++++--------------- .../services/test_conversation_service.py | 19 ++++-------- 16 files changed, 72 insertions(+), 102 deletions(-) diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 90e72d5f34..1c7027efb4 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -132,7 +132,7 @@ class AudioService: uuid.UUID(message_id) except ValueError: return None - message = db.session.query(Message).where(Message.id == message_id).first() + message = db.session.get(Message, message_id) if message is None: return None if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}: diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 54c595e0cb..9970b2e604 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -6,6 +6,7 @@ from typing import Literal import httpx from pydantic import TypeAdapter +from sqlalchemy import select from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from typing_extensions import TypedDict from werkzeug.exceptions import InternalServerError @@ -158,10 +159,10 @@ class BillingService: def is_tenant_owner_or_admin(current_user: Account): tenant_id = current_user.current_tenant_id - join: TenantAccountJoin | None = ( - db.session.query(TenantAccountJoin) + join: TenantAccountJoin | None = db.session.scalar( + select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) - .first() + .limit(1) ) if not join: diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index ba1e7bb826..95482a2235 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -137,11 +137,11 @@ class ConversationService: @classmethod def auto_generate_name(cls, app_model: App, conversation: Conversation): # get conversation first message - message = ( - db.session.query(Message) + message = db.session.scalar( + select(Message) .where(Message.app_id == app_model.id, Message.conversation_id == conversation.id) .order_by(Message.created_at.asc()) - .first() + .limit(1) ) if not message: @@ -160,8 +160,8 @@ class ConversationService: @classmethod def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): - conversation = ( - db.session.query(Conversation) + conversation = db.session.scalar( + select(Conversation) .where( Conversation.id == conversation_id, Conversation.app_id == app_model.id, @@ -170,7 +170,7 @@ class ConversationService: Conversation.from_account_id == (user.id if isinstance(user, Account) else None), Conversation.is_deleted == False, ) - .first() + .limit(1) ) if not conversation: diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 2894826935..7826695366 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -1,6 +1,6 @@ import logging -from sqlalchemy import update +from sqlalchemy import select, update from sqlalchemy.orm import Session from configs import dify_config @@ -29,13 +29,13 @@ class CreditPoolService: @classmethod def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None: """get tenant credit pool""" - return ( - db.session.query(TenantCreditPool) - .filter_by( - tenant_id=tenant_id, - pool_type=pool_type, + return db.session.scalar( + select(TenantCreditPool) + .where( + TenantCreditPool.tenant_id == tenant_id, + TenantCreditPool.pool_type == pool_type, ) - .first() + .limit(1) ) @classmethod diff --git a/api/services/enterprise/account_deletion_sync.py b/api/services/enterprise/account_deletion_sync.py index c7ff42894d..b5107fb0f6 100644 --- a/api/services/enterprise/account_deletion_sync.py +++ b/api/services/enterprise/account_deletion_sync.py @@ -4,6 +4,7 @@ import uuid from datetime import UTC, datetime from redis import RedisError +from sqlalchemy import select from configs import dify_config from extensions.ext_database import db @@ -104,7 +105,9 @@ def sync_account_deletion(account_id: str, *, source: str) -> bool: return True # Fetch all workspaces the account belongs to - workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all() + workspace_joins = db.session.scalars( + select(TenantAccountJoin).where(TenantAccountJoin.account_id == account_id) + ).all() # Queue sync task for each workspace success = True diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 07e1b8f20e..10e89b1dba 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -110,7 +110,7 @@ class PipelineGenerateService: Update document status to waiting :param document_id: document id """ - document = db.session.query(Document).where(Document.id == document_id).first() + document = db.session.get(Document, document_id) if document: document.indexing_status = IndexingStatus.WAITING db.session.add(document) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 4ac2e0792b..2ee871a266 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,4 +1,5 @@ import yaml +from sqlalchemy import select from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -32,12 +33,11 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - pipeline_customized_templates = ( - db.session.query(PipelineCustomizedTemplate) + pipeline_customized_templates = db.session.scalars( + select(PipelineCustomizedTemplate) .where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) - .all() - ) + ).all() recommended_pipelines_results = [] for pipeline_customized_template in pipeline_customized_templates: recommended_pipeline_result = { @@ -59,9 +59,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param template_id: Template ID :return: """ - pipeline_template = ( - db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first() - ) + pipeline_template = db.session.get(PipelineCustomizedTemplate, template_id) if not pipeline_template: return None diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 908f9a2684..43b21a7b32 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,4 +1,5 @@ import yaml +from sqlalchemy import select from extensions.ext_database import db from models.dataset import PipelineBuiltInTemplate @@ -30,8 +31,10 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :return: """ - pipeline_built_in_templates: list[PipelineBuiltInTemplate] = ( - db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all() + pipeline_built_in_templates = list( + db.session.scalars( + select(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language) + ).all() ) recommended_pipelines_results = [] @@ -58,9 +61,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :return: """ # is in public recommended list - pipeline_template = ( - db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first() - ) + pipeline_template = db.session.get(PipelineBuiltInTemplate, template_id) if not pipeline_template: return None diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index d0c49325dc..6fb90d356d 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -77,17 +77,15 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): :return: """ # is in public recommended list - recommended_app = ( - db.session.query(RecommendedApp) - .where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) - .first() + recommended_app = db.session.scalar( + select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id).limit(1) ) if not recommended_app: return None # get app detail - app_model = db.session.query(App).where(App.id == app_id).first() + app_model = db.session.get(App, app_id) if not app_model or not app_model.is_public: return None diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index e028e3e5e3..5ef9e9be61 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -64,15 +64,15 @@ class WebConversationService: def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): if not user: return - pinned_conversation = ( - db.session.query(PinnedConversation) + pinned_conversation = db.session.scalar( + select(PinnedConversation) .where( PinnedConversation.app_id == app_model.id, PinnedConversation.conversation_id == conversation_id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by == user.id, ) - .first() + .limit(1) ) if pinned_conversation: @@ -96,15 +96,15 @@ class WebConversationService: def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): if not user: return - pinned_conversation = ( - db.session.query(PinnedConversation) + pinned_conversation = db.session.scalar( + select(PinnedConversation) .where( PinnedConversation.app_id == app_model.id, PinnedConversation.conversation_id == conversation_id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by == user.id, ) - .first() + .limit(1) ) if not pinned_conversation: diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 5ca0b63001..eaea79af2f 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -3,6 +3,7 @@ import secrets from datetime import UTC, datetime, timedelta from typing import Any +from sqlalchemy import select from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config @@ -92,10 +93,10 @@ class WebAppAuthService: @classmethod def create_end_user(cls, app_code, email) -> EndUser: - site = db.session.query(Site).where(Site.code == app_code).first() + site = db.session.scalar(select(Site).where(Site.code == app_code).limit(1)) if not site: raise NotFound("Site not found.") - app_model = db.session.query(App).where(App.id == site.app_id).first() + app_model = db.session.get(App, site.app_id) if not app_model: raise NotFound("App not found.") end_user = EndUser( diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 31367f72fa..399c82849f 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -6,6 +6,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.nodes import BuiltinNodeTypes from graphon.variables.input_entities import VariableEntity +from sqlalchemy import select from typing_extensions import TypedDict from core.app.app_config.entities import ( @@ -648,10 +649,10 @@ class WorkflowConverter: :param api_based_extension_id: api based extension id :return: """ - api_based_extension = ( - db.session.query(APIBasedExtension) + api_based_extension = db.session.scalar( + select(APIBasedExtension) .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) - .first() + .limit(1) ) if not api_based_extension: diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 84a8b03329..eb4671cfaa 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,5 @@ from flask_login import current_user +from sqlalchemy import select from configs import dify_config from enums.cloud_plan import CloudPlan @@ -24,10 +25,10 @@ class WorkspaceService: } # Get role of user - tenant_account_join = ( - db.session.query(TenantAccountJoin) + tenant_account_join = db.session.scalar( + select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) - .first() + .limit(1) ) assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 175fd3ee01..cede6671ce 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -421,11 +421,8 @@ class TestAudioServiceTTS: answer="Message answer text", ) - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = message + # Mock database lookup + mock_db_session.get.return_value = message # Mock ModelManager mock_model_manager = mock_model_manager_class.return_value @@ -568,11 +565,8 @@ class TestAudioServiceTTS: # Arrange app = factory.create_app_mock() - # Mock database query returning None - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None + # Mock database lookup returning None + mock_db_session.get.return_value = None # Act result = AudioService.transcript_tts( @@ -594,11 +588,8 @@ class TestAudioServiceTTS: status=MessageStatus.NORMAL, ) - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = message + # Mock database lookup + mock_db_session.get.return_value = message # Act result = AudioService.transcript_tts( diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index b3d2e60802..168ab6cf0d 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -865,16 +865,11 @@ class TestBillingServiceAccountManagement: mock_join = MagicMock(spec=TenantAccountJoin) mock_join.role = TenantAccountRole.OWNER - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_db_session.query.return_value = mock_query + mock_db_session.scalar.return_value = mock_join # Act - should not raise exception BillingService.is_tenant_owner_or_admin(current_user) - # Assert - mock_db_session.query.assert_called_once() - def test_is_tenant_owner_or_admin_admin(self, mock_db_session): """Test tenant owner/admin check for admin role.""" # Arrange @@ -885,16 +880,11 @@ class TestBillingServiceAccountManagement: mock_join = MagicMock(spec=TenantAccountJoin) mock_join.role = TenantAccountRole.ADMIN - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_db_session.query.return_value = mock_query + mock_db_session.scalar.return_value = mock_join # Act - should not raise exception BillingService.is_tenant_owner_or_admin(current_user) - # Assert - mock_db_session.query.assert_called_once() - def test_is_tenant_owner_or_admin_normal_user_raises_error(self, mock_db_session): """Test tenant owner/admin check raises error for normal user.""" # Arrange @@ -905,9 +895,7 @@ class TestBillingServiceAccountManagement: mock_join = MagicMock(spec=TenantAccountJoin) mock_join.role = TenantAccountRole.NORMAL - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_db_session.query.return_value = mock_query + mock_db_session.scalar.return_value = mock_join # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -921,9 +909,7 @@ class TestBillingServiceAccountManagement: current_user.id = "account-123" current_user.current_tenant_id = "tenant-456" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = None - mock_db_session.query.return_value = mock_query + mock_db_session.scalar.return_value = None # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -1135,9 +1121,7 @@ class TestBillingServiceEdgeCases: mock_join.role = TenantAccountRole.EDITOR # Editor is not privileged with patch("services.billing_service.db.session") as mock_session: - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_session.query.return_value = mock_query + mock_session.scalar.return_value = mock_join # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -1155,9 +1139,7 @@ class TestBillingServiceEdgeCases: mock_join.role = TenantAccountRole.DATASET_OPERATOR # Dataset operator is not privileged with patch("services.billing_service.db.session") as mock_session: - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_session.query.return_value = mock_query + mock_session.scalar.return_value = mock_join # Act & Assert with pytest.raises(ValueError) as exc_info: diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 1bf4c0e172..a4359f00b8 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -355,15 +355,13 @@ class TestConversationServiceGetConversation: from_account_id=user.id, from_source=ConversationFromSource.CONSOLE ) - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.first.return_value = conversation + mock_db_session.scalar.return_value = conversation # Act result = ConversationService.get_conversation(app_model, "conv-123", user) # Assert assert result == conversation - mock_db_session.query.assert_called_once_with(Conversation) @patch("services.conversation_service.db.session") def test_get_conversation_success_with_end_user(self, mock_db_session): @@ -379,8 +377,7 @@ class TestConversationServiceGetConversation: from_end_user_id=user.id, from_source=ConversationFromSource.API ) - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.first.return_value = conversation + mock_db_session.scalar.return_value = conversation # Act result = ConversationService.get_conversation(app_model, "conv-123", user) @@ -399,8 +396,7 @@ class TestConversationServiceGetConversation: app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_account_mock() - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.first.return_value = None + mock_db_session.scalar.return_value = None # Act & Assert with pytest.raises(ConversationNotExistsError): @@ -489,8 +485,7 @@ class TestConversationServiceAutoGenerateName: ) # Mock database query to return message - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.order_by.return_value.first.return_value = message + mock_db_session.scalar.return_value = message # Mock LLM generator mock_llm_generator.generate_conversation_name.return_value = "Generated Name" @@ -518,8 +513,7 @@ class TestConversationServiceAutoGenerateName: conversation = ConversationServiceTestDataFactory.create_conversation_mock() # Mock database query to return None - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.order_by.return_value.first.return_value = None + mock_db_session.scalar.return_value = None # Act & Assert with pytest.raises(MessageNotExistsError): @@ -541,8 +535,7 @@ class TestConversationServiceAutoGenerateName: ) # Mock database query to return message - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.order_by.return_value.first.return_value = message + mock_db_session.scalar.return_value = message # Mock LLM generator to raise exception mock_llm_generator.generate_conversation_name.side_effect = Exception("LLM Error") From 09ee8ea1f535fc86a41e8370ef520abbe10ac54f Mon Sep 17 00:00:00 2001 From: Full Stack Engineer <66432853+EndlessLucky@users.noreply.github.com> Date: Wed, 1 Apr 2026 00:22:23 -0400 Subject: [PATCH 2/4] fix: support qa_preview shape in IndexProcessor preview formatting (#34151) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com> --- api/core/rag/index_processor/index_processor.py | 9 ++++++++- .../core/rag/indexing/test_index_processor.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 api/tests/unit_tests/core/rag/indexing/test_index_processor.py diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index a6d1db214b..825ae01226 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -35,7 +35,10 @@ class IndexProcessor: if "parent_mode" in preview: data.parent_mode = preview["parent_mode"] - for item in preview["preview"]: + # Different index processors return different preview shapes: + # - paragraph/parent-child processors: {"preview": [...]} + # - QA processor: {"qa_preview": [...]} (no "preview" key) + for item in preview.get("preview", []): if "content" in item and "child_chunks" in item: data.preview.append( PreviewItem(content=item["content"], child_chunks=item["child_chunks"], summary=None) @@ -44,6 +47,10 @@ class IndexProcessor: data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"])) elif "content" in item: data.preview.append(PreviewItem(content=item["content"], child_chunks=None, summary=None)) + + for item in preview.get("qa_preview", []): + if "question" in item and "answer" in item: + data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"])) return data def index_and_clean( diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor.py new file mode 100644 index 0000000000..a3f284955b --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor.py @@ -0,0 +1,15 @@ +from core.rag.index_processor.index_processor import IndexProcessor + + +class TestIndexProcessor: + def test_format_preview_supports_qa_preview_shape(self) -> None: + preview = IndexProcessor().format_preview( + "qa_model", + {"qa_chunks": [{"question": "Q1", "answer": "A1"}]}, + ) + + assert preview.chunk_structure == "qa_model" + assert preview.total_segments == 1 + assert len(preview.qa_preview) == 1 + assert preview.qa_preview[0].question == "Q1" + assert preview.qa_preview[0].answer == "A1" From c51cd42cb4e21320664b6d0e9efcf2ecbd1ddec5 Mon Sep 17 00:00:00 2001 From: Dream <42954461+eureka928@users.noreply.github.com> Date: Wed, 1 Apr 2026 01:41:44 -0400 Subject: [PATCH 3/4] refactor(api): replace json.loads with Pydantic validation in controllers and infra layers (#34277) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/workflow.py | 12 ++--- .../rag_pipeline/rag_pipeline_workflow.py | 23 ++-------- .../arize_phoenix_trace.py | 3 +- api/core/ops/mlflow_trace/mlflow_trace.py | 10 ++--- api/core/ops/ops_trace_manager.py | 23 +++++++--- api/core/ops/utils.py | 3 ++ .../alibabacloud_mysql_vector.py | 15 +++---- .../analyticdb/analyticdb_vector_openapi.py | 5 ++- .../rag/datasource/vdb/baidu/baidu_vector.py | 13 ++---- .../vdb/clickzetta/clickzetta_vector.py | 32 ++++++------- api/core/rag/datasource/vdb/field.py | 20 +++++++++ .../vdb/hologres/hologres_vector.py | 7 ++- .../rag/datasource/vdb/iris/iris_vector.py | 5 ++- .../vdb/matrixone/matrixone_vector.py | 7 +-- .../vdb/oceanbase/oceanbase_vector.py | 5 ++- .../vdb/tablestore/tablestore_vector.py | 9 ++-- .../datasource/vdb/tencent/tencent_vector.py | 12 +++-- .../datasource/vdb/tidb_vector/tidb_vector.py | 4 +- .../vdb/vikingdb/vikingdb_vector.py | 7 ++- ...tore_workflow_node_execution_repository.py | 9 ++-- .../clickzetta_volume/file_lifecycle.py | 8 +++- .../storage/google_cloud_storage.py | 7 ++- .../core/rag/datasource/vdb/test_field.py | 45 +++++++++++++++++++ 23 files changed, 170 insertions(+), 114 deletions(-) create mode 100644 api/tests/unit_tests/core/rag/datasource/vdb/test_field.py diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 6df8f7032e..dcd24d2200 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -9,7 +9,7 @@ from graphon.enums import NodeType from graphon.file import File from graphon.graph_engine.manager import GraphEngineManager from graphon.model_runtime.utils.encoders import jsonable_encoder -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, ValidationError, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -268,22 +268,18 @@ class DraftWorkflowApi(Resource): content_type = request.headers.get("Content-Type", "") - payload_data: dict[str, Any] | None = None if "application/json" in content_type: payload_data = request.get_json(silent=True) if not isinstance(payload_data, dict): return {"message": "Invalid JSON data"}, 400 + args_model = SyncDraftWorkflowPayload.model_validate(payload_data) elif "text/plain" in content_type: try: - payload_data = json.loads(request.data.decode("utf-8")) - except json.JSONDecodeError: - return {"message": "Invalid JSON data"}, 400 - if not isinstance(payload_data, dict): + args_model = SyncDraftWorkflowPayload.model_validate_json(request.data) + except (ValueError, ValidationError): return {"message": "Invalid JSON data"}, 400 else: abort(415) - - args_model = SyncDraftWorkflowPayload.model_validate(payload_data) args = args_model.model_dump() workflow_service = WorkflowService() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index e08cb155b6..4251e7ebac 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -5,7 +5,7 @@ from typing import Any, Literal, cast from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore from graphon.model_runtime.utils.encoders import jsonable_encoder -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -186,29 +186,14 @@ class DraftRagPipelineApi(Resource): if "application/json" in content_type: payload_dict = console_ns.payload or {} + payload = DraftWorkflowSyncPayload.model_validate(payload_dict) elif "text/plain" in content_type: try: - data = json.loads(request.data.decode("utf-8")) - if "graph" not in data or "features" not in data: - raise ValueError("graph or features not found in data") - - if not isinstance(data.get("graph"), dict): - raise ValueError("graph is not a dict") - - payload_dict = { - "graph": data.get("graph"), - "features": data.get("features"), - "hash": data.get("hash"), - "environment_variables": data.get("environment_variables"), - "conversation_variables": data.get("conversation_variables"), - "rag_pipeline_variables": data.get("rag_pipeline_variables"), - } - except json.JSONDecodeError: + payload = DraftWorkflowSyncPayload.model_validate_json(request.data) + except (ValueError, ValidationError): return {"message": "Invalid JSON data"}, 400 else: abort(415) - - payload = DraftWorkflowSyncPayload.model_validate(payload_dict) rag_pipeline_service = RagPipelineService() try: diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 902f58e6b7..66933cea28 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -38,6 +38,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) +from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db from models.model import EndUser, MessageFile @@ -469,7 +470,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider if trace_info.message_data and trace_info.message_data.message_metadata: - metadata_dict = json.loads(trace_info.message_data.message_metadata) + metadata_dict = JSON_DICT_ADAPTER.validate_json(trace_info.message_data.message_metadata) if model_params := metadata_dict.get("model_parameters"): llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params) diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index 946d3cdd47..3d8c1dd038 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -1,4 +1,3 @@ -import json import logging import os from datetime import datetime, timedelta @@ -25,6 +24,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) +from core.ops.utils import JSON_DICT_ADAPTER from extensions.ext_database import db from models import EndUser from models.workflow import WorkflowNodeExecutionModel @@ -153,7 +153,7 @@ class MLflowDataTrace(BaseTraceInstance): inputs = node.process_data # contains request URL if not inputs: - inputs = json.loads(node.inputs) if node.inputs else {} + inputs = JSON_DICT_ADAPTER.validate_json(node.inputs) if node.inputs else {} node_span = start_span_no_context( name=node.title, @@ -180,7 +180,7 @@ class MLflowDataTrace(BaseTraceInstance): # End node span finished_at = node.created_at + timedelta(seconds=node.elapsed_time) - outputs = json.loads(node.outputs) if node.outputs else {} + outputs = JSON_DICT_ADAPTER.validate_json(node.outputs) if node.outputs else {} if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: outputs = self._parse_knowledge_retrieval_outputs(outputs) elif node.node_type == BuiltinNodeTypes.LLM: @@ -216,8 +216,8 @@ class MLflowDataTrace(BaseTraceInstance): return {}, {} try: - data = json.loads(node.process_data) - except (json.JSONDecodeError, TypeError): + data = JSON_DICT_ADAPTER.validate_json(node.process_data) + except (ValueError, TypeError): return {}, {} inputs = self._parse_prompts(data.get("prompts")) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 9c36d57c6f..c689a86614 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -11,8 +11,10 @@ from uuid import UUID, uuid4 from cachetools import LRUCache from flask import current_app +from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from typing_extensions import TypedDict from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( @@ -33,7 +35,7 @@ from core.ops.entities.trace_entity import ( WorkflowNodeTraceInfo, WorkflowTraceInfo, ) -from core.ops.utils import get_message_data +from core.ops.utils import JSON_DICT_ADAPTER, get_message_data from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant @@ -50,6 +52,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class _AppTracingConfig(TypedDict, total=False): + enabled: bool + tracing_provider: str | None + + +_app_tracing_config_adapter: TypeAdapter[_AppTracingConfig] = TypeAdapter(_AppTracingConfig) + + def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]: """Return (app_name, workspace_name) for the given IDs. Falls back to empty strings.""" app_name = "" @@ -468,7 +478,7 @@ class OpsTraceManager: if app is None: return None - app_ops_trace_config = json.loads(app.tracing) if app.tracing else None + app_ops_trace_config = _app_tracing_config_adapter.validate_json(app.tracing) if app.tracing else None if app_ops_trace_config is None: return None if not app_ops_trace_config.get("enabled"): @@ -560,7 +570,7 @@ class OpsTraceManager: raise ValueError("App not found") if not app.tracing: return {"enabled": False, "tracing_provider": None} - app_trace_config = json.loads(app.tracing) + app_trace_config = _app_tracing_config_adapter.validate_json(app.tracing) return app_trace_config @staticmethod @@ -636,7 +646,6 @@ class TraceTask: carries ``total_tokens``. Projects only the ``outputs`` column to avoid loading large JSON blobs unnecessarily. """ - import json from models.workflow import WorkflowNodeExecutionModel @@ -658,7 +667,7 @@ class TraceTask: if not raw: continue try: - outputs = json.loads(raw) if isinstance(raw, str) else raw + outputs = JSON_DICT_ADAPTER.validate_json(raw) if isinstance(raw, str) else raw except (ValueError, TypeError): continue if not isinstance(outputs, dict): @@ -1420,7 +1429,7 @@ class TraceTask: return {} try: - metadata = json.loads(message_data.message_metadata) + metadata = JSON_DICT_ADAPTER.validate_json(message_data.message_metadata) usage = metadata.get("usage", {}) time_to_first_token = usage.get("time_to_first_token") time_to_generate = usage.get("time_to_generate") @@ -1430,7 +1439,7 @@ class TraceTask: "llm_streaming_time_to_generate": time_to_generate, "is_streaming_request": time_to_first_token is not None, } - except (json.JSONDecodeError, AttributeError): + except (ValueError, AttributeError): return {} diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 8b9a2e424a..a6f10c09ac 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -3,11 +3,14 @@ from datetime import datetime from typing import Any, Union from urllib.parse import urlparse +from pydantic import TypeAdapter from sqlalchemy import select from models.engine import db from models.model import Message +JSON_DICT_ADAPTER: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) + def filter_none_values(data: dict[str, Any]) -> dict[str, Any]: new_data = {} diff --git a/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py b/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py index fdb5ffebfc..6e76827a42 100644 --- a/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py +++ b/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py @@ -10,6 +10,7 @@ from mysql.connector import Error as MySQLError from pydantic import BaseModel, model_validator from configs import dify_config +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -178,9 +179,7 @@ class AlibabaCloudMySQLVector(BaseVector): cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids) docs = [] for record in cur: - metadata = record["meta"] - if isinstance(metadata, str): - metadata = json.loads(metadata) + metadata = parse_metadata_json(record["meta"]) docs.append(Document(page_content=record["text"], metadata=metadata)) return docs @@ -263,15 +262,13 @@ class AlibabaCloudMySQLVector(BaseVector): # similarity = 1 / (1 + distance) similarity = 1.0 / (1.0 + distance) - metadata = record["meta"] - if isinstance(metadata, str): - metadata = json.loads(metadata) + metadata = parse_metadata_json(record["meta"]) metadata["score"] = similarity metadata["distance"] = distance if similarity >= score_threshold: docs.append(Document(page_content=record["text"], metadata=metadata)) - except (ValueError, json.JSONDecodeError) as e: + except (ValueError, TypeError) as e: logger.warning("Error processing search result: %s", e) continue @@ -306,9 +303,7 @@ class AlibabaCloudMySQLVector(BaseVector): ) docs = [] for record in cur: - metadata = record["meta"] - if isinstance(metadata, str): - metadata = json.loads(metadata) + metadata = parse_metadata_json(record["meta"]) metadata["score"] = float(record["score"]) docs.append(Document(page_content=record["text"], metadata=metadata)) return docs diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 702200e0ac..ce626bbd7e 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -8,6 +8,7 @@ _import_err_msg = ( "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`" ) +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.models.document import Document from extensions.ext_redis import redis_client @@ -257,7 +258,7 @@ class AnalyticdbVectorOpenAPI: documents = [] for match in response.body.matches.match: if match.score >= score_threshold: - metadata = json.loads(match.metadata.get("metadata_")) + metadata = parse_metadata_json(match.metadata.get("metadata_")) metadata["score"] = match.score doc = Document( page_content=match.metadata.get("page_content"), @@ -294,7 +295,7 @@ class AnalyticdbVectorOpenAPI: documents = [] for match in response.body.matches.match: if match.score >= score_threshold: - metadata = json.loads(match.metadata.get("metadata_")) + metadata = parse_metadata_json(match.metadata.get("metadata_")) metadata["score"] = match.score doc = Document( page_content=match.metadata.get("page_content"), diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 9f5842e449..3173920c9c 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -29,6 +29,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams, from configs import dify_config from core.rag.datasource.vdb.field import Field as VDBField +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -173,15 +174,9 @@ class BaiduVector(BaseVector): score = row.get("score", 0.0) meta = row_data.get(VDBField.METADATA_KEY, {}) - # Handle both JSON string and dict formats for backward compatibility - if isinstance(meta, str): - try: - import json - - meta = json.loads(meta) - except (json.JSONDecodeError, TypeError): - meta = {} - elif not isinstance(meta, dict): + try: + meta = parse_metadata_json(meta) + except (ValueError, TypeError): meta = {} if score >= score_threshold: diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 8e8120fc10..a4dddc68f0 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from clickzetta.connector.v0.connection import Connection # type: ignore from configs import dify_config -from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.field import Field, parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.embedding.embedding_base import Embeddings @@ -357,18 +357,19 @@ class ClickzettaVector(BaseVector): """ try: if raw_metadata: - metadata = json.loads(raw_metadata) + # First parse may yield a string (double-encoded JSON) so use json.loads + first_pass = json.loads(raw_metadata) # Handle double-encoded JSON - if isinstance(metadata, str): - metadata = json.loads(metadata) - - # Ensure we have a dict - if not isinstance(metadata, dict): + if isinstance(first_pass, str): + metadata = parse_metadata_json(first_pass) + elif isinstance(first_pass, dict): + metadata = first_pass + else: metadata = {} else: metadata = {} - except (json.JSONDecodeError, TypeError): + except (json.JSONDecodeError, ValueError, TypeError): logger.exception("JSON parsing failed for metadata") # Fallback: extract document_id with regex doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', raw_metadata or "") @@ -930,17 +931,18 @@ class ClickzettaVector(BaseVector): # Parse metadata from JSON string (may be double-encoded) try: if row[2]: - metadata = json.loads(row[2]) + # First parse may yield a string (double-encoded JSON) + first_pass = json.loads(row[2]) - # If result is a string, it's double-encoded JSON - parse again - if isinstance(metadata, str): - metadata = json.loads(metadata) - - if not isinstance(metadata, dict): + if isinstance(first_pass, str): + metadata = parse_metadata_json(first_pass) + elif isinstance(first_pass, dict): + metadata = first_pass + else: metadata = {} else: metadata = {} - except (json.JSONDecodeError, TypeError): + except (json.JSONDecodeError, ValueError, TypeError): logger.exception("JSON parsing failed") # Fallback: extract document_id with regex diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py index 8fc94be360..5a0fabc572 100644 --- a/api/core/rag/datasource/vdb/field.py +++ b/api/core/rag/datasource/vdb/field.py @@ -1,4 +1,24 @@ from enum import StrEnum, auto +from typing import Any + +from pydantic import TypeAdapter + +_metadata_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) + + +def parse_metadata_json(raw: Any) -> dict[str, Any]: + """Parse metadata from a JSON string or pass through an existing dict. + + Many VDB drivers return metadata as either a JSON string or an already- + decoded dict depending on the column type and driver version. + """ + if raw is None or raw in ("", b""): + return {} + if isinstance(raw, dict): + return raw + if not isinstance(raw, (str, bytes, bytearray)): + return {} + return _metadata_adapter.validate_json(raw) class Field(StrEnum): diff --git a/api/core/rag/datasource/vdb/hologres/hologres_vector.py b/api/core/rag/datasource/vdb/hologres/hologres_vector.py index 36b259e494..13d48b5668 100644 --- a/api/core/rag/datasource/vdb/hologres/hologres_vector.py +++ b/api/core/rag/datasource/vdb/hologres/hologres_vector.py @@ -9,6 +9,7 @@ from psycopg import sql as psql from pydantic import BaseModel, model_validator from configs import dify_config +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -217,8 +218,7 @@ class HologresVector(BaseVector): text = row[2] meta = row[3] - if isinstance(meta, str): - meta = json.loads(meta) + meta = parse_metadata_json(meta) # Convert distance to similarity score (consistent with pgvector) score = 1 - distance @@ -265,8 +265,7 @@ class HologresVector(BaseVector): meta = row[2] score = row[-1] # score is the last column from return_score - if isinstance(meta, str): - meta = json.loads(meta) + meta = parse_metadata_json(meta) meta["score"] = score docs.append(Document(page_content=text, metadata=meta)) diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/core/rag/datasource/vdb/iris/iris_vector.py index 50bb2429ec..aae445e6ff 100644 --- a/api/core/rag/datasource/vdb/iris/iris_vector.py +++ b/api/core/rag/datasource/vdb/iris/iris_vector.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any from configs import dify_config from configs.middleware.vdb.iris_config import IrisVectorConfig +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -269,7 +270,7 @@ class IrisVector(BaseVector): if len(row) >= 4: text, meta_str, score = row[1], row[2], float(row[3]) if score >= score_threshold: - metadata = json.loads(meta_str) if meta_str else {} + metadata = parse_metadata_json(meta_str) metadata["score"] = score docs.append(Document(page_content=text, metadata=metadata)) return docs @@ -384,7 +385,7 @@ class IrisVector(BaseVector): meta_str = row[2] score_value = row[3] - metadata = json.loads(meta_str) if meta_str else {} + metadata = parse_metadata_json(meta_str) # Add score to metadata for hybrid search compatibility score = float(score_value) if score_value is not None else 0.0 metadata["score"] = score diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index 14955c8d7c..09ef498715 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -9,6 +9,7 @@ from mo_vector.client import MoVectorClient # type: ignore from pydantic import BaseModel, model_validator from configs import dify_config +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -196,11 +197,7 @@ class MatrixoneVector(BaseVector): docs = [] for result in results: - metadata = result.metadata - if isinstance(metadata, str): - import json - - metadata = json.loads(metadata) + metadata = parse_metadata_json(result.metadata) score = 1 - result.distance if score >= score_threshold: metadata["score"] = score diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index 86c1e65f47..82f419871c 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -10,6 +10,7 @@ from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.exc import SQLAlchemyError from configs import dify_config +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -366,8 +367,8 @@ class OceanBaseVector(BaseVector): # Parse metadata JSON try: - metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str - except json.JSONDecodeError: + metadata = parse_metadata_json(metadata_str) + except (ValueError, TypeError): logger.warning("Invalid JSON metadata: %s", metadata_str) metadata = {} diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index f2156afa59..4a734232ec 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator from tablestore import BatchGetRowRequest, TableInBatchGetRowItem from configs import dify_config -from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.field import Field, parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -73,7 +73,8 @@ class TableStoreVector(BaseVector): for item in table_result: if item.is_ok and item.row: kv = {k: v for k, v, _ in item.row.attribute_columns} - docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY]))) + metadata = parse_metadata_json(kv[Field.METADATA_KEY]) + docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=metadata)) return docs def get_type(self) -> str: @@ -311,7 +312,7 @@ class TableStoreVector(BaseVector): metadata_str = ots_column_map.get(Field.METADATA_KEY) vector = json.loads(vector_str) if vector_str else None - metadata = json.loads(metadata_str) if metadata_str else {} + metadata = parse_metadata_json(metadata_str) metadata["score"] = search_hit.score @@ -371,7 +372,7 @@ class TableStoreVector(BaseVector): ots_column_map[col[0]] = col[1] metadata_str = ots_column_map.get(Field.METADATA_KEY) - metadata = json.loads(metadata_str) if metadata_str else {} + metadata = parse_metadata_json(metadata_str) vector_str = ots_column_map.get(Field.VECTOR) vector = json.loads(vector_str) if vector_str else None diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 291d047c04..829db9db20 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -11,6 +11,7 @@ from tcvectordb.model import index as vdb_index # type: ignore from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, WeightedRerank # type: ignore from configs import dify_config +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -286,13 +287,10 @@ class TencentVector(BaseVector): return docs for result in res[0]: - meta = result.get(self.field_metadata) - if isinstance(meta, str): - # Compatible with version 1.1.3 and below. - meta = json.loads(meta) - score = 1 - result.get("score", 0.0) - else: - score = result.get("score", 0.0) + raw_meta = result.get(self.field_metadata) + # Compatible with version 1.1.3 and below: str means old driver. + score = (1 - result.get("score", 0.0)) if isinstance(raw_meta, str) else result.get("score", 0.0) + meta = parse_metadata_json(raw_meta) if score >= score_threshold: meta["score"] = score doc = Document(page_content=result.get(self.field_text), metadata=meta) diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 27ae038a06..c948917374 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -9,7 +9,7 @@ from sqlalchemy import text as sql_text from sqlalchemy.orm import Session, declarative_base from configs import dify_config -from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.field import Field, parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -228,7 +228,7 @@ class TiDBVector(BaseVector): ) results = [(row[0], row[1], row[2]) for row in res] for meta, text, distance in results: - metadata = json.loads(meta) + metadata = parse_metadata_json(meta) metadata["score"] = 1 - distance docs.append(Document(page_content=text, metadata=metadata)) return docs diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index e5feecf2bc..83fd3626d9 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -15,6 +15,7 @@ from volcengine.viking_db import ( # type: ignore from configs import dify_config from core.rag.datasource.vdb.field import Field as vdb_Field +from core.rag.datasource.vdb.field import parse_metadata_json from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -163,7 +164,7 @@ class VikingDBVector(BaseVector): for result in results: metadata = result.fields.get(vdb_Field.METADATA_KEY) if metadata is not None: - metadata = json.loads(metadata) + metadata = parse_metadata_json(metadata) if metadata.get(key) == value: ids.append(result.id) return ids @@ -189,9 +190,7 @@ class VikingDBVector(BaseVector): docs = [] for result in results: - metadata = result.fields.get(vdb_Field.METADATA_KEY) - if metadata is not None: - metadata = json.loads(metadata) + metadata = parse_metadata_json(result.fields.get(vdb_Field.METADATA_KEY)) if result.score >= score_threshold: metadata["score"] = result.score doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata) diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index b725436681..0e9a19b821 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -20,6 +20,7 @@ from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore @@ -48,10 +49,10 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut """ logger.debug("_dict_to_workflow_node_execution: data keys=%s", list(data.keys())[:5]) # Parse JSON fields - inputs = json.loads(data.get("inputs", "{}")) - process_data = json.loads(data.get("process_data", "{}")) - outputs = json.loads(data.get("outputs", "{}")) - metadata = json.loads(data.get("execution_metadata", "{}")) + inputs = JSON_DICT_ADAPTER.validate_json(data.get("inputs") or "{}") + process_data = JSON_DICT_ADAPTER.validate_json(data.get("process_data") or "{}") + outputs = JSON_DICT_ADAPTER.validate_json(data.get("outputs") or "{}") + metadata = JSON_DICT_ADAPTER.validate_json(data.get("execution_metadata") or "{}") # Convert metadata to domain enum keys domain_metadata = {} diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 1d9911465b..483bd6bbf6 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -15,8 +15,12 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any +from pydantic import TypeAdapter + logger = logging.getLogger(__name__) +_metadata_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) + class FileStatus(StrEnum): """File status enumeration""" @@ -455,8 +459,8 @@ class FileLifecycleManager: try: if self._storage.exists(self._metadata_file): metadata_content = self._storage.load_once(self._metadata_file) - result = json.loads(metadata_content.decode("utf-8")) - return dict(result) if result else {} + result = _metadata_adapter.validate_json(metadata_content) + return result or {} else: return {} except Exception as e: diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 4ad7e2d159..00f7289aa4 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -1,13 +1,16 @@ import base64 import io -import json from collections.abc import Generator +from typing import Any from google.cloud import storage as google_cloud_storage # type: ignore +from pydantic import TypeAdapter from configs import dify_config from extensions.storage.base_storage import BaseStorage +_service_account_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) + class GoogleCloudStorage(BaseStorage): """Implementation for Google Cloud storage.""" @@ -21,7 +24,7 @@ class GoogleCloudStorage(BaseStorage): if service_account_json_str: service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") # convert str to object - service_account_obj = json.loads(service_account_json) + service_account_obj = _service_account_adapter.validate_json(service_account_json) self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj) else: self.client = google_cloud_storage.Client() diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_field.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_field.py new file mode 100644 index 0000000000..d68c93b021 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_field.py @@ -0,0 +1,45 @@ +import pytest + +from core.rag.datasource.vdb.field import parse_metadata_json + + +class TestParseMetadataJson: + def test_none_returns_empty_dict(self): + assert parse_metadata_json(None) == {} + + def test_empty_string_returns_empty_dict(self): + assert parse_metadata_json("") == {} + + def test_valid_json_string(self): + result = parse_metadata_json('{"doc_id": "abc", "score": 0.9}') + assert result == {"doc_id": "abc", "score": 0.9} + + def test_dict_passthrough(self): + original = {"doc_id": "abc", "document_id": "123"} + result = parse_metadata_json(original) + assert result == original + + def test_empty_json_object(self): + assert parse_metadata_json("{}") == {} + + def test_invalid_json_raises_value_error(self): + with pytest.raises(ValueError): + parse_metadata_json("{invalid json") + + def test_nested_metadata(self): + result = parse_metadata_json('{"doc_id": "1", "extra": {"nested": true}}') + assert result["extra"]["nested"] is True + + def test_non_str_non_dict_returns_empty_dict(self): + assert parse_metadata_json(123) == {} + assert parse_metadata_json([1, 2]) == {} + + def test_bytes_input(self): + result = parse_metadata_json(b'{"key": "value"}') + assert result == {"key": "value"} + + def test_empty_bytes_returns_empty_dict(self): + assert parse_metadata_json(b"") == {} + + def test_empty_bytearray_returns_empty_dict(self): + assert parse_metadata_json(bytearray(b"")) == {} From b23ea0397a756d7b6f267c5789a292eabbb1c502 Mon Sep 17 00:00:00 2001 From: jimmyzhuu Date: Wed, 1 Apr 2026 14:16:09 +0800 Subject: [PATCH 4/4] fix: apply Baidu Vector DB connection timeout when initializing Mochow client (#34328) --- api/core/rag/datasource/vdb/baidu/baidu_vector.py | 6 +++++- .../rag/datasource/vdb/baidu/test_baidu_vector.py | 13 +++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 3173920c9c..2b220fc04d 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -195,7 +195,11 @@ class BaiduVector(BaseVector): raise def _init_client(self, config) -> MochowClient: - config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint) + config = Configuration( + credentials=BceCredentials(config.account, config.api_key), + endpoint=config.endpoint, + connection_timeout_in_mills=config.connection_timeout_in_mills, + ) client = MochowClient(config) return client diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py index c46c3d5e4b..487d021697 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py @@ -381,13 +381,22 @@ def test_init_client_constructs_configuration_and_client(baidu_module, monkeypat monkeypatch.setattr(baidu_module, "MochowClient", client_cls) vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) - config = SimpleNamespace(account="account", api_key="key", endpoint="https://endpoint") + config = SimpleNamespace( + account="account", + api_key="key", + endpoint="https://endpoint", + connection_timeout_in_mills=12_345, + ) client = vector._init_client(config) assert client == "client" credentials.assert_called_once_with("account", "key") - configuration.assert_called_once_with(credentials="credentials", endpoint="https://endpoint") + configuration.assert_called_once_with( + credentials="credentials", + endpoint="https://endpoint", + connection_timeout_in_mills=12_345, + ) client_cls.assert_called_once_with("configuration")