From 8ad131bb3badc8f66b687b4d629ffaf14fb505b4 Mon Sep 17 00:00:00 2001 From: aliworksx08 <57456290+aliworksx08@users.noreply.github.com> Date: Thu, 9 Apr 2026 09:15:59 -0500 Subject: [PATCH 01/14] refactor: migrate session.query to select API in file service (#34852) --- api/services/file_service.py | 14 ++++----- .../unit_tests/services/test_file_service.py | 30 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/api/services/file_service.py b/api/services/file_service.py index 50a326d813..7443ca3271 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -132,8 +132,8 @@ class FileService: return file_size <= file_size_limit def get_file_base64(self, file_id: str) -> str: - upload_file = ( - self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = self._session_maker(expire_on_commit=False).scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) ) if not upload_file: raise NotFound("File not found") @@ -178,7 +178,7 @@ class FileService: Return a short text preview extracted from a document file. """ with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found") @@ -200,7 +200,7 @@ class FileService: if not result: raise NotFound("File not found or signature is invalid") with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") @@ -220,7 +220,7 @@ class FileService: raise NotFound("File not found or signature is invalid") with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") @@ -231,7 +231,7 @@ class FileService: def get_public_image_preview(self, file_id: str): with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") @@ -247,7 +247,7 @@ class FileService: def get_file_content(self, file_id: str) -> str: with self._session_maker(expire_on_commit=False) as session: - upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file: UploadFile | None = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found") diff --git a/api/tests/unit_tests/services/test_file_service.py b/api/tests/unit_tests/services/test_file_service.py index b7259c3e82..8e1b22886b 100644 --- a/api/tests/unit_tests/services/test_file_service.py +++ b/api/tests/unit_tests/services/test_file_service.py @@ -165,7 +165,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "test_key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.storage") as mock_storage: mock_storage.load_once.return_value = b"test content" @@ -178,7 +178,7 @@ class TestFileService: mock_storage.load_once.assert_called_once_with("test_key") def test_get_file_base64_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_base64("non_existent") @@ -215,7 +215,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "pdf" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract: mock_extract.return_value = "Extracted text content" @@ -227,7 +227,7 @@ class TestFileService: assert result == "Extracted text content" def test_get_file_preview_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_preview("non_existent") @@ -235,7 +235,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "exe" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with pytest.raises(UnsupportedFileTypeError): file_service.get_file_preview("file_id") @@ -246,7 +246,7 @@ class TestFileService: upload_file.extension = "jpg" upload_file.mime_type = "image/jpeg" upload_file.key = "key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with ( patch("services.file_service.file_helpers.verify_image_signature") as mock_verify, @@ -269,7 +269,7 @@ class TestFileService: file_service.get_image_preview("file_id", "ts", "nonce", "sign") def test_get_image_preview_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(NotFound, match="File not found or signature is invalid"): @@ -279,7 +279,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "txt" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(UnsupportedFileTypeError): @@ -289,7 +289,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with ( patch("services.file_service.file_helpers.verify_file_signature") as mock_verify, @@ -309,7 +309,7 @@ class TestFileService: file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(NotFound, match="File not found or signature is invalid"): @@ -321,7 +321,7 @@ class TestFileService: upload_file.extension = "png" upload_file.mime_type = "image/png" upload_file.key = "key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.storage") as mock_storage: mock_storage.load.return_value = b"image content" @@ -330,7 +330,7 @@ class TestFileService: assert mime == "image/png" def test_get_public_image_preview_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_public_image_preview("file_id") @@ -338,7 +338,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "txt" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with pytest.raises(UnsupportedFileTypeError): file_service.get_public_image_preview("file_id") @@ -346,7 +346,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.storage") as mock_storage: mock_storage.load.return_value = b"hello world" @@ -354,7 +354,7 @@ class TestFileService: assert result == "hello world" def test_get_file_content_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_content("file_id") From e143dbce50bd4458c9fa24c093cdbf58533659f6 Mon Sep 17 00:00:00 2001 From: aliworksx08 <57456290+aliworksx08@users.noreply.github.com> Date: Thu, 9 Apr 2026 09:16:33 -0500 Subject: [PATCH 02/14] refactor: migrate session.query to select API in webhook service (#34849) --- api/services/trigger/webhook_service.py | 28 +++++++++---------- .../services/test_webhook_service.py | 14 +++++----- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 7b69ccfce7..bb767a6759 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -104,32 +104,32 @@ class WebhookService: """ with Session(db.engine) as session: # Get webhook trigger - webhook_trigger = ( - session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).first() + webhook_trigger = session.scalar( + select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).limit(1) ) if not webhook_trigger: raise ValueError(f"Webhook not found: {webhook_id}") if is_debug: - workflow = ( - session.query(Workflow) - .filter( + workflow = session.scalar( + select(Workflow) + .where( Workflow.app_id == webhook_trigger.app_id, Workflow.version == Workflow.VERSION_DRAFT, ) .order_by(Workflow.created_at.desc()) - .first() + .limit(1) ) else: # Check if the corresponding AppTrigger exists - app_trigger = ( - session.query(AppTrigger) - .filter( + app_trigger = session.scalar( + select(AppTrigger) + .where( AppTrigger.app_id == webhook_trigger.app_id, AppTrigger.node_id == webhook_trigger.node_id, AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK, ) - .first() + .limit(1) ) if not app_trigger: @@ -146,14 +146,14 @@ class WebhookService: raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}") # Get workflow - workflow = ( - session.query(Workflow) - .filter( + workflow = session.scalar( + select(Workflow) + .where( Workflow.app_id == webhook_trigger.app_id, Workflow.version != Workflow.VERSION_DRAFT, ) .order_by(Workflow.created_at.desc()) - .first() + .limit(1) ) if not workflow: raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}") diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 1b5252fc64..39693e3f4b 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -657,7 +657,7 @@ def _app(**kwargs: Any) -> App: def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None: # Arrange fake_session = MagicMock() - fake_session.query.return_value = _FakeQuery(None) + fake_session.scalar.return_value = None _patch_session(monkeypatch, fake_session) # Act / Assert @@ -671,7 +671,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_foun # Arrange webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)] + fake_session.scalar.side_effect = [webhook_trigger, None] _patch_session(monkeypatch, fake_session) # Act / Assert @@ -686,7 +686,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_lim webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED) fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)] + fake_session.scalar.side_effect = [webhook_trigger, app_trigger] _patch_session(monkeypatch, fake_session) # Act / Assert @@ -701,7 +701,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED) fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)] + fake_session.scalar.side_effect = [webhook_trigger, app_trigger] _patch_session(monkeypatch, fake_session) # Act / Assert @@ -714,7 +714,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(m webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED) fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)] + fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None] _patch_session(monkeypatch, fake_session) # Act / Assert @@ -732,7 +732,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mod workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}} fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)] + fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow] _patch_session(monkeypatch, fake_session) # Act @@ -751,7 +751,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(mo workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}} fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)] + fake_session.scalar.side_effect = [webhook_trigger, workflow] _patch_session(monkeypatch, fake_session) # Act From 75b88a54163c309b02a4aef1cf9abf640b24f62e Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Thu, 9 Apr 2026 09:17:08 -0500 Subject: [PATCH 03/14] refactor: migrate session.query to select API in deal dataset index update task (#34847) --- api/tasks/deal_dataset_index_update_task.py | 69 ++++++++++++--------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index fa844a8647..c9b5121a08 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task # type: ignore +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -26,43 +27,42 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): with session_factory.create_session() as session: try: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Dataset not found") index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "upgrade": - dataset_documents = ( - session.query(DatasetDocument) - .where( + dataset_documents = session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() for dataset_document in dataset_documents: try: # add from vector index - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] for segment in segments: @@ -81,32 +81,36 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): # clean keywords index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) index_processor.load(dataset, documents, with_keywords=False) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() elif action == "update": - dataset_documents = ( - session.query(DatasetDocument) - .where( + dataset_documents = session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() # add new index if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() @@ -116,15 +120,14 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): for dataset_document in dataset_documents: # update from vector index try: - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] multimodal_documents = [] @@ -173,13 +176,17 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): index_processor.load( dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False ) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() else: From 0a6494abfba2fa40eb05f4d45c30217c961a2597 Mon Sep 17 00:00:00 2001 From: Jonathan Chang <55106972+jonathanchang31@users.noreply.github.com> Date: Thu, 9 Apr 2026 09:24:39 -0500 Subject: [PATCH 04/14] refactor(api): deduplicate EnabledConfig property logic in AppModelConfig (#34793) --- api/models/model.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/api/models/model.py b/api/models/model.py index ece3ff8b87..6e6e390902 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -674,28 +674,24 @@ class AppModelConfig(TypeBase): def suggested_questions_list(self) -> list[str]: return json.loads(self.suggested_questions) if self.suggested_questions else [] + def _get_enabled_config(self, value: str | None, *, default_enabled: bool = False) -> EnabledConfig: + return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled}) + @property def suggested_questions_after_answer_dict(self) -> EnabledConfig: - return cast( - EnabledConfig, - json.loads(self.suggested_questions_after_answer) - if self.suggested_questions_after_answer - else {"enabled": False}, - ) + return self._get_enabled_config(self.suggested_questions_after_answer) @property def speech_to_text_dict(self) -> EnabledConfig: - return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}) + return self._get_enabled_config(self.speech_to_text) @property def text_to_speech_dict(self) -> EnabledConfig: - return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}) + return self._get_enabled_config(self.text_to_speech) @property def retriever_resource_dict(self) -> EnabledConfig: - return cast( - EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} - ) + return self._get_enabled_config(self.retriever_resource, default_enabled=True) @property def annotation_reply_dict(self) -> AnnotationReplyConfig: @@ -722,7 +718,7 @@ class AppModelConfig(TypeBase): @property def more_like_this_dict(self) -> EnabledConfig: - return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}) + return self._get_enabled_config(self.more_like_this) @property def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig: From b8858708be1fa1a8b7fae29f0eb1cb46bebeaa02 Mon Sep 17 00:00:00 2001 From: NVIDIAN Date: Thu, 9 Apr 2026 08:37:39 -0700 Subject: [PATCH 05/14] chore: remove commented-out reqparse code from rag_pipeline_workflow (#34860) Co-authored-by: ai-hpc --- .../rag_pipeline/rag_pipeline_workflow.py | 83 ------------------- 1 file changed, 83 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 6c02646c22..a8077d9eb0 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -346,89 +346,6 @@ class PublishedRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) -# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): -# @setup_required -# @login_required -# @account_initialization_required -# @get_rag_pipeline -# def post(self, pipeline: Pipeline, node_id: str): -# """ -# Run rag pipeline datasource -# """ -# # The role of the current user in the ta table must be admin, owner, or editor -# if not current_user.has_edit_permission: -# raise Forbidden() -# -# if not isinstance(current_user, Account): -# raise Forbidden() -# -# parser = (reqparse.RequestParser() -# .add_argument("job_id", type=str, required=True, nullable=False, location="json") -# .add_argument("datasource_type", type=str, required=True, location="json") -# ) -# args = parser.parse_args() -# -# job_id = args.get("job_id") -# if job_id == None: -# raise ValueError("missing job_id") -# datasource_type = args.get("datasource_type") -# if datasource_type == None: -# raise ValueError("missing datasource_type") -# -# rag_pipeline_service = RagPipelineService() -# result = rag_pipeline_service.run_datasource_workflow_node_status( -# pipeline=pipeline, -# node_id=node_id, -# job_id=job_id, -# account=current_user, -# datasource_type=datasource_type, -# is_published=True -# ) -# -# return result - - -# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): -# @setup_required -# @login_required -# @account_initialization_required -# @get_rag_pipeline -# def post(self, pipeline: Pipeline, node_id: str): -# """ -# Run rag pipeline datasource -# """ -# # The role of the current user in the ta table must be admin, owner, or editor -# if not current_user.has_edit_permission: -# raise Forbidden() -# -# if not isinstance(current_user, Account): -# raise Forbidden() -# -# parser = (reqparse.RequestParser() -# .add_argument("job_id", type=str, required=True, nullable=False, location="json") -# .add_argument("datasource_type", type=str, required=True, location="json") -# ) -# args = parser.parse_args() -# -# job_id = args.get("job_id") -# if job_id == None: -# raise ValueError("missing job_id") -# datasource_type = args.get("datasource_type") -# if datasource_type == None: -# raise ValueError("missing datasource_type") -# -# rag_pipeline_service = RagPipelineService() -# result = rag_pipeline_service.run_datasource_workflow_node_status( -# pipeline=pipeline, -# node_id=node_id, -# job_id=job_id, -# account=current_user, -# datasource_type=datasource_type, -# is_published=False -# ) -# -# return result -# @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//run") class RagPipelinePublishedDatasourceNodeRunApi(Resource): @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__]) From ab3b3056829965eec2334e77a9fc2724e58dec10 Mon Sep 17 00:00:00 2001 From: NVIDIAN Date: Thu, 9 Apr 2026 08:38:16 -0700 Subject: [PATCH 06/14] refactor: migrate web human_input_form from reqparse to Pydantic BaseModel (#34859) Co-authored-by: ai-hpc --- api/controllers/web/human_input_form.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index 36728a47d1..aff0b42d95 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -7,7 +7,8 @@ import logging from datetime import datetime from flask import Response, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden @@ -23,6 +24,12 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ logger = logging.getLogger(__name__) + +class HumanInputFormSubmitPayload(BaseModel): + inputs: dict + action: str + + _FORM_SUBMIT_RATE_LIMITER = RateLimiter( prefix="web_form_submit_rate_limit", max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, @@ -112,10 +119,7 @@ class HumanInputFormApi(Resource): "action": "Approve" } """ - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("action", type=str, required=True, location="json") - args = parser.parse_args() + payload = HumanInputFormSubmitPayload.model_validate(request.get_json()) ip_address = extract_remote_ip(request) if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address): @@ -135,8 +139,8 @@ class HumanInputFormApi(Resource): service.submit_form_by_token( recipient_type=recipient_type, form_token=form_token, - selected_action_id=args["action"], - form_data=args["inputs"], + selected_action_id=payload.action, + form_data=payload.inputs, submission_end_user_id=None, # submission_end_user_id=_end_user.id, ) From 4d57f04a264403fd003996b86dcd7c289dd112e4 Mon Sep 17 00:00:00 2001 From: NVIDIAN Date: Thu, 9 Apr 2026 08:38:47 -0700 Subject: [PATCH 07/14] refactor: migrate console human_input_form from reqparse to PydanticBaseModel (#34858) Co-authored-by: ai-hpc --- api/controllers/console/human_input_form.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 5d79e1b5e9..845af37365 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -7,7 +7,8 @@ import logging from collections.abc import Generator from flask import Response, jsonify, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -33,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream logger = logging.getLogger(__name__) +class HumanInputFormSubmitPayload(BaseModel): + inputs: dict + action: str + + def _jsonify_form_definition(form: Form) -> Response: payload = form.get_definition().model_dump() payload["expiration_time"] = int(form.expiration_time.timestamp()) @@ -84,10 +90,7 @@ class ConsoleHumanInputFormApi(Resource): "action": "Approve" } """ - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("action", type=str, required=True, location="json") - args = parser.parse_args() + payload = HumanInputFormSubmitPayload.model_validate(request.get_json()) current_user, _ = current_account_with_tenant() service = HumanInputService(db.engine) @@ -107,8 +110,8 @@ class ConsoleHumanInputFormApi(Resource): service.submit_form_by_token( recipient_type=recipient_type, form_token=form_token, - selected_action_id=args["action"], - form_data=args["inputs"], + selected_action_id=payload.action, + form_data=payload.inputs, submission_user_id=current_user.id, ) From 985e71ebf42ed431e660ff61e060ac34e10d45dc Mon Sep 17 00:00:00 2001 From: sxxtony <166789813+sxxtony@users.noreply.github.com> Date: Thu, 9 Apr 2026 08:41:29 -0700 Subject: [PATCH 08/14] refactor: migrate TrialApp and AccountTrialAppRecord to TypeBase (#34806) --- api/models/model.py | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/api/models/model.py b/api/models/model.py index 6e6e390902..d2ff8065e2 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -898,7 +898,7 @@ class InstalledApp(TypeBase): return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) -class TrialApp(Base): +class TrialApp(TypeBase): __tablename__ = "trial_apps" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="trial_app_pkey"), @@ -907,18 +907,26 @@ class TrialApp(Base): sa.UniqueConstraint("app_id", name="unique_trail_app_id"), ) - id = mapped_column(StringUUID, default=gen_uuidv4_string) - app_id = mapped_column(StringUUID, nullable=False) - tenant_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - trial_limit = mapped_column(sa.Integer, nullable=False, default=3) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + insert_default=func.current_timestamp(), + server_default=func.current_timestamp(), + init=False, + ) + trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3) @property def app(self) -> App | None: return db.session.scalar(select(App).where(App.id == self.app_id)) -class AccountTrialAppRecord(Base): +class AccountTrialAppRecord(TypeBase): __tablename__ = "account_trial_app_records" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"), @@ -926,11 +934,19 @@ class AccountTrialAppRecord(Base): sa.Index("account_trial_app_record_app_id_idx", "app_id"), sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"), ) - id = mapped_column(StringUUID, default=gen_uuidv4_string) - account_id = mapped_column(StringUUID, nullable=False) - app_id = mapped_column(StringUUID, nullable=False) - count = mapped_column(sa.Integer, nullable=False, default=0) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + insert_default=func.current_timestamp(), + server_default=func.current_timestamp(), + init=False, + ) @property def app(self) -> App | None: From 2352269ba94a66782e00a9738395e2df66fa1b48 Mon Sep 17 00:00:00 2001 From: YBoy <231405196+YB0y@users.noreply.github.com> Date: Fri, 10 Apr 2026 02:32:24 +0200 Subject: [PATCH 09/14] refactor(api): type recommend app database retrieval dicts with TypedDicts (#34873) --- .../database/database_retrieval.py | 59 ++++++++++++++----- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index 6fb90d356d..1df5fd13b6 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -1,3 +1,5 @@ +from typing import Any, TypedDict + from sqlalchemy import select from constants.languages import languages @@ -8,16 +10,43 @@ from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase from services.recommend_app.recommend_app_type import RecommendAppType +class RecommendedAppItemDict(TypedDict): + id: str + app: App | None + app_id: str + description: Any + copyright: Any + privacy_policy: Any + custom_disclaimer: str + category: str + position: int + is_listed: bool + + +class RecommendedAppsResultDict(TypedDict): + recommended_apps: list[RecommendedAppItemDict] + categories: list[str] + + +class RecommendedAppDetailDict(TypedDict): + id: str + name: str + icon: Any + icon_background: str | None + mode: str + export_data: str + + class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): """ Retrieval recommended app from database """ - def get_recommended_apps_and_categories(self, language: str): + def get_recommended_apps_and_categories(self, language: str) -> RecommendedAppsResultDict: result = self.fetch_recommended_apps_from_db(language) return result - def get_recommend_app_detail(self, app_id: str): + def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None: result = self.fetch_recommended_app_detail_from_db(app_id) return result @@ -25,7 +54,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): return RecommendAppType.DATABASE @classmethod - def fetch_recommended_apps_from_db(cls, language: str): + def fetch_recommended_apps_from_db(cls, language: str) -> RecommendedAppsResultDict: """ Fetch recommended apps from db. :param language: language @@ -41,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): ).all() categories = set() - recommended_apps_result = [] + recommended_apps_result: list[RecommendedAppItemDict] = [] for recommended_app in recommended_apps: app = recommended_app.app if not app or not app.is_public: @@ -51,7 +80,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): if not site: continue - recommended_app_result = { + recommended_app_result: RecommendedAppItemDict = { "id": recommended_app.id, "app": recommended_app.app, "app_id": recommended_app.app_id, @@ -67,10 +96,10 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): categories.add(recommended_app.category) - return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} + return RecommendedAppsResultDict(recommended_apps=recommended_apps_result, categories=sorted(categories)) @classmethod - def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None: + def fetch_recommended_app_detail_from_db(cls, app_id: str) -> RecommendedAppDetailDict | None: """ Fetch recommended app detail from db. :param app_id: App ID @@ -89,11 +118,11 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): if not app_model or not app_model.is_public: return None - return { - "id": app_model.id, - "name": app_model.name, - "icon": app_model.icon, - "icon_background": app_model.icon_background, - "mode": app_model.mode, - "export_data": AppDslService.export_dsl(app_model=app_model), - } + return RecommendedAppDetailDict( + id=app_model.id, + name=app_model.name, + icon=app_model.icon, + icon_background=app_model.icon_background, + mode=app_model.mode, + export_data=AppDslService.export_dsl(app_model=app_model), + ) From a31c1d2c69185324dec4931f2da0b7b75c209df9 Mon Sep 17 00:00:00 2001 From: dataCenter430 <161712630+dataCenter430@users.noreply.github.com> Date: Thu, 9 Apr 2026 17:33:23 -0700 Subject: [PATCH 10/14] refactor(api): type Celery SSL options and Sentinel transport dicts with TypedDicts (#34871) --- api/extensions/ext_celery.py | 50 ++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 1b3ccd1207..86b0550187 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -5,12 +5,30 @@ from typing import Any import pytz # type: ignore[import-untyped] from celery import Celery, Task from celery.schedules import crontab +from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp -def get_celery_ssl_options() -> dict[str, Any] | None: +class _CelerySentinelKwargsDict(TypedDict): + socket_timeout: float | None + password: str | None + + +class CelerySentinelTransportDict(TypedDict): + master_name: str | None + sentinel_kwargs: _CelerySentinelKwargsDict + + +class CelerySSLOptionsDict(TypedDict): + ssl_cert_reqs: int + ssl_ca_certs: str | None + ssl_certfile: str | None + ssl_keyfile: str | None + + +def get_celery_ssl_options() -> CelerySSLOptionsDict | None: """Get SSL configuration for Celery broker/backend connections.""" # Only apply SSL if we're using Redis as broker/backend if not dify_config.BROKER_USE_SSL: @@ -33,26 +51,24 @@ def get_celery_ssl_options() -> dict[str, Any] | None: ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) - ssl_options = { - "ssl_cert_reqs": ssl_cert_reqs, - "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, - "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, - "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, - } - - return ssl_options + return CelerySSLOptionsDict( + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS, + ssl_certfile=dify_config.REDIS_SSL_CERTFILE, + ssl_keyfile=dify_config.REDIS_SSL_KEYFILE, + ) -def get_celery_broker_transport_options() -> dict[str, Any]: +def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]: """Get broker transport options (e.g. Redis Sentinel) for Celery connections.""" if dify_config.CELERY_USE_SENTINEL: - return { - "master_name": dify_config.CELERY_SENTINEL_MASTER_NAME, - "sentinel_kwargs": { - "socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, - "password": dify_config.CELERY_SENTINEL_PASSWORD, - }, - } + return CelerySentinelTransportDict( + master_name=dify_config.CELERY_SENTINEL_MASTER_NAME, + sentinel_kwargs=_CelerySentinelKwargsDict( + socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, + password=dify_config.CELERY_SENTINEL_PASSWORD, + ), + ) return {} From c5c5c71d1548af652a0d9ce0cee225a65eabbc4e Mon Sep 17 00:00:00 2001 From: dataCenter430 <161712630+dataCenter430@users.noreply.github.com> Date: Thu, 9 Apr 2026 17:34:34 -0700 Subject: [PATCH 11/14] refactor(api): type OpenSearch/Lindorm/Huawei VDB config params dicts with TypedDicts (#34870) --- .../vdb/huawei/huawei_cloud_vector.py | 29 +++++++++++------ .../datasource/vdb/lindorm/lindorm_vector.py | 23 +++++++++----- .../vdb/opensearch/opensearch_vector.py | 31 ++++++++++++++----- 3 files changed, 59 insertions(+), 24 deletions(-) diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index df02c584ed..90d6d98c63 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -5,6 +5,7 @@ from typing import Any from elasticsearch import Elasticsearch from pydantic import BaseModel, model_validator +from typing_extensions import TypedDict from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -19,6 +20,16 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class HuaweiElasticsearchParamsDict(TypedDict, total=False): + hosts: list[str] + verify_certs: bool + ssl_show_warn: bool + request_timeout: int + retry_on_timeout: bool + max_retries: int + basic_auth: tuple[str, str] + + def create_ssl_context() -> ssl.SSLContext: ssl_context = ssl.create_default_context() ssl_context.check_hostname = False @@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel): raise ValueError("config HOSTS is required") return values - def to_elasticsearch_params(self) -> dict[str, Any]: - params = { - "hosts": self.hosts.split(","), - "verify_certs": False, - "ssl_show_warn": False, - "request_timeout": 30000, - "retry_on_timeout": True, - "max_retries": 10, - } + def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict: + params = HuaweiElasticsearchParamsDict( + hosts=self.hosts.split(","), + verify_certs=False, + ssl_show_warn=False, + request_timeout=30000, + retry_on_timeout=True, + max_retries=10, + ) if self.username and self.password: params["basic_auth"] = (self.username, self.password) return params diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index bfcb620618..fbe0bcad02 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator from tenacity import retry, stop_after_attempt, wait_exponential +from typing_extensions import TypedDict from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field" UGC_INDEX_PREFIX = "ugc_index" +class LindormOpenSearchParamsDict(TypedDict, total=False): + hosts: str | None + use_ssl: bool + pool_maxsize: int + timeout: int + http_auth: tuple[str, str] + + class LindormVectorStoreConfig(BaseModel): hosts: str | None username: str | None = None @@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel): raise ValueError("config PASSWORD is required") return values - def to_opensearch_params(self) -> dict[str, Any]: - params: dict[str, Any] = { - "hosts": self.hosts, - "use_ssl": False, - "pool_maxsize": 128, - "timeout": 30, - } + def to_opensearch_params(self) -> LindormOpenSearchParamsDict: + params = LindormOpenSearchParamsDict( + hosts=self.hosts, + use_ssl=False, + pool_maxsize=128, + timeout=30, + ) if self.username and self.password: params["http_auth"] = (self.username, self.password) return params diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 2f77776807..50d18cdc4c 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -6,6 +6,7 @@ from uuid import uuid4 from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator +from typing_extensions import TypedDict from configs import dify_config from configs.middleware.vdb.opensearch_config import AuthMethod @@ -21,6 +22,20 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class _OpenSearchHostDict(TypedDict): + host: str + port: int + + +class OpenSearchParamsDict(TypedDict, total=False): + hosts: list[_OpenSearchHostDict] + use_ssl: bool + verify_certs: bool + connection_class: type + pool_maxsize: int + http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth + + class OpenSearchConfig(BaseModel): host: str port: int @@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel): service=self.aws_service, # type: ignore[arg-type] ) - def to_opensearch_params(self) -> dict[str, Any]: - params = { - "hosts": [{"host": self.host, "port": self.port}], - "use_ssl": self.secure, - "verify_certs": self.verify_certs, - "connection_class": Urllib3HttpConnection, - "pool_maxsize": 20, - } + def to_opensearch_params(self) -> OpenSearchParamsDict: + params = OpenSearchParamsDict( + hosts=[{"host": self.host, "port": self.port}], + use_ssl=self.secure, + verify_certs=self.verify_certs, + connection_class=Urllib3HttpConnection, + pool_maxsize=20, + ) if self.auth_method == "basic": logger.info("Using basic authentication for OpenSearch Vector DB") From 1117b6e72d7647fdb51c6e91a72613ed4b6c63be Mon Sep 17 00:00:00 2001 From: dataCenter430 <161712630+dataCenter430@users.noreply.github.com> Date: Thu, 9 Apr 2026 17:35:12 -0700 Subject: [PATCH 12/14] refactor: convert appmode misc if/elif to match/case (#30001) (#34869) --- api/core/memory/token_buffer_memory.py | 39 ++++----- api/models/utils/file_input_compat.py | 15 ++-- api/services/app_dsl_service.py | 110 +++++++++++++------------ 3 files changed, 87 insertions(+), 77 deletions(-) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 09c84538a9..5809d6f74a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -61,27 +61,28 @@ class TokenBufferMemory: :param is_user_message: whether this is a user message :return: PromptMessage """ - if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: - file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) - elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - app = self.conversation.app - if not app: - raise ValueError("App not found for conversation") + match self.conversation.mode: + case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT: + file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) + case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW: + app = self.conversation.app + if not app: + raise ValueError("App not found for conversation") - if not message.workflow_run_id: - raise ValueError("Workflow run ID not found") + if not message.workflow_run_id: + raise ValueError("Workflow run ID not found") - workflow_run = self.workflow_run_repo.get_workflow_run_by_id( - tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id - ) - if not workflow_run: - raise ValueError(f"Workflow run not found: {message.workflow_run_id}") - workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) - if not workflow: - raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - else: - raise AssertionError(f"Invalid app mode: {self.conversation.mode}") + workflow_run = self.workflow_run_repo.get_workflow_run_by_id( + tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id + ) + if not workflow_run: + raise ValueError(f"Workflow run not found: {message.workflow_run_id}") + workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + case _: + raise AssertionError(f"Invalid app mode: {self.conversation.mode}") detail = ImagePromptMessageContent.DETAIL.HIGH if file_extra_config and app_record: diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index f71583c1cd..8b767779ce 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -66,12 +66,15 @@ def build_file_from_stored_mapping( record_id = resolve_file_record_id(mapping) transfer_method = FileTransferMethod.value_of(mapping["transfer_method"]) - if transfer_method == FileTransferMethod.TOOL_FILE and record_id: - mapping["tool_file_id"] = record_id - elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id: - mapping["upload_file_id"] = record_id - elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id: - mapping["datasource_file_id"] = record_id + match transfer_method: + case FileTransferMethod.TOOL_FILE if record_id: + mapping["tool_file_id"] = record_id + case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL if record_id: + mapping["upload_file_id"] = record_id + case FileTransferMethod.DATASOURCE_FILE if record_id: + mapping["datasource_file_id"] = record_id + case _: + pass if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: remote_url = mapping.get("remote_url") diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index c6c8a15109..40e1e5f8ab 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -467,61 +467,67 @@ class AppDslService: ) # Initialize app based on mode - if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - workflow_data = data.get("workflow") - if not workflow_data or not isinstance(workflow_data, dict): - raise ValueError("Missing workflow data for workflow/advanced chat app") + match app_mode: + case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW: + workflow_data = data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise ValueError("Missing workflow data for workflow/advanced chat app") - environment_variables_list = workflow_data.get("environment_variables", []) - environment_variables = [ - variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list - ] - conversation_variables_list = workflow_data.get("conversation_variables", []) - conversation_variables = [ - variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list - ] + environment_variables_list = workflow_data.get("environment_variables", []) + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables", []) + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) + for obj in conversation_variables_list + ] - workflow_service = WorkflowService() - current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) - if current_draft_workflow: - unique_hash = current_draft_workflow.unique_hash - else: - unique_hash = None - graph = workflow_data.get("graph", {}) - for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: - dataset_ids = node["data"].get("dataset_ids", []) - node["data"]["dataset_ids"] = [ - decrypted_id - for dataset_id in dataset_ids - if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id)) - ] - workflow_service.sync_draft_workflow( - app_model=app, - graph=workflow_data.get("graph", {}), - features=workflow_data.get("features", {}), - unique_hash=unique_hash, - account=account, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - ) - elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: - # Initialize model config - model_config = data.get("model_config") - if not model_config or not isinstance(model_config, dict): - raise ValueError("Missing model_config for chat/agent-chat/completion app") - # Initialize or update model config - if not app.app_model_config: - app_model_config = AppModelConfig( - app_id=app.id, created_by=account.id, updated_by=account.id - ).from_model_config_dict(cast(AppModelConfigDict, model_config)) - app_model_config.id = str(uuid4()) - app.app_model_config_id = app_model_config.id + workflow_service = WorkflowService() + current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) + if current_draft_workflow: + unique_hash = current_draft_workflow.unique_hash + else: + unique_hash = None + graph = workflow_data.get("graph", {}) + for node in graph.get("nodes", []): + if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + decrypted_id + for dataset_id in dataset_ids + if ( + decrypted_id := self.decrypt_dataset_id( + encrypted_data=dataset_id, tenant_id=app.tenant_id + ) + ) + ] + workflow_service.sync_draft_workflow( + app_model=app, + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), + unique_hash=unique_hash, + account=account, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + case AppMode.CHAT | AppMode.AGENT_CHAT | AppMode.COMPLETION: + # Initialize model config + model_config = data.get("model_config") + if not model_config or not isinstance(model_config, dict): + raise ValueError("Missing model_config for chat/agent-chat/completion app") + # Initialize or update model config + if not app.app_model_config: + app_model_config = AppModelConfig( + app_id=app.id, created_by=account.id, updated_by=account.id + ).from_model_config_dict(cast(AppModelConfigDict, model_config)) + app_model_config.id = str(uuid4()) + app.app_model_config_id = app_model_config.id - self._session.add(app_model_config) - app_model_config_was_updated.send(app, app_model_config=app_model_config) - else: - raise ValueError("Invalid app mode") + self._session.add(app_model_config) + app_model_config_was_updated.send(app, app_model_config=app_model_config) + case _: + raise ValueError("Invalid app mode") return app @classmethod From d50f096b14fea8c7b4b3c793cb1d775a3d5db7d5 Mon Sep 17 00:00:00 2001 From: Jean Ibarz Date: Fri, 10 Apr 2026 03:28:57 +0200 Subject: [PATCH 13/14] =?UTF-8?q?fix(mcp):=20catch=20JSONDecodeError=20in?= =?UTF-8?q?=20OAuth=20discovery=20functions=20=F0=9F=A4=96=F0=9F=A4=96?= =?UTF-8?q?=F0=9F=A4=96=20(#34868)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/mcp/auth/auth_flow.py | 6 ++-- .../core/mcp/auth/test_auth_flow.py | 35 +++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index d015769b54..1d8356acf6 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -146,7 +146,7 @@ def discover_protected_resource_metadata( return ProtectedResourceMetadata.model_validate(response.json()) elif response.status_code == 404: continue # Try next URL - except (RequestError, ValidationError): + except (RequestError, ValidationError, json.JSONDecodeError): continue # Try next URL return None @@ -166,7 +166,7 @@ def discover_oauth_authorization_server_metadata( return OAuthMetadata.model_validate(response.json()) elif response.status_code == 404: continue # Try next URL - except (RequestError, ValidationError): + except (RequestError, ValidationError, json.JSONDecodeError): continue # Try next URL return None @@ -276,7 +276,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: else: return False, "" return False, "" - except RequestError: + except (RequestError, json.JSONDecodeError, IndexError): # Not support resource discovery, fall back to well-known OAuth metadata return False, "" diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py index fe533e62af..1f5fdd2657 100644 --- a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -862,6 +862,15 @@ class TestAuthOrchestration: result = discover_protected_resource_metadata(None, "https://api.example.com") assert result is None + # JSONDecodeError (non-JSON 200 response) + mock_get.side_effect = None + bad_json_response = Mock() + bad_json_response.status_code = 200 + bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_get.return_value = bad_json_response + result = discover_protected_resource_metadata(None, "https://api.example.com") + assert result is None + @patch("core.helper.ssrf_proxy.get") def test_discover_oauth_authorization_server_metadata(self, mock_get): # Success @@ -892,6 +901,14 @@ class TestAuthOrchestration: result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") assert result is None + # JSONDecodeError (non-JSON 200 response) + bad_json_response = Mock() + bad_json_response.status_code = 200 + bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_get.return_value = bad_json_response + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") + assert result is None + def test_get_effective_scope(self): prm = ProtectedResourceMetadata( resource="https://api.example.com", @@ -997,6 +1014,24 @@ class TestAuthOrchestration: supported, url = check_support_resource_discovery("https://api") assert supported is False + # Case 6: JSONDecodeError (non-JSON 200 response) + mock_get.side_effect = None + bad_json_res = Mock() + bad_json_res.status_code = 200 + bad_json_res.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_get.return_value = bad_json_res + supported, url = check_support_resource_discovery("https://api") + assert supported is False + assert url == "" + + # Case 7: Empty authorization_servers array (IndexError) + empty_res = Mock() + empty_res.status_code = 200 + empty_res.json.return_value = {"authorization_servers": []} + mock_get.return_value = empty_res + supported, url = check_support_resource_discovery("https://api") + assert supported is False + def test_discover_oauth_metadata(self): with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm: with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm: From 40e23ce8dc4d721328bc76287cc2488cf659d8ff Mon Sep 17 00:00:00 2001 From: volcano303 <75143900+volcano303@users.noreply.github.com> Date: Fri, 10 Apr 2026 03:47:59 +0200 Subject: [PATCH 14/14] refactor(api): type DatasourceProviderApiEntity.to_dict with TypedDict (#34879) --- api/core/datasource/entities/api_entities.py | 26 +++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 14d1af2e8b..5f90ba067c 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ -from typing import Literal, Optional +from typing import Any, Literal, Optional, TypedDict from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter -from core.tools.entities.common_entities import I18nObject +from core.tools.entities.common_entities import I18nObject, I18nObjectDict class DatasourceApiEntity(BaseModel): @@ -20,6 +20,23 @@ class DatasourceApiEntity(BaseModel): ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] +class DatasourceProviderApiEntityDict(TypedDict): + id: str + author: str + name: str + plugin_id: str | None + plugin_unique_identifier: str | None + description: I18nObjectDict + icon: str | dict + label: I18nObjectDict + type: str + team_credentials: dict | None + is_team_authorization: bool + allow_delete: bool + datasources: list[Any] + labels: list[str] + + class DatasourceProviderApiEntity(BaseModel): id: str author: str @@ -42,7 +59,7 @@ class DatasourceProviderApiEntity(BaseModel): def convert_none_to_empty_list(cls, v): return v if v is not None else [] - def to_dict(self) -> dict: + def to_dict(self) -> DatasourceProviderApiEntityDict: # ------------- # overwrite datasource parameter types for temp fix datasources = jsonable_encoder(self.datasources) @@ -53,7 +70,7 @@ class DatasourceProviderApiEntity(BaseModel): parameter["type"] = "files" # ------------- - return { + result: DatasourceProviderApiEntityDict = { "id": self.id, "author": self.author, "name": self.name, @@ -69,3 +86,4 @@ class DatasourceProviderApiEntity(BaseModel): "datasources": datasources, "labels": self.labels, } + return result