From bce61ff7ce53372e15d8fdd275268d302a1d19e2 Mon Sep 17 00:00:00 2001 From: FFXN <31929997+FFXN@users.noreply.github.com> Date: Tue, 28 Apr 2026 16:37:13 +0800 Subject: [PATCH] fix: hit-testing response failed because of Pydantic check. (#35640) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../console/datasets/hit_testing_base.py | 49 +++++++++++++++++- .../console/datasets/test_hit_testing_base.py | 36 +++++++++++++ .../service_api/dataset/test_hit_testing.py | 51 +++++++++++++++++++ 3 files changed, 135 insertions(+), 1 deletion(-) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 699fa599c8..71ab1513ed 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -38,6 +38,48 @@ class HitTestingPayload(BaseModel): class DatasetsHitTestingBase: + @staticmethod + def _normalize_hit_testing_query(query: Any) -> str: + """Return the user-visible query string from legacy and current response shapes.""" + if isinstance(query, str): + return query + + if isinstance(query, dict): + content = query.get("content") + if isinstance(content, str): + return content + + raise ValueError("Invalid hit testing query response") + + @staticmethod + def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]: + """Coerce nullable collection fields into lists before response validation.""" + if not isinstance(records, list): + return [] + + normalized_records: list[dict[str, Any]] = [] + for record in records: + if not isinstance(record, dict): + continue + + normalized_record = dict(record) + segment = normalized_record.get("segment") + if isinstance(segment, dict): + normalized_segment = dict(segment) + if normalized_segment.get("keywords") is None: + normalized_segment["keywords"] = [] + normalized_record["segment"] = normalized_segment + + if normalized_record.get("child_chunks") is None: + normalized_record["child_chunks"] = [] + + if normalized_record.get("files") is None: + normalized_record["files"] = [] + + normalized_records.append(normalized_record) + + return normalized_records + @staticmethod def get_and_validate_dataset(dataset_id: str): assert isinstance(current_user, Account) @@ -75,7 +117,12 @@ class DatasetsHitTestingBase: attachment_ids=args.get("attachment_ids"), limit=10, ) - return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} + return { + "query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")), + "records": DatasetsHitTestingBase._normalize_hit_testing_records( + marshal(response.get("records", []), hit_testing_record_fields) + ), + } except services.errors.index.IndexNotInitializedError: raise DatasetNotInitializedError() except ProviderTokenNotInitError as ex: diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index e4acd91b76..d29b34beb2 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -134,6 +134,42 @@ class TestPerformHitTesting: assert result["query"] == "hello" assert result["records"] == [] + def test_success_normalizes_legacy_query_and_nullable_list_fields(self, dataset): + response = { + "query": {"content": "hello"}, + "records": [ + { + "segment": {"id": "segment-1", "keywords": None}, + "child_chunks": None, + "files": None, + "score": 0.8, + } + ], + } + + with ( + patch.object( + HitTestingService, + "retrieve", + return_value=response, + ), + patch( + "controllers.console.datasets.hit_testing_base.marshal", + return_value=response["records"], + ), + ): + result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + assert result["query"] == "hello" + assert result["records"] == [ + { + "segment": {"id": "segment-1", "keywords": []}, + "child_chunks": [], + "files": [], + "score": 0.8, + } + ] + def test_index_not_initialized(self, dataset): with patch.object( HitTestingService, diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py index 95c2f5cf92..9be8e56f56 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py @@ -171,6 +171,57 @@ class TestHitTestingApiPost: assert passed_retrieval_model["search_method"] == "semantic_search" assert passed_retrieval_model["top_k"] == 10 + @patch("controllers.service_api.dataset.hit_testing.service_api_ns") + @patch("controllers.console.datasets.hit_testing_base.marshal") + @patch("controllers.console.datasets.hit_testing_base.HitTestingService") + @patch("controllers.console.datasets.hit_testing_base.DatasetService") + @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) + def test_post_normalizes_legacy_query_and_nullable_list_fields( + self, + mock_current_user, + mock_dataset_svc, + mock_hit_svc, + mock_marshal, + mock_ns, + app, + ): + """Test service API normalizes legacy query shape and nullable list fields.""" + dataset_id = str(uuid.uuid4()) + tenant_id = str(uuid.uuid4()) + + mock_dataset = Mock() + mock_dataset.id = dataset_id + + mock_dataset_svc.get_dataset.return_value = mock_dataset + mock_dataset_svc.check_dataset_permission.return_value = None + + mock_hit_svc.retrieve.return_value = {"query": {"content": "legacy query"}, "records": ["placeholder"]} + mock_hit_svc.hit_testing_args_check.return_value = None + mock_marshal.return_value = [ + { + "segment": {"id": "segment-1", "keywords": None}, + "child_chunks": None, + "files": None, + "score": 0.9, + } + ] + + mock_ns.payload = {"query": "legacy query"} + + with app.test_request_context(): + api = HitTestingApi() + response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) + + assert response["query"] == "legacy query" + assert response["records"] == [ + { + "segment": {"id": "segment-1", "keywords": []}, + "child_chunks": [], + "files": [], + "score": 0.9, + } + ] + @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))