From be1f4b34f88a0802545966162a94ae74caa332e7 Mon Sep 17 00:00:00 2001 From: carlos4s <71615127+carlos4s@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:42:39 -0500 Subject: [PATCH 01/14] refactor(api): use sessionmaker in workflow & RAG pipeline services (#34805) --- api/services/rag_pipeline/rag_pipeline.py | 6 ++---- api/services/workflow_draft_variable_service.py | 3 +-- api/services/workflow_service.py | 6 ++---- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index b330e1a46a..f6d80f9a6e 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -555,7 +555,7 @@ class RagPipelineService: workflow_node_execution.id ) - with Session(bind=db.engine) as session, session.begin(): + with sessionmaker(bind=db.engine).begin() as session: draft_var_saver = DraftVariableSaver( session=session, app_id=pipeline.id, @@ -569,7 +569,6 @@ class RagPipelineService: process_data=workflow_node_execution.process_data, outputs=workflow_node_execution.outputs, ) - session.commit() if isinstance(workflow_node_execution_db_model, WorkflowNodeExecutionModel): enqueue_draft_node_execution_trace( execution=workflow_node_execution_db_model, @@ -1325,7 +1324,7 @@ class RagPipelineService: # Convert node_execution to WorkflowNodeExecution after save workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) # type: ignore - with Session(bind=db.engine) as session, session.begin(): + with sessionmaker(bind=db.engine).begin() as session: draft_var_saver = DraftVariableSaver( session=session, app_id=pipeline.id, @@ -1339,7 +1338,6 @@ class RagPipelineService: process_data=workflow_node_execution.process_data, outputs=workflow_node_execution.outputs, ) - session.commit() enqueue_draft_node_execution_trace( execution=workflow_node_execution_db_model, outputs=workflow_node_execution.outputs, diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 9ed60bf86b..1c1b94ae9d 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -1075,9 +1075,8 @@ class DraftVariableSaver: ) engine = bind = self._session.get_bind() assert isinstance(engine, Engine) - with Session(bind=engine, expire_on_commit=False) as session: + with sessionmaker(bind=engine, expire_on_commit=False).begin() as session: session.add(variable_file) - session.commit() return truncation_result.result, variable_file diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index eaffb60c63..1e3feeed29 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -837,7 +837,7 @@ class WorkflowService: with sessionmaker(db.engine).begin() as session: outputs = workflow_node_execution.load_full_outputs(session, storage) - with Session(bind=db.engine) as session, session.begin(): + with sessionmaker(bind=db.engine).begin() as session: draft_var_saver = DraftVariableSaver( session=session, app_id=app_model.id, @@ -848,7 +848,6 @@ class WorkflowService: user=account, ) draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs) - session.commit() enqueue_draft_node_execution_trace( execution=workflow_node_execution, @@ -977,7 +976,7 @@ class WorkflowService: enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None - with Session(bind=db.engine) as session, session.begin(): + with sessionmaker(bind=db.engine).begin() as session: draft_var_saver = DraftVariableSaver( session=session, app_id=app_model.id, @@ -988,7 +987,6 @@ class WorkflowService: enclosing_node_id=enclosing_node_id, ) draft_var_saver.save(outputs=outputs, process_data={}) - session.commit() return outputs From a76a8876d14559672daae5f7cf19f8701fcbd43e Mon Sep 17 00:00:00 2001 From: carlos4s <71615127+carlos4s@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:43:13 -0500 Subject: [PATCH 02/14] refactor(api): use sessionmaker in datasource provider service (#34811) --- api/services/datasource_provider_service.py | 32 +++++++------------ .../test_datasource_provider_service.py | 17 ++++------ 2 files changed, 17 insertions(+), 32 deletions(-) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index faa978afdc..d5f8cd30bd 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -5,7 +5,7 @@ from typing import Any from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy import func, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE @@ -53,13 +53,12 @@ class DatasourceProviderService: """ remove oauth custom client params """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: session.query(DatasourceOauthTenantParamConfig).filter_by( tenant_id=tenant_id, provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, ).delete() - session.commit() def decrypt_datasource_provider_credentials( self, @@ -109,7 +108,7 @@ class DatasourceProviderService: """ get credential by id """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: if credential_id: datasource_provider = ( session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() @@ -156,7 +155,6 @@ class DatasourceProviderService: datasource_provider=datasource_provider, ) datasource_provider.expires_at = refreshed_credentials.expires_at - session.commit() return self.decrypt_datasource_provider_credentials( tenant_id=tenant_id, @@ -174,7 +172,7 @@ class DatasourceProviderService: """ get all datasource credentials by provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: datasource_providers = ( session.query(DatasourceProvider) .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) @@ -224,7 +222,6 @@ class DatasourceProviderService: provider=provider, ) real_credentials_list.append(real_credentials) - session.commit() return real_credentials_list @@ -234,7 +231,7 @@ class DatasourceProviderService: """ update datasource provider name """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: target_provider = ( session.query(DatasourceProvider) .filter_by( @@ -266,7 +263,6 @@ class DatasourceProviderService: raise ValueError("Authorization name is already exists") target_provider.name = name - session.commit() return def set_default_datasource_provider( @@ -275,7 +271,7 @@ class DatasourceProviderService: """ set default datasource provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # get provider target_provider = ( session.query(DatasourceProvider) @@ -300,7 +296,6 @@ class DatasourceProviderService: # set new default provider target_provider.is_default = True - session.commit() return {"result": "success"} def setup_oauth_custom_client_params( @@ -315,7 +310,7 @@ class DatasourceProviderService: """ if client_params is None and enabled is None: return - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: tenant_oauth_client_params = ( session.query(DatasourceOauthTenantParamConfig) .filter_by( @@ -349,7 +344,6 @@ class DatasourceProviderService: if enabled is not None: tenant_oauth_client_params.enabled = enabled - session.commit() def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool: """ @@ -488,7 +482,7 @@ class DatasourceProviderService: """ update datasource oauth provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}" with redis_client.lock(lock, timeout=20): target_provider = ( @@ -535,7 +529,6 @@ class DatasourceProviderService: target_provider.expires_at = expire_at target_provider.encrypted_credentials = credentials target_provider.avatar_url = avatar_url or target_provider.avatar_url - session.commit() def add_datasource_oauth_provider( self, @@ -550,7 +543,7 @@ class DatasourceProviderService: add datasource oauth provider """ credential_type = CredentialType.OAUTH2 - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}" with redis_client.lock(lock, timeout=60): db_provider_name = name @@ -604,7 +597,6 @@ class DatasourceProviderService: expires_at=expire_at, ) session.add(datasource_provider) - session.commit() def add_datasource_api_key_provider( self, @@ -623,7 +615,7 @@ class DatasourceProviderService: provider_name = provider_id.provider_name plugin_id = provider_id.plugin_id - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}" with redis_client.lock(lock, timeout=20): db_provider_name = name or self.generate_next_datasource_provider_name( @@ -670,7 +662,6 @@ class DatasourceProviderService: encrypted_credentials=credentials, ) session.add(datasource_provider) - session.commit() def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]: """ @@ -926,7 +917,7 @@ class DatasourceProviderService: update datasource credentials. """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: datasource_provider = ( session.query(DatasourceProvider) .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) @@ -980,7 +971,6 @@ class DatasourceProviderService: encrypted_credentials[key] = value datasource_provider.encrypted_credentials = encrypted_credentials - session.commit() def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None: """ diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index bc4120e2af..70ecc158d6 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -40,7 +40,10 @@ class TestDatasourceProviderService: q returns itself for .filter_by(), .order_by(), .where() so any SQLAlchemy chaining pattern works without multiple brittle sub-mocks. """ - with patch("services.datasource_provider_service.Session") as mock_cls: + with ( + patch("services.datasource_provider_service.Session") as mock_cls, + patch("services.datasource_provider_service.sessionmaker") as mock_sm, + ): sess = MagicMock(spec=Session) q = MagicMock() @@ -63,6 +66,8 @@ class TestDatasourceProviderService: mock_cls.return_value.__enter__.return_value = sess mock_cls.return_value.no_autoflush.__enter__.return_value = sess + mock_sm.return_value.begin.return_value.__enter__.return_value = sess + mock_sm.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) yield sess @@ -266,7 +271,6 @@ class TestDatasourceProviderService: patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}), ): service.get_datasource_credentials("t1", "prov", "org/plug") - mock_db_session.commit.assert_called_once() def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user): """API key credentials with expires_at=-1 skip refresh and return directly.""" @@ -333,7 +337,6 @@ class TestDatasourceProviderService: p.name = "same" mock_db_session.query().first.return_value = p service.update_datasource_provider_name("t1", make_id(), "same", "cred-id") - mock_db_session.commit.assert_not_called() def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) @@ -352,7 +355,6 @@ class TestDatasourceProviderService: mock_db_session.query().count.return_value = 0 service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") assert p.name == "new_name" - mock_db_session.commit.assert_called_once() # ----------------------------------------------------------------------- # set_default_datasource_provider (lines 277-303) @@ -370,7 +372,6 @@ class TestDatasourceProviderService: mock_db_session.query().first.return_value = target service.set_default_datasource_provider("t1", make_id(), "new-id") assert target.is_default is True - mock_db_session.commit.assert_called_once() # ----------------------------------------------------------------------- # get_oauth_encrypter (lines 404-420) @@ -460,7 +461,6 @@ class TestDatasourceProviderService: with patch.object(service, "extract_secret_variables", return_value=[]): service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {}) mock_db_session.add.assert_called_once() - mock_db_session.commit.assert_called_once() def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session): """Conflict on name results in auto-incremented name, not an error.""" @@ -512,7 +512,6 @@ class TestDatasourceProviderService: mock_db_session.query().count.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=[]): service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") - mock_db_session.commit.assert_called_once() def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) @@ -523,7 +522,6 @@ class TestDatasourceProviderService: service.reauthorize_datasource_oauth_provider( "conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id" ) - mock_db_session.commit.assert_called_once() def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) @@ -571,7 +569,6 @@ class TestDatasourceProviderService: ): service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"}) mock_db_session.add.assert_called_once() - mock_db_session.commit.assert_called_once() def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user): mock_db_session.query().count.return_value = 0 @@ -747,7 +744,6 @@ class TestDatasourceProviderService: # encrypter must have been called with the new secret value self._enc.encrypt_token.assert_called() # commit must be called exactly once - mock_db_session.commit.assert_called_once() # ----------------------------------------------------------------------- # remove_datasource_credentials (lines 980-997) @@ -758,7 +754,6 @@ class TestDatasourceProviderService: mock_db_session.scalar.return_value = p service.remove_datasource_credentials("t1", "id", "prov", "org/plug") mock_db_session.delete.assert_called_once_with(p) - mock_db_session.commit.assert_called_once() def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session): """No error raised; no delete called when record doesn't exist (lines 994 branch).""" From f5ea61e93ed53755000694151744a8b534c163ab Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:44:13 -0500 Subject: [PATCH 03/14] refactor: migrate session.query to select API in document indexing sync task (#34813) --- api/tasks/document_indexing_sync_task.py | 16 +++++++++------- .../tasks/test_document_indexing_sync_task.py | 11 +++++------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index f99e90062f..90c80be3a1 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -32,7 +32,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): tenant_id = None with session_factory.create_session() as session, session.begin(): - document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) + ) if not document: logger.info(click.style(f"Document not found: {document_id}", fg="red")) @@ -42,7 +44,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow")) return - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Dataset not found") @@ -87,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ) with session_factory.create_session() as session, session.begin(): - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if document: document.indexing_status = IndexingStatus.ERROR document.error = "Datasource credential not found. Please reconnect your Notion workspace." @@ -112,7 +114,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): try: index_processor = IndexProcessorFactory(index_type).init_index_processor() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if dataset: index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green")) @@ -120,7 +122,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.exception("Failed to clean vector index for document %s", document_id) with session_factory.create_session() as session, session.begin(): - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if not document: logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow")) return @@ -140,7 +142,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): try: indexing_runner = IndexingRunner() with session_factory.create_session() as session: - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if document: indexing_runner.run([document]) end_at = time.perf_counter() @@ -150,7 +152,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): except Exception as e: logger.exception("document_indexing_sync_task failed for document_id: %s", document_id) with session_factory.create_session() as session, session.begin(): - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if document: document.indexing_status = IndexingStatus.ERROR document.error = str(e) diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index f49f4535af..41d3068a10 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -80,7 +80,7 @@ def mock_db_session(mock_document, mock_dataset): with patch("tasks.document_indexing_sync_task.session_factory", autospec=True) as mock_session_factory: session = MagicMock() session.scalars.return_value.all.return_value = [] - session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + session.scalar.side_effect = [mock_document, mock_dataset] begin_cm = MagicMock() begin_cm.__enter__.return_value = session @@ -242,14 +242,13 @@ class TestDataSourceInfoSerialization: # DB session mock — shared across all ``session_factory.create_session()`` calls session = MagicMock() session.scalars.return_value.all.return_value = [] - # .where() path: session 1 reads document + dataset, session 2 reads dataset - session.query.return_value.where.return_value.first.side_effect = [ + # All .first() calls are now session.scalar() — ordered by call sequence: + # session 1: document + dataset, session 2: dataset (clean), session 3: document (update), + # session 4: document (indexing) + session.scalar.side_effect = [ mock_document, mock_dataset, mock_dataset, - ] - # .filter_by() path: session 3 (update), session 4 (indexing) - session.query.return_value.filter_by.return_value.first.side_effect = [ mock_document, mock_document, ] From b5acc8e3925000ff3755474896a84a7082541e49 Mon Sep 17 00:00:00 2001 From: aliworksx08 <57456290+aliworksx08@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:44:49 -0500 Subject: [PATCH 04/14] refactor: migrate session.query to select API in core tools (#34814) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/tools/tool_file_manager.py | 33 +++---------------- api/core/tools/workflow_as_tool/provider.py | 13 ++++---- .../core/tools/test_tool_file_manager.py | 26 ++++----------- .../tools/workflow_as_tool/test_provider.py | 8 ++--- 4 files changed, 23 insertions(+), 57 deletions(-) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 7ac29cf069..a59d167a0a 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -11,6 +11,7 @@ from uuid import uuid4 import httpx from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type +from sqlalchemy import select from configs import dify_config from core.db.session_factory import session_factory @@ -166,13 +167,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ with session_factory.create_session() as session: - tool_file: ToolFile | None = ( - session.query(ToolFile) - .where( - ToolFile.id == id, - ) - .first() - ) + tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == id).limit(1)) if not tool_file: return None @@ -190,13 +185,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ with session_factory.create_session() as session: - message_file: MessageFile | None = ( - session.query(MessageFile) - .where( - MessageFile.id == id, - ) - .first() - ) + message_file: MessageFile | None = session.scalar(select(MessageFile).where(MessageFile.id == id).limit(1)) # Check if message_file is not None if message_file is not None: @@ -210,13 +199,7 @@ class ToolFileManager: else: tool_file_id = None - tool_file: ToolFile | None = ( - session.query(ToolFile) - .where( - ToolFile.id == tool_file_id, - ) - .first() - ) + tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1)) if not tool_file: return None @@ -234,13 +217,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ with session_factory.create_session() as session: - tool_file: ToolFile | None = ( - session.query(ToolFile) - .where( - ToolFile.id == tool_file_id, - ) - .first() - ) + tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1)) if not tool_file: return None, None diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index f48b24be30..a01004448a 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import Field +from sqlalchemy import select from sqlalchemy.orm import Session from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -96,10 +97,10 @@ class WorkflowToolProviderController(ToolProviderController): :param app: the app :return: the tool """ - workflow: Workflow | None = ( - session.query(Workflow) + workflow: Workflow | None = session.scalar( + select(Workflow) .where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) - .first() + .limit(1) ) if not workflow: @@ -217,13 +218,13 @@ class WorkflowToolProviderController(ToolProviderController): return self.tools with Session(db.engine, expire_on_commit=False) as session, session.begin(): - db_provider: WorkflowToolProvider | None = ( - session.query(WorkflowToolProvider) + db_provider: WorkflowToolProvider | None = session.scalar( + select(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == self.provider_id, ) - .first() + .limit(1) ) if not db_provider: diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index 7fcebde3c5..2889cb9db1 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -129,7 +129,7 @@ def test_get_file_binary_returns_none_when_not_found() -> None: # Arrange manager = ToolFileManager() session = Mock() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None # Act with _patch_session_factory(session): @@ -144,7 +144,7 @@ def test_get_file_binary_returns_bytes_when_found() -> None: manager = ToolFileManager() tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain") session = Mock() - session.query.return_value.where.return_value.first.return_value = tool_file + session.scalar.return_value = tool_file # Act with patch("core.tools.tool_file_manager.storage") as storage: @@ -160,11 +160,7 @@ def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None: # Arrange manager = ToolFileManager() session = Mock() - first_query = Mock() - second_query = Mock() - first_query.where.return_value.first.return_value = None - second_query.where.return_value.first.return_value = None - session.query.side_effect = [first_query, second_query] + session.scalar.side_effect = [None, None] # Act with _patch_session_factory(session): @@ -179,11 +175,7 @@ def test_get_file_binary_by_message_file_id_when_url_is_none() -> None: manager = ToolFileManager() message_file = SimpleNamespace(url=None) session = Mock() - first_query = Mock() - second_query = Mock() - first_query.where.return_value.first.return_value = message_file - second_query.where.return_value.first.return_value = None - session.query.side_effect = [first_query, second_query] + session.scalar.side_effect = [message_file, None] # Act with _patch_session_factory(session): @@ -199,11 +191,7 @@ def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None: message_file = SimpleNamespace(url="https://x/files/tools/tool123.png") tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") session = Mock() - first_query = Mock() - second_query = Mock() - first_query.where.return_value.first.return_value = message_file - second_query.where.return_value.first.return_value = tool_file - session.query.side_effect = [first_query, second_query] + session.scalar.side_effect = [message_file, tool_file] # Act with patch("core.tools.tool_file_manager.storage") as storage: @@ -219,7 +207,7 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None: # Arrange manager = ToolFileManager() session = Mock() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None # Act with _patch_session_factory(session): @@ -242,7 +230,7 @@ def test_get_file_generator_returns_stream_when_found() -> None: size=12, ) session = Mock() - session.query.return_value.where.return_value.first.return_value = tool_file + session.scalar.return_value = tool_file # Act with patch("core.tools.tool_file_manager.storage") as storage: diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index 2607861b59..4767480a5a 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -43,7 +43,7 @@ def test_get_db_provider_tool_builds_entity(): controller = _controller() session = Mock() workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={}) - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow app = SimpleNamespace(id="app-1") db_provider = SimpleNamespace( id="provider-1", @@ -136,7 +136,7 @@ def test_from_db_builds_controller(): parameter_configurations=[], ) session = _mock_session_with_begin() - session.query.return_value.where.return_value.first.return_value = db_provider + session.scalar.return_value = db_provider session.get.side_effect = [app, user] fake_cm = MagicMock() fake_cm.__enter__.return_value = session @@ -163,7 +163,7 @@ def test_get_tools_returns_empty_when_provider_missing(): mock_db.engine = object() with patch("core.tools.workflow_as_tool.provider.Session") as session_cls: session = _mock_session_with_begin() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None session_cls.return_value.__enter__.return_value = session assert controller.get_tools("tenant-1") == [] @@ -189,7 +189,7 @@ def test_get_tools_raises_when_app_missing(): mock_db.engine = object() with patch("core.tools.workflow_as_tool.provider.Session") as session_cls: session = _mock_session_with_begin() - session.query.return_value.where.return_value.first.return_value = db_provider + session.scalar.return_value = db_provider session.get.return_value = None session_cls.return_value.__enter__.return_value = session with pytest.raises(ValueError, match="app not found"): From e3cc4b83c878747c56d2ee95a2022abb5be2d431 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:46:36 -0500 Subject: [PATCH 05/14] refactor: migrate session.query to select API in clean dataset task (#34815) --- api/tasks/clean_dataset_task.py | 30 +++++++++++-------- .../tasks/test_clean_dataset_task.py | 27 ++++------------- 2 files changed, 23 insertions(+), 34 deletions(-) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 0d51a743ad..377d0e5cc7 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -112,7 +112,9 @@ def clean_dataset_task( segment_ids = [segment.id for segment in segments] for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) - image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all() + image_files = session.scalars( + select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + ).all() for image_file in image_files: if image_file is None: continue @@ -150,20 +152,22 @@ def clean_dataset_task( ) session.execute(binding_delete_stmt) - session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() - session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() - session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete() + session.execute(delete(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id)) + session.execute(delete(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id)) + session.execute(delete(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id)) # delete dataset metadata - session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete() - session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete() + session.execute(delete(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id)) + session.execute(delete(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id)) # delete pipeline and workflow if pipeline_id: - session.query(Pipeline).where(Pipeline.id == pipeline_id).delete() - session.query(Workflow).where( - Workflow.tenant_id == tenant_id, - Workflow.app_id == pipeline_id, - Workflow.type == WorkflowType.RAG_PIPELINE, - ).delete() + session.execute(delete(Pipeline).where(Pipeline.id == pipeline_id)) + session.execute( + delete(Workflow).where( + Workflow.tenant_id == tenant_id, + Workflow.app_id == pipeline_id, + Workflow.type == WorkflowType.RAG_PIPELINE, + ) + ) # delete files if documents: file_ids = [] @@ -174,7 +178,7 @@ def clean_dataset_task( if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] file_ids.append(file_id) - files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all() + files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() for file in files: storage.delete(file.key) diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index 936a10d6c5..b4332334ab 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -60,12 +60,6 @@ def mock_db_session(): cm.__exit__.return_value = None mock_sf.create_session.return_value = cm - # Setup query chain - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.delete.return_value = 0 - # Setup scalars for select queries mock_session.scalars.return_value.all.return_value = [] @@ -220,11 +214,6 @@ class TestPipelineAndWorkflowDeletion: - Pipeline record is deleted - Related workflow record is deleted """ - # Arrange - mock_query = mock_db_session.session.query.return_value - mock_query.where.return_value = mock_query - mock_query.delete.return_value = 1 - # Act clean_dataset_task( dataset_id=dataset_id, @@ -236,9 +225,9 @@ class TestPipelineAndWorkflowDeletion: pipeline_id=pipeline_id, ) - # Assert - verify delete was called for pipeline-related queries - # The actual count depends on total queries, but pipeline deletion should add 2 more - assert mock_query.delete.call_count >= 7 # 5 base + 2 pipeline/workflow + # Assert - verify execute was called for delete operations + # 1 attachment JOIN query + 5 base deletes + 2 pipeline/workflow deletes = 8 + assert mock_db_session.session.execute.call_count >= 8 def test_clean_dataset_task_without_pipeline_id( self, @@ -256,11 +245,6 @@ class TestPipelineAndWorkflowDeletion: Expected behavior: - Pipeline and workflow deletion queries are not executed """ - # Arrange - mock_query = mock_db_session.session.query.return_value - mock_query.where.return_value = mock_query - mock_query.delete.return_value = 1 - # Act clean_dataset_task( dataset_id=dataset_id, @@ -272,8 +256,9 @@ class TestPipelineAndWorkflowDeletion: pipeline_id=None, ) - # Assert - verify delete was called only for base queries (5 times) - assert mock_query.delete.call_count == 5 + # Assert - verify execute was called for delete operations + # 1 attachment JOIN query + 5 base deletes = 6 + assert mock_db_session.session.execute.call_count == 6 # ============================================================================ From 5f53748d074f80802bd840dade44f2e8142eefdd Mon Sep 17 00:00:00 2001 From: dataCenter430 <161712630+dataCenter430@users.noreply.github.com> Date: Wed, 8 Apr 2026 22:48:40 -0700 Subject: [PATCH 06/14] refactor: convert ToolProviderType if/elif to match/case (#30001) (#34794) --- api/core/tools/entities/api_entities.py | 35 +++++++++++-------- api/services/tools/tools_transform_service.py | 34 ++++++++++-------- 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index d5d3d1b1d9..410ec72baf 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -75,22 +75,27 @@ class ToolProviderApiEntity(BaseModel): parameter.pop("input_schema", None) # ------------- optional_fields = self.optional_field("server_url", self.server_url) - if self.type == ToolProviderType.MCP: - optional_fields.update(self.optional_field("updated_at", self.updated_at)) - optional_fields.update(self.optional_field("server_identifier", self.server_identifier)) - optional_fields.update( - self.optional_field( - "configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration() + match self.type: + case ToolProviderType.MCP: + optional_fields.update(self.optional_field("updated_at", self.updated_at)) + optional_fields.update(self.optional_field("server_identifier", self.server_identifier)) + optional_fields.update( + self.optional_field( + "configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration() + ) ) - ) - optional_fields.update( - self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None) - ) - optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration)) - optional_fields.update(self.optional_field("masked_headers", self.masked_headers)) - optional_fields.update(self.optional_field("original_headers", self.original_headers)) - elif self.type == ToolProviderType.WORKFLOW: - optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id)) + optional_fields.update( + self.optional_field( + "authentication", self.authentication.model_dump() if self.authentication else None + ) + ) + optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration)) + optional_fields.update(self.optional_field("masked_headers", self.masked_headers)) + optional_fields.update(self.optional_field("original_headers", self.original_headers)) + case ToolProviderType.WORKFLOW: + optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id)) + case _: + pass return { "id": self.id, "author": self.author, diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b24f001133..4fd2ea1628 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -48,21 +48,25 @@ class ToolTransformService: URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider" ) - if provider_type == ToolProviderType.BUILT_IN: - return str(url_prefix / "builtin" / provider_name / "icon") - elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}: - try: - if isinstance(icon, str): - parsed = emoji_icon_adapter.validate_json(icon) - return {"background": parsed["background"], "content": parsed["content"]} - return {"background": icon["background"], "content": icon["content"]} - except (ValueError, ValidationError, KeyError): - return {"background": "#252525", "content": "\ud83d\ude01"} - elif provider_type == ToolProviderType.MCP: - if isinstance(icon, Mapping): - return {"background": icon.get("background", ""), "content": icon.get("content", "")} - return icon - return "" + match provider_type: + case ToolProviderType.BUILT_IN: + return str(url_prefix / "builtin" / provider_name / "icon") + case ToolProviderType.API | ToolProviderType.WORKFLOW: + try: + if isinstance(icon, str): + parsed = emoji_icon_adapter.validate_json(icon) + return {"background": parsed["background"], "content": parsed["content"]} + return {"background": icon["background"], "content": icon["content"]} + except (ValueError, ValidationError, KeyError): + return {"background": "#252525", "content": "\ud83d\ude01"} + case ToolProviderType.MCP: + if isinstance(icon, Mapping): + return {"background": icon.get("background", ""), "content": icon.get("content", "")} + return icon + case ToolProviderType.PLUGIN | ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL: + return "" + case _: + return "" @staticmethod def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]): From d360929af1687439e0483e6f103d56a1585c86aa Mon Sep 17 00:00:00 2001 From: carlos4s <71615127+carlos4s@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:49:03 -0500 Subject: [PATCH 07/14] refactor(api): use sessionmaker in pgvecto_rs VDB service (#34818) --- .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 20 ++++----- .../vdb/pgvecto_rs/test_pgvecto_rs.py | 41 +++++++++++++++---- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 90d9173409..387e918c76 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator from sqlalchemy import Float, create_engine, insert, select, text from sqlalchemy import text as sql_text from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import Mapped, Session, mapped_column +from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker from configs import dify_config from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM @@ -55,9 +55,8 @@ class PGVectoRS(BaseVector): f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" ) self._client = create_engine(self._url) - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) - session.commit() self._fields: list[str] = [] class _Table(CollectionORM): @@ -88,7 +87,7 @@ class PGVectoRS(BaseVector): if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: create_statement = sql_text(f""" CREATE TABLE IF NOT EXISTS {self._collection_name} ( id UUID PRIMARY KEY, @@ -111,12 +110,11 @@ class PGVectoRS(BaseVector): $$); """) session.execute(index_statement) - session.commit() redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): pks = [] - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: for document, embedding in zip(documents, embeddings): pk = uuid4() session.execute( @@ -128,7 +126,6 @@ class PGVectoRS(BaseVector): ), ) pks.append(pk) - session.commit() return pks @@ -145,10 +142,9 @@ class PGVectoRS(BaseVector): def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") session.execute(select_statement, {"ids": ids}) - session.commit() def delete_by_ids(self, ids: list[str]): with Session(self._client) as session: @@ -159,15 +155,13 @@ class PGVectoRS(BaseVector): if result: ids = [item[0] for item in result] if ids: - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") session.execute(select_statement, {"ids": ids}) - session.commit() def delete(self): - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}")) - session.commit() def text_exists(self, id: str) -> bool: with Session(self._client) as session: diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py index 1aec81b8ac..5b9ec8002a 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -53,6 +53,31 @@ def _session_factory(calls, execute_results=None): return _session +class _FakeBeginContext: + def __init__(self, session): + self._session = session + + def __enter__(self): + return self._session + + def __exit__(self, exc_type, exc, tb): + return None + + +def _sessionmaker_factory(calls, execute_results=None): + def _sessionmaker(*args, **kwargs): + session = _FakeSessionContext(calls=calls, execute_results=execute_results) + return MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session))) + + return _sessionmaker + + +def _patch_both(monkeypatch, module, calls, execute_results=None): + """Patch both Session and sessionmaker on the module with the same call tracker.""" + monkeypatch.setattr(module, "Session", _session_factory(calls, execute_results)) + monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(calls, execute_results)) + + @pytest.fixture def pgvecto_module(monkeypatch): for name, module in _build_fake_pgvecto_modules().items(): @@ -105,7 +130,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch): module, _ = pgvecto_module session_calls = [] monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) - monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + _patch_both(monkeypatch, module, session_calls) vector = module.PGVectoRS("collection_1", _config(module), dim=3) vector.create_collection = MagicMock() @@ -124,7 +149,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch): module, _ = pgvecto_module session_calls = [] monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) - monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + _patch_both(monkeypatch, module, session_calls) lock = MagicMock() lock.__enter__.return_value = None @@ -151,10 +176,10 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])] monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) - monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + _patch_both(monkeypatch, module, init_calls) vector = module.PGVectoRS("collection_1", _config(module), dim=3) - monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results))) + _patch_both(monkeypatch, module, runtime_calls, execute_results=list(execute_results)) class _InsertBuilder: def __init__(self, table): @@ -179,6 +204,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): "Session", _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]), ) + monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls)) assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] monkeypatch.setattr( @@ -204,12 +230,13 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): ], ), ) + monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls)) vector.delete_by_ids(["doc-1"]) assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls) assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) runtime_calls.clear() - monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()])) + _patch_both(monkeypatch, module, runtime_calls, execute_results=[MagicMock()]) vector.delete() assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls) @@ -218,7 +245,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch): module, _ = pgvecto_module init_calls = [] monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) - monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + _patch_both(monkeypatch, module, init_calls) vector = module.PGVectoRS("collection_1", _config(module), dim=3) runtime_calls = [] @@ -277,7 +304,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch): (SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1), (SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8), ] - monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows])) + _patch_both(monkeypatch, module, runtime_calls, execute_results=[rows]) docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) assert len(docs) == 1 From ee789db443642565cfa2e3c38f58680adef30bc3 Mon Sep 17 00:00:00 2001 From: aliworksx08 <57456290+aliworksx08@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:49:59 -0500 Subject: [PATCH 08/14] refactor: migrate session.query to select API in plugin services (#34817) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../plugin/plugin_auto_upgrade_service.py | 19 ++++++------ .../plugin/plugin_permission_service.py | 9 ++++-- .../test_plugin_auto_upgrade_service.py | 29 ++++++++++--------- .../plugin/test_plugin_permission_service.py | 10 +++---- 4 files changed, 36 insertions(+), 31 deletions(-) diff --git a/api/services/plugin/plugin_auto_upgrade_service.py b/api/services/plugin/plugin_auto_upgrade_service.py index adbed87c3c..a58bede8db 100644 --- a/api/services/plugin/plugin_auto_upgrade_service.py +++ b/api/services/plugin/plugin_auto_upgrade_service.py @@ -1,3 +1,4 @@ +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from extensions.ext_database import db @@ -8,10 +9,10 @@ class PluginAutoUpgradeService: @staticmethod def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None: with sessionmaker(bind=db.engine).begin() as session: - return ( - session.query(TenantPluginAutoUpgradeStrategy) + return session.scalar( + select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) - .first() + .limit(1) ) @staticmethod @@ -24,10 +25,10 @@ class PluginAutoUpgradeService: include_plugins: list[str], ) -> bool: with sessionmaker(bind=db.engine).begin() as session: - exist_strategy = ( - session.query(TenantPluginAutoUpgradeStrategy) + exist_strategy = session.scalar( + select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) - .first() + .limit(1) ) if not exist_strategy: strategy = TenantPluginAutoUpgradeStrategy( @@ -51,10 +52,10 @@ class PluginAutoUpgradeService: @staticmethod def exclude_plugin(tenant_id: str, plugin_id: str) -> bool: with sessionmaker(bind=db.engine).begin() as session: - exist_strategy = ( - session.query(TenantPluginAutoUpgradeStrategy) + exist_strategy = session.scalar( + select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) - .first() + .limit(1) ) if not exist_strategy: # create for this tenant diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py index 55276d6f99..0d2a70acbd 100644 --- a/api/services/plugin/plugin_permission_service.py +++ b/api/services/plugin/plugin_permission_service.py @@ -1,3 +1,4 @@ +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from extensions.ext_database import db @@ -8,7 +9,9 @@ class PluginPermissionService: @staticmethod def get_permission(tenant_id: str) -> TenantPluginPermission | None: with sessionmaker(bind=db.engine).begin() as session: - return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first() + return session.scalar( + select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) + ) @staticmethod def change_permission( @@ -17,8 +20,8 @@ class PluginPermissionService: debug_permission: TenantPluginPermission.DebugPermission, ): with sessionmaker(bind=db.engine).begin() as session: - permission = ( - session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first() + permission = session.scalar( + select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) ) if not permission: permission = TenantPluginPermission( diff --git a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py index 45156958b6..bc2f1c6ecc 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py @@ -20,7 +20,7 @@ class TestGetStrategy: def test_returns_strategy_when_found(self): p1, p2, session = _patched_session() strategy = MagicMock() - session.query.return_value.where.return_value.first.return_value = strategy + session.scalar.return_value = strategy with p1, p2: from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService @@ -31,7 +31,7 @@ class TestGetStrategy: def test_returns_none_when_not_found(self): p1, p2, session = _patched_session() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None with p1, p2: from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService @@ -44,9 +44,9 @@ class TestGetStrategy: class TestChangeStrategy: def test_creates_new_strategy(self): p1, p2, session = _patched_session() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None - with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.return_value = MagicMock() from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService @@ -65,7 +65,7 @@ class TestChangeStrategy: def test_updates_existing_strategy(self): p1, p2, session = _patched_session() existing = MagicMock() - session.query.return_value.where.return_value.first.return_value = existing + session.scalar.return_value = existing with p1, p2: from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService @@ -90,11 +90,12 @@ class TestChangeStrategy: class TestExcludePlugin: def test_creates_default_strategy_when_none_exists(self): p1, p2, session = _patched_session() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None with ( p1, p2, + patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls, patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs, ): @@ -113,9 +114,9 @@ class TestExcludePlugin: existing = MagicMock() existing.upgrade_mode = "exclude" existing.exclude_plugins = ["p-existing"] - session.query.return_value.where.return_value.first.return_value = existing + session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" @@ -131,9 +132,9 @@ class TestExcludePlugin: existing = MagicMock() existing.upgrade_mode = "partial" existing.include_plugins = ["p1", "p2"] - session.query.return_value.where.return_value.first.return_value = existing + session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" @@ -148,9 +149,9 @@ class TestExcludePlugin: p1, p2, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "all" - session.query.return_value.where.return_value.first.return_value = existing + session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" @@ -167,9 +168,9 @@ class TestExcludePlugin: existing = MagicMock() existing.upgrade_mode = "exclude" existing.exclude_plugins = ["p1"] - session.query.return_value.where.return_value.first.return_value = existing + session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" diff --git a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py index 40f4c6a8d2..20f132c015 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py @@ -20,7 +20,7 @@ class TestGetPermission: def test_returns_permission_when_found(self): p1, p2, session = _patched_session() permission = MagicMock() - session.query.return_value.where.return_value.first.return_value = permission + session.scalar.return_value = permission with p1, p2: from services.plugin.plugin_permission_service import PluginPermissionService @@ -31,7 +31,7 @@ class TestGetPermission: def test_returns_none_when_not_found(self): p1, p2, session = _patched_session() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None with p1, p2: from services.plugin.plugin_permission_service import PluginPermissionService @@ -44,9 +44,9 @@ class TestGetPermission: class TestChangePermission: def test_creates_new_permission_when_not_exists(self): p1, p2, session = _patched_session() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None - with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls: + with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls: perm_cls.return_value = MagicMock() from services.plugin.plugin_permission_service import PluginPermissionService @@ -59,7 +59,7 @@ class TestChangePermission: def test_updates_existing_permission(self): p1, p2, session = _patched_session() existing = MagicMock() - session.query.return_value.where.return_value.first.return_value = existing + session.scalar.return_value = existing with p1, p2: from services.plugin.plugin_permission_service import PluginPermissionService From 9a51c2f56ab189e69d91c606a40792fd3f620955 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:50:59 -0500 Subject: [PATCH 09/14] refactor: migrate session.query to select API in deal dataset vector index task (#34819) --- api/tasks/deal_dataset_vector_index_task.py | 54 ++++++++++++--------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 0047e04a17..36605359dc 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -3,7 +3,7 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -29,7 +29,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): with session_factory.create_session() as session: try: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Dataset not found") @@ -49,23 +49,24 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() for dataset_document in dataset_documents: try: # add from vector index - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] for segment in segments: @@ -82,13 +83,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() elif action == "update": @@ -104,8 +109,10 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() @@ -115,15 +122,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): for dataset_document in dataset_documents: # update from vector index try: - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] multimodal_documents = [] @@ -172,13 +178,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): index_processor.load( dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False ) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() else: From 66e588c8caaa7e89fbac3baa197d754d9d628db9 Mon Sep 17 00:00:00 2001 From: carlos4s <71615127+carlos4s@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:58:38 -0500 Subject: [PATCH 10/14] refactor(api): use sessionmaker in builtin tools manage service (#34812) --- .../tools/builtin_tools_manage_service.py | 22 +++----- .../test_builtin_tools_manage_service.py | 54 ++++++++++--------- 2 files changed, 36 insertions(+), 40 deletions(-) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index d529d2f065..3daaf9a263 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any from sqlalchemy import exists, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE @@ -46,13 +46,12 @@ class BuiltinToolManageService: delete custom oauth client params """ tool_provider = ToolProviderID(provider) - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: session.query(ToolOAuthTenantClient).filter_by( tenant_id=tenant_id, provider=tool_provider.provider_name, plugin_id=tool_provider.plugin_id, ).delete() - session.commit() return {"result": "success"} @staticmethod @@ -150,7 +149,7 @@ class BuiltinToolManageService: """ update builtin tool provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # get if the provider exists db_provider = ( session.query(BuiltinToolProvider) @@ -203,9 +202,7 @@ class BuiltinToolManageService: db_provider.name = name - session.commit() except Exception as e: - session.rollback() raise ValueError(str(e)) return {"result": "success"} @@ -222,7 +219,7 @@ class BuiltinToolManageService: """ add builtin tool provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: try: lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" with redis_client.lock(lock, timeout=20): @@ -281,9 +278,7 @@ class BuiltinToolManageService: ) session.add(db_provider) - session.commit() except Exception as e: - session.rollback() raise ValueError(str(e)) return {"result": "success"} @@ -379,7 +374,7 @@ class BuiltinToolManageService: """ delete tool provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: db_provider = ( session.query(BuiltinToolProvider) .where( @@ -393,7 +388,6 @@ class BuiltinToolManageService: raise ValueError(f"you have not added provider {provider}") session.delete(db_provider) - session.commit() # delete cache provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) @@ -409,7 +403,7 @@ class BuiltinToolManageService: """ set default provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # get provider target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first() if target_provider is None: @@ -422,7 +416,6 @@ class BuiltinToolManageService: # set new default provider target_provider.is_default = True - session.commit() return {"result": "success"} @@ -654,7 +647,7 @@ class BuiltinToolManageService: if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)): raise ValueError(f"Provider {provider} is not a builtin or plugin provider") - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: custom_client_params = ( session.query(ToolOAuthTenantClient) .filter_by( @@ -690,7 +683,6 @@ class BuiltinToolManageService: if enable_oauth_custom_client is not None: custom_client_params.enabled = enable_oauth_custom_client - session.commit() return {"result": "success"} @staticmethod diff --git a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py index 175900071b..e80c306854 100644 --- a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py +++ b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py @@ -15,17 +15,24 @@ def _mock_session(mock_session_cls): return session +def _mock_sessionmaker(mock_sm_cls): + """Helper: set up a sessionmaker().begin() context manager mock and return the inner session.""" + session = MagicMock() + mock_sm_cls.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) + mock_sm_cls.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + return session + + class TestDeleteCustomOauthClientParams: - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_deletes_and_returns_success(self, mock_db, mock_session_cls): - session = _mock_session(mock_session_cls) + def test_deletes_and_returns_success(self, mock_db, mock_sm_cls): + session = _mock_sessionmaker(mock_sm_cls) result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google") assert result == {"result": "success"} session.query.return_value.filter_by.return_value.delete.assert_called_once() - session.commit.assert_called_once() class TestListBuiltinToolProviderTools: @@ -138,10 +145,10 @@ class TestIsOauthCustomClientEnabled: class TestDeleteBuiltinToolProvider: @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") @patch(f"{MODULE}.ToolManager") - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc): - session = _mock_session(mock_session_cls) + def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc): + session = _mock_sessionmaker(mock_sm_cls) session.query.return_value.where.return_value.first.return_value = None with pytest.raises(ValueError, match="you have not added provider"): @@ -149,10 +156,10 @@ class TestDeleteBuiltinToolProvider: @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") @patch(f"{MODULE}.ToolManager") - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc): - session = _mock_session(mock_session_cls) + def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc): + session = _mock_sessionmaker(mock_sm_cls) db_provider = MagicMock() session.query.return_value.where.return_value.first.return_value = db_provider mock_cache = MagicMock() @@ -162,24 +169,23 @@ class TestDeleteBuiltinToolProvider: assert result == {"result": "success"} session.delete.assert_called_once_with(db_provider) - session.commit.assert_called_once() mock_cache.delete.assert_called_once() class TestSetDefaultProvider: - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_raises_when_not_found(self, mock_db, mock_session_cls): - session = _mock_session(mock_session_cls) + def test_raises_when_not_found(self, mock_db, mock_sm_cls): + session = _mock_sessionmaker(mock_sm_cls) session.query.return_value.filter_by.return_value.first.return_value = None with pytest.raises(ValueError, match="provider not found"): BuiltinToolManageService.set_default_provider("t", "u", "p", "id") - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_sets_default_and_clears_old(self, mock_db, mock_session_cls): - session = _mock_session(mock_session_cls) + def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls): + session = _mock_sessionmaker(mock_sm_cls) target = MagicMock() session.query.return_value.filter_by.return_value.first.return_value = target @@ -187,14 +193,13 @@ class TestSetDefaultProvider: assert result == {"result": "success"} assert target.is_default is True - session.commit.assert_called_once() class TestUpdateBuiltinToolProvider: - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls): - session = _mock_session(mock_session_cls) + def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls): + session = _mock_sessionmaker(mock_sm_cls) session.query.return_value.where.return_value.first.return_value = None with pytest.raises(ValueError, match="you have not added provider"): @@ -203,10 +208,10 @@ class TestUpdateBuiltinToolProvider: @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") @patch(f"{MODULE}.CredentialType") @patch(f"{MODULE}.ToolManager") - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc): - session = _mock_session(mock_session_cls) + def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc): + session = _mock_sessionmaker(mock_sm_cls) db_provider = MagicMock(credential_type="api_key", credentials="{}") session.query.return_value.where.return_value.first.return_value = db_provider @@ -227,7 +232,6 @@ class TestUpdateBuiltinToolProvider: result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"}) assert result == {"result": "success"} - session.commit.assert_called_once() mock_cache.delete.assert_called_once() From 4c05316a7b5bda820e4d08abf157ef1e49cb8896 Mon Sep 17 00:00:00 2001 From: aliworksx08 <57456290+aliworksx08@users.noreply.github.com> Date: Thu, 9 Apr 2026 01:04:18 -0500 Subject: [PATCH 11/14] refactor(api): deduplicate DSL shared entities into dsl_entities.py (#34762) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/app.py | 3 ++- api/controllers/console/app/app_import.py | 3 ++- .../rag_pipeline/rag_pipeline_import.py | 2 +- api/controllers/inner_api/app/dsl.py | 3 ++- api/services/app_dsl_service.py | 20 ++---------------- api/services/entities/dsl_entities.py | 21 +++++++++++++++++++ .../rag_pipeline/rag_pipeline_dsl_service.py | 8 ++----- 7 files changed, 32 insertions(+), 28 deletions(-) create mode 100644 api/services/entities/dsl_entities.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index c4b9bf6540..2018f60215 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -34,9 +34,10 @@ from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow from models.model import IconType -from services.app_dsl_service import AppDslService, ImportMode +from services.app_dsl_service import AppDslService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService +from services.entities.dsl_entities import ImportMode from services.entities.knowledge_entities.knowledge_entities import ( DataSource, InfoList, diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 06192936f1..12d6951a48 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -17,8 +17,9 @@ from fields.app_fields import ( ) from libs.login import current_account_with_tenant, login_required from models.model import App -from services.app_dsl_service import AppDslService, ImportStatus +from services.app_dsl_service import AppDslService from services.enterprise.enterprise_service import EnterpriseService +from services.entities.dsl_entities import ImportStatus from services.feature_service import FeatureService from .. import console_ns diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index 76a8c136e4..aa27458176 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -19,7 +19,7 @@ from fields.rag_pipeline_fields import ( ) from libs.login import current_account_with_tenant, login_required from models.dataset import Pipeline -from services.app_dsl_service import ImportStatus +from services.entities.dsl_entities import ImportStatus from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py index b1986b2557..6c15f9aa8b 100644 --- a/api/controllers/inner_api/app/dsl.py +++ b/api/controllers/inner_api/app/dsl.py @@ -18,7 +18,8 @@ from controllers.inner_api.wraps import enterprise_inner_api_only from extensions.ext_database import db from models import Account, App from models.account import AccountStatus -from services.app_dsl_service import AppDslService, ImportMode, ImportStatus +from services.app_dsl_service import AppDslService +from services.entities.dsl_entities import ImportMode, ImportStatus class InnerAppDSLImportPayload(BaseModel): diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index dd73e10374..c6c8a15109 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -3,7 +3,6 @@ import hashlib import logging import uuid from collections.abc import Mapping -from enum import StrEnum from typing import cast from urllib.parse import urlparse from uuid import uuid4 @@ -19,7 +18,7 @@ from graphon.nodes.question_classifier.entities import QuestionClassifierNodeDat from graphon.nodes.tool.entities import ToolNodeData from packaging import version from packaging.version import parse as parse_version -from pydantic import BaseModel, Field +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -40,6 +39,7 @@ from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig, AppModelConfigDict, IconType from models.workflow import Workflow +from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.workflow_draft_variable_service import WorkflowDraftVariableService from services.workflow_service import WorkflowService @@ -53,18 +53,6 @@ DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB CURRENT_DSL_VERSION = "0.6.0" -class ImportMode(StrEnum): - YAML_CONTENT = "yaml-content" - YAML_URL = "yaml-url" - - -class ImportStatus(StrEnum): - COMPLETED = "completed" - COMPLETED_WITH_WARNINGS = "completed-with-warnings" - PENDING = "pending" - FAILED = "failed" - - class Import(BaseModel): id: str status: ImportStatus @@ -75,10 +63,6 @@ class Import(BaseModel): error: str = "" -class CheckDependenciesResult(BaseModel): - leaked_dependencies: list[PluginDependency] = Field(default_factory=list) - - def _check_version_compatibility(imported_version: str) -> ImportStatus: """Determine import status based on version comparison""" try: diff --git a/api/services/entities/dsl_entities.py b/api/services/entities/dsl_entities.py new file mode 100644 index 0000000000..05baa51fbd --- /dev/null +++ b/api/services/entities/dsl_entities.py @@ -0,0 +1,21 @@ +from enum import StrEnum + +from pydantic import BaseModel, Field + +from core.plugin.entities.plugin import PluginDependency + + +class ImportMode(StrEnum): + YAML_CONTENT = "yaml-content" + YAML_URL = "yaml-url" + + +class ImportStatus(StrEnum): + COMPLETED = "completed" + COMPLETED_WITH_WARNINGS = "completed-with-warnings" + PENDING = "pending" + FAILED = "failed" + + +class CheckDependenciesResult(BaseModel): + leaked_dependencies: list[PluginDependency] = Field(default_factory=list) diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index e42c020925..c24bf3d649 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -20,7 +20,7 @@ from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeDat from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from graphon.nodes.tool.entities import ToolNodeData from packaging import version -from pydantic import BaseModel, Field +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -37,7 +37,7 @@ from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode from models.workflow import Workflow, WorkflowType -from services.app_dsl_service import ImportMode, ImportStatus +from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus from services.entities.knowledge_entities.rag_pipeline_entities import ( IconInfo, KnowledgeConfiguration, @@ -64,10 +64,6 @@ class RagPipelineImportInfo(BaseModel): dataset_id: str | None = None -class CheckDependenciesResult(BaseModel): - leaked_dependencies: list[PluginDependency] = Field(default_factory=list) - - def _check_version_compatibility(imported_version: str) -> ImportStatus: """Determine import status based on version comparison""" try: From 8225f9856586e4abcfae46431670c91afa205fa9 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Thu, 9 Apr 2026 14:09:27 +0800 Subject: [PATCH 12/14] fix(web): use nuqs for log conversation url state (#34820) --- .../app/log/__tests__/list.spec.tsx | 244 +++++++++--------- web/app/components/app/log/list.tsx | 20 +- 2 files changed, 123 insertions(+), 141 deletions(-) diff --git a/web/app/components/app/log/__tests__/list.spec.tsx b/web/app/components/app/log/__tests__/list.spec.tsx index a5d801f13f..25512ed689 100644 --- a/web/app/components/app/log/__tests__/list.spec.tsx +++ b/web/app/components/app/log/__tests__/list.spec.tsx @@ -1,14 +1,13 @@ /* eslint-disable ts/no-explicit-any */ import type { ReactNode } from 'react' -import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { act, fireEvent, screen, waitFor } from '@testing-library/react' +import { renderWithNuqs } from '@/test/nuqs-testing' import { AppModeEnum } from '@/types/app' import ConversationList from '../list' const mockFetchChatMessages = vi.fn() const mockUpdateLogMessageFeedbacks = vi.fn() const mockUpdateLogMessageAnnotations = vi.fn() -const mockPush = vi.fn() -const mockReplace = vi.fn() const mockOnRefresh = vi.fn() const mockSetCurrentLogItem = vi.fn() const mockSetShowPromptLogModal = vi.fn() @@ -17,7 +16,6 @@ const mockSetShowMessageLogModal = vi.fn() const mockCompletionRefetch = vi.fn() const mockDelAnnotation = vi.fn() -let mockSearchParams = new URLSearchParams() let mockChatConversationDetail: Record | undefined let mockCompletionConversationDetail: Record | undefined let mockShowMessageLogModal = false @@ -53,18 +51,6 @@ vi.mock('@/hooks/use-breakpoints', () => ({ }, })) -vi.mock('@/next/navigation', () => ({ - useRouter: () => ({ - push: mockPush, - replace: mockReplace, - }), - usePathname: () => '/apps/app-1/logs', - useSearchParams: () => ({ - get: (key: string) => mockSearchParams.get(key), - toString: () => mockSearchParams.toString(), - }), -})) - vi.mock('@/service/use-log', () => ({ useChatConversationDetail: () => ({ data: mockChatConversationDetail, @@ -256,10 +242,28 @@ const createChatMessage = (id: string, overrides: Record = {}) ...overrides, }) +const renderConversationList = ({ + appDetail = { id: 'app-1', mode: AppModeEnum.CHAT } as any, + logs = createLogs() as any, + searchParams = '?page=2', +}: { + appDetail?: any + logs?: any + searchParams?: string +} = {}) => { + return renderWithNuqs( + , + { searchParams }, + ) +} + describe('ConversationList', () => { beforeEach(() => { vi.clearAllMocks() - mockSearchParams = new URLSearchParams('page=2') mockChatConversationDetail = undefined mockCompletionConversationDetail = undefined mockShowMessageLogModal = false @@ -273,34 +277,29 @@ describe('ConversationList', () => { }) }) - it('should render chat rows and push the conversation id into the url when a row is clicked', () => { - render( - , - ) + it('should render chat rows and push the conversation id into the url when a row is clicked', async () => { + const { onUrlUpdate } = renderConversationList() expect(screen.getByText('hello world')).toBeInTheDocument() expect(screen.getAllByText('formatted-1710000000')).toHaveLength(2) fireEvent.click(screen.getByText('hello world')) - expect(mockPush).toHaveBeenCalledWith('/apps/app-1/logs?page=2&conversation_id=conversation-1', { scroll: false }) - expect(screen.getByTestId('drawer')).toBeInTheDocument() + await waitFor(() => { + expect(onUrlUpdate).toHaveBeenCalled() + expect(screen.getByTestId('drawer')).toBeInTheDocument() + }) + + const update = onUrlUpdate.mock.calls.at(-1)![0] + expect(update.searchParams.get('page')).toBe('2') + expect(update.searchParams.get('conversation_id')).toBe('conversation-1') + expect(update.options.history).toBe('push') }) - it('should close the drawer, refresh, and clear modal flags', () => { - mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1') - - render( - , - ) + it('should close the drawer, refresh, and clear modal flags', async () => { + const { onUrlUpdate } = renderConversationList({ + searchParams: '?page=2&conversation_id=conversation-1', + }) fireEvent.click(screen.getByText('close-drawer')) @@ -308,11 +307,18 @@ describe('ConversationList', () => { expect(mockSetShowPromptLogModal).toHaveBeenCalledWith(false) expect(mockSetShowAgentLogModal).toHaveBeenCalledWith(false) expect(mockSetShowMessageLogModal).toHaveBeenCalledWith(false) - expect(mockReplace).toHaveBeenCalledWith('/apps/app-1/logs?page=2', { scroll: false }) + + await waitFor(() => { + expect(onUrlUpdate).toHaveBeenCalled() + }) + + const update = onUrlUpdate.mock.calls.at(-1)![0] + expect(update.searchParams.get('page')).toBe('2') + expect(update.searchParams.has('conversation_id')).toBe(false) + expect(update.options.history).toBe('replace') }) it('should render chat conversation details and submit feedback from the chat panel', async () => { - mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1') mockChatConversationDetail = { id: 'conversation-1', created_at: 1710000000, @@ -355,13 +361,9 @@ describe('ConversationList', () => { mockShowMessageLogModal = true mockCurrentLogItem = { id: 'log-1' } - render( - , - ) + renderConversationList({ + searchParams: '?page=2&conversation_id=conversation-1', + }) await waitFor(() => { expect(mockFetchChatMessages).toHaveBeenCalledWith({ @@ -396,7 +398,6 @@ describe('ConversationList', () => { }) it('should render completion details and refetch after feedback updates', async () => { - mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1') mockCompletionConversationDetail = { id: 'conversation-1', created_at: 1710000000, @@ -423,13 +424,11 @@ describe('ConversationList', () => { mockShowPromptLogModal = true mockCurrentLogItem = { id: 'log-2' } - render( - , - ) + renderConversationList({ + appDetail: { id: 'app-1', mode: AppModeEnum.COMPLETION } as any, + logs: createCompletionLogs() as any, + searchParams: '?page=2&conversation_id=conversation-1', + }) await waitFor(() => { expect(screen.getByTestId('text-generation')).toBeInTheDocument() @@ -454,64 +453,61 @@ describe('ConversationList', () => { }) it('should render chatflow status cells and feedback counters for advanced chat logs', () => { - render( - , - ) + renderConversationList({ + appDetail: { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } as any, + logs: { + data: [ + { + id: 'conversation-pending', + name: 'Pending row', + from_account_name: 'user-a', + read_at: 1710000001, + message_count: 3, + status_count: { paused: 1, success: 0, failed: 0, partial_success: 0 }, + user_feedback_stats: { like: 2, dislike: 0 }, + admin_feedback_stats: { like: 0, dislike: 1 }, + updated_at: 1710000000, + created_at: 1710000000, + }, + { + id: 'conversation-success', + name: 'Success row', + from_account_name: 'user-b', + read_at: 1710000001, + message_count: 4, + status_count: { paused: 0, success: 4, failed: 0, partial_success: 0 }, + user_feedback_stats: { like: 0, dislike: 0 }, + admin_feedback_stats: { like: 0, dislike: 0 }, + updated_at: 1710000000, + created_at: 1710000000, + }, + { + id: 'conversation-partial', + name: 'Partial row', + from_account_name: 'user-c', + read_at: 1710000001, + message_count: 5, + status_count: { paused: 0, success: 3, failed: 0, partial_success: 1 }, + user_feedback_stats: { like: 0, dislike: 0 }, + admin_feedback_stats: { like: 0, dislike: 0 }, + updated_at: 1710000000, + created_at: 1710000000, + }, + { + id: 'conversation-failure', + name: 'Failure row', + from_account_name: 'user-d', + read_at: 1710000001, + message_count: 1, + status_count: { paused: 0, success: 0, failed: 2, partial_success: 0 }, + user_feedback_stats: { like: 0, dislike: 0 }, + admin_feedback_stats: { like: 0, dislike: 0 }, + updated_at: 1710000000, + created_at: 1710000000, + }, + ], + } as any, + }) expect(screen.getByText('Pending')).toBeInTheDocument() expect(screen.getByText('Success')).toBeInTheDocument() @@ -522,7 +518,6 @@ describe('ConversationList', () => { }) it('should support annotation changes, modal closing, and paginated scroll loading in the detail drawer', async () => { - mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1') mockChatConversationDetail = { id: 'conversation-1', created_at: 1710000000, @@ -568,13 +563,9 @@ describe('ConversationList', () => { has_more: false, }) - render( - , - ) + renderConversationList({ + searchParams: '?page=2&conversation_id=conversation-1', + }) await waitFor(() => { expect(screen.getByTestId('chat-panel')).toBeInTheDocument() @@ -609,7 +600,6 @@ describe('ConversationList', () => { }) it('should close the prompt log modal from completion detail drawers', async () => { - mockSearchParams = new URLSearchParams('page=2&conversation_id=conversation-1') mockCompletionConversationDetail = { id: 'conversation-1', created_at: 1710000000, @@ -636,13 +626,11 @@ describe('ConversationList', () => { mockShowPromptLogModal = true mockCurrentLogItem = { id: 'log-2' } - render( - , - ) + renderConversationList({ + appDetail: { id: 'app-1', mode: AppModeEnum.COMPLETION } as any, + logs: createCompletionLogs() as any, + searchParams: '?page=2&conversation_id=conversation-1', + }) expect(await screen.findByTestId('prompt-log-modal')).toBeInTheDocument() diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 79323d34ab..01621e0d2a 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -13,6 +13,7 @@ import dayjs from 'dayjs' import timezone from 'dayjs/plugin/timezone' import utc from 'dayjs/plugin/utc' import { noop } from 'es-toolkit/function' +import { parseAsString, useQueryState } from 'nuqs' import * as React from 'react' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -33,7 +34,6 @@ import { WorkflowContextProvider } from '@/app/components/workflow/context' import { useAppContext } from '@/context/app-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useTimestamp from '@/hooks/use-timestamp' -import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { fetchChatMessages, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log' import { AppSourceType } from '@/service/share' import { useChatConversationDetail, useCompletionConversationDetail } from '@/service/use-log' @@ -46,7 +46,6 @@ import { applyAnnotationEdited, applyAnnotationRemoved, buildChatThreadState, - buildConversationUrl, getCompletionMessageFiles, getConversationRowValues, getDetailVarList, @@ -674,10 +673,7 @@ const ChatConversationDetailComp: FC<{ appId?: string, conversationId?: string } const ConversationList: FC = ({ logs, appDetail, onRefresh }) => { const { t } = useTranslation() const { formatTime } = useTimestamp() - const router = useRouter() - const pathname = usePathname() - const searchParams = useSearchParams() - const conversationIdInUrl = searchParams.get('conversation_id') ?? undefined + const [conversationIdInUrl, setConversationIdInUrl] = useQueryState('conversation_id', parseAsString) const media = useBreakpoints() const isMobile = media === MediaType.mobile @@ -697,8 +693,6 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) const activeConversationId = conversationIdInUrl ?? pendingConversationIdRef.current ?? currentConversation?.id - const buildUrlWithConversation = useCallback((conversationId?: string) => buildConversationUrl(pathname, searchParams.toString(), conversationId), [pathname, searchParams]) - const handleRowClick = useCallback((log: ConversationListItem) => { if (conversationIdInUrl === log.id) { if (!showDrawer) @@ -717,8 +711,8 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) if (currentConversation?.id !== log.id) setCurrentConversation(undefined) - router.push(buildUrlWithConversation(log.id), { scroll: false }) - }, [buildUrlWithConversation, conversationIdInUrl, currentConversation, router, showDrawer]) + void setConversationIdInUrl(log.id, { history: 'push' }) + }, [conversationIdInUrl, currentConversation, setConversationIdInUrl, showDrawer]) const currentConversationId = currentConversation?.id @@ -755,7 +749,7 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) if (pendingConversationCacheRef.current?.id === conversationIdInUrl || matchedConversation) pendingConversationCacheRef.current = undefined - }, [conversationIdInUrl, currentConversation, isChatMode, logs?.data, showDrawer]) + }, [conversationIdInUrl, currentConversation, currentConversationId, logs?.data, showDrawer]) const onCloseDrawer = useCallback(() => { onRefresh() @@ -769,8 +763,8 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) closingConversationIdRef.current = conversationIdInUrl ?? null if (conversationIdInUrl) - router.replace(buildUrlWithConversation(), { scroll: false }) - }, [buildUrlWithConversation, conversationIdInUrl, onRefresh, router, setShowAgentLogModal, setShowMessageLogModal, setShowPromptLogModal]) + void setConversationIdInUrl(null, { history: 'replace' }) + }, [conversationIdInUrl, onRefresh, setConversationIdInUrl, setShowAgentLogModal, setShowMessageLogModal, setShowPromptLogModal]) // Annotated data needs to be highlighted const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => { From d5ababfed0957909c312eb9fce2453039536ca7f Mon Sep 17 00:00:00 2001 From: Jonathan Chang <55106972+jonathanchang31@users.noreply.github.com> Date: Thu, 9 Apr 2026 01:14:48 -0500 Subject: [PATCH 13/14] refactor(api): deduplicate json serialization in AppModelConfig.from_model_config_dict (#34795) --- api/models/model.py | 64 ++++++++++++++------------------------------- 1 file changed, 20 insertions(+), 44 deletions(-) diff --git a/api/models/model.py b/api/models/model.py index 12865c4d22..ece3ff8b87 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -813,56 +813,32 @@ class AppModelConfig(TypeBase): "file_upload": self.file_upload_dict, } + @staticmethod + def _dump_optional(value: Any) -> str | None: + return json.dumps(value) if value else None + def from_model_config_dict(self, model_config: AppModelConfigDict): self.opening_statement = model_config.get("opening_statement") - self.suggested_questions = ( - json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None - ) - self.suggested_questions_after_answer = ( - json.dumps(model_config.get("suggested_questions_after_answer")) - if model_config.get("suggested_questions_after_answer") - else None - ) - self.speech_to_text = ( - json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None - ) - self.text_to_speech = ( - json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None - ) - self.more_like_this = ( - json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None - ) - self.sensitive_word_avoidance = ( - json.dumps(model_config.get("sensitive_word_avoidance")) - if model_config.get("sensitive_word_avoidance") - else None - ) - self.external_data_tools = ( - json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None - ) - self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None - self.user_input_form = ( - json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None + self.suggested_questions = self._dump_optional(model_config.get("suggested_questions")) + self.suggested_questions_after_answer = self._dump_optional( + model_config.get("suggested_questions_after_answer") ) + self.speech_to_text = self._dump_optional(model_config.get("speech_to_text")) + self.text_to_speech = self._dump_optional(model_config.get("text_to_speech")) + self.more_like_this = self._dump_optional(model_config.get("more_like_this")) + self.sensitive_word_avoidance = self._dump_optional(model_config.get("sensitive_word_avoidance")) + self.external_data_tools = self._dump_optional(model_config.get("external_data_tools")) + self.model = self._dump_optional(model_config.get("model")) + self.user_input_form = self._dump_optional(model_config.get("user_input_form")) self.dataset_query_variable = model_config.get("dataset_query_variable") self.pre_prompt = model_config.get("pre_prompt") - self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None - self.retriever_resource = ( - json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None - ) + self.agent_mode = self._dump_optional(model_config.get("agent_mode")) + self.retriever_resource = self._dump_optional(model_config.get("retriever_resource")) self.prompt_type = PromptType(model_config.get("prompt_type", "simple")) - self.chat_prompt_config = ( - json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None - ) - self.completion_prompt_config = ( - json.dumps(model_config.get("completion_prompt_config")) - if model_config.get("completion_prompt_config") - else None - ) - self.dataset_configs = ( - json.dumps(model_config.get("dataset_configs")) if model_config.get("dataset_configs") else None - ) - self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None + self.chat_prompt_config = self._dump_optional(model_config.get("chat_prompt_config")) + self.completion_prompt_config = self._dump_optional(model_config.get("completion_prompt_config")) + self.dataset_configs = self._dump_optional(model_config.get("dataset_configs")) + self.file_upload = self._dump_optional(model_config.get("file_upload")) return self From ec56f4e8399a6d9ef12309d46c9ffbce2fee9b0e Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 9 Apr 2026 14:44:28 +0800 Subject: [PATCH 14/14] fix(docker): restore S3_ADDRESS_STYLE env examples (#34826) --- api/.env.example | 1 + docker/.env.example | 1 + docker/docker-compose.yaml | 1 + 3 files changed, 3 insertions(+) diff --git a/api/.env.example b/api/.env.example index 2c1a755059..a04a18944a 100644 --- a/api/.env.example +++ b/api/.env.example @@ -109,6 +109,7 @@ S3_BUCKET_NAME=your-bucket-name S3_ACCESS_KEY=your-access-key S3_SECRET_KEY=your-secret-key S3_REGION=your-region +S3_ADDRESS_STYLE=auto # Workflow run and Conversation archive storage (S3-compatible) ARCHIVE_STORAGE_ENABLED=false diff --git a/docker/.env.example b/docker/.env.example index f6da6c568d..4426a882f1 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -469,6 +469,7 @@ S3_REGION=us-east-1 S3_BUCKET_NAME=difyai S3_ACCESS_KEY= S3_SECRET_KEY= +S3_ADDRESS_STYLE=auto # Whether to use AWS managed IAM roles for authenticating with the S3 service. # If set to false, the access key and secret key must be provided. S3_USE_AWS_MANAGED_IAM=false diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index dbadc58f89..1fc1cfdf9e 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -131,6 +131,7 @@ x-shared-env: &shared-api-worker-env S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai} S3_ACCESS_KEY: ${S3_ACCESS_KEY:-} S3_SECRET_KEY: ${S3_SECRET_KEY:-} + S3_ADDRESS_STYLE: ${S3_ADDRESS_STYLE:-auto} S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false} ARCHIVE_STORAGE_ENABLED: ${ARCHIVE_STORAGE_ENABLED:-false} ARCHIVE_STORAGE_ENDPOINT: ${ARCHIVE_STORAGE_ENDPOINT:-}