From 67a5eacf2db4d85f0d02ac9d487210b87961fe43 Mon Sep 17 00:00:00 2001 From: Evan <2869018789@qq.com> Date: Wed, 24 Jun 2026 13:34:41 +0800 Subject: [PATCH] refactor: pass session as parameter in knowledge_retrieval_inner_service and agent_app_feature_service (#37639) Co-authored-by: Asuka Minato --- .../console/app/agent_app_feature.py | 5 +- .../inner_api/knowledge/retrieval.py | 3 +- api/services/agent_app_feature_service.py | 7 ++- .../knowledge_retrieval_inner_service.py | 20 ++++--- .../test_knowledge_retrieval_inner_service.py | 56 +++++++++---------- 5 files changed, 52 insertions(+), 39 deletions(-) diff --git a/api/controllers/console/app/agent_app_feature.py b/api/controllers/console/app/agent_app_feature.py index d155dae6ac3..358e552beb0 100644 --- a/api/controllers/console/app/agent_app_feature.py +++ b/api/controllers/console/app/agent_app_feature.py @@ -91,7 +91,10 @@ class AgentAppFeatureConfigResource(Resource): args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {}) new_app_model_config = AgentAppFeatureConfigService.update_features( - app_model=app_model, account=current_user, config=args.model_dump(exclude_none=True), session=db.session + app_model=app_model, + account=current_user, + config=args.model_dump(exclude_none=True), + session=db.session, ) app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) diff --git a/api/controllers/inner_api/knowledge/retrieval.py b/api/controllers/inner_api/knowledge/retrieval.py index 1c1320fde42..e34dedea286 100644 --- a/api/controllers/inner_api/knowledge/retrieval.py +++ b/api/controllers/inner_api/knowledge/retrieval.py @@ -14,6 +14,7 @@ from controllers.common.schema import register_response_schema_models, register_ from controllers.inner_api import inner_api_ns from controllers.inner_api.wraps import plugin_inner_api_only from core.workflow.nodes.knowledge_retrieval import exc as retrieval_exc +from extensions.ext_database import db from libs.exception import BaseHTTPException from services.entities.knowledge_retrieval_inner import InnerKnowledgeRetrieveRequest, InnerKnowledgeRetrieveResponse from services.errors.knowledge_retrieval import ExternalKnowledgeRetrievalError, InnerKnowledgeRetrievalServiceError @@ -81,7 +82,7 @@ class InnerKnowledgeRetrieveApi(Resource): ) from exc try: - response = InnerKnowledgeRetrievalService().retrieve(payload) + response = InnerKnowledgeRetrievalService().retrieve(payload, session=db.session) except InnerKnowledgeRetrievalServiceError as exc: raise InnerKnowledgeRetrievalHttpError( error_code=exc.error_code, diff --git a/api/services/agent_app_feature_service.py b/api/services/agent_app_feature_service.py index b8e98653c8e..5fd794bb10f 100644 --- a/api/services/agent_app_feature_service.py +++ b/api/services/agent_app_feature_service.py @@ -69,7 +69,12 @@ class AgentAppFeatureConfigService: @classmethod def update_features( - cls, *, app_model: App, account: Account, config: dict[str, Any], session: scoped_session + cls, + *, + app_model: App, + account: Account, + config: dict[str, Any], + session: scoped_session, ) -> AppModelConfig: """Persist the presentation features as a new app_model_config version. diff --git a/api/services/knowledge_retrieval_inner_service.py b/api/services/knowledge_retrieval_inner_service.py index fccc81c4a29..8759413f533 100644 --- a/api/services/knowledge_retrieval_inner_service.py +++ b/api/services/knowledge_retrieval_inner_service.py @@ -13,11 +13,11 @@ of a separate validation error. """ from sqlalchemy import select +from sqlalchemy.orm import scoped_session from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest -from extensions.ext_database import db from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.nodes.llm.entities import ModelConfig from models.dataset import Dataset @@ -38,7 +38,11 @@ from services.errors.knowledge_retrieval import ( class InnerKnowledgeRetrievalService: """Validate inner caller scope and delegate to workflow dataset retrieval.""" - def retrieve(self, request: InnerKnowledgeRetrieveRequest) -> InnerKnowledgeRetrieveResponse: + def retrieve( + self, + request: InnerKnowledgeRetrieveRequest, + session: scoped_session, + ) -> InnerKnowledgeRetrieveResponse: """Run tenant-scoped retrieval for a trusted internal caller. This method only rejects caller app existence/tenant mismatches and @@ -56,8 +60,8 @@ class InnerKnowledgeRetrievalService: InnerKnowledgeRetrieveDatasetTenantMismatchError: At least one requested dataset is outside the caller tenant. """ - self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id) - self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids) + self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id, session=session) + self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids, session=session) rag = DatasetRetrieval() results = rag.knowledge_retrieval(request=self._to_rag_request(request)) @@ -66,8 +70,8 @@ class InnerKnowledgeRetrievalService: usage=InnerKnowledgeRetrieveUsage.model_validate(jsonable_encoder(rag.llm_usage)), ) - def _validate_caller_app(self, *, tenant_id: str, app_id: str) -> None: - app = db.session.scalar(select(App).where(App.id == app_id).limit(1)) + def _validate_caller_app(self, *, tenant_id: str, app_id: str, session: scoped_session) -> None: + app = session.scalar(select(App).where(App.id == app_id).limit(1)) if app is None: raise InnerKnowledgeRetrieveAppNotFoundError(f"App '{app_id}' not found") if app.tenant_id != tenant_id: @@ -75,8 +79,8 @@ class InnerKnowledgeRetrievalService: f"App '{app_id}' does not belong to tenant '{tenant_id}'" ) - def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str]) -> None: - datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() + def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str], session: scoped_session) -> None: + datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() found_ids = {dataset.id for dataset in datasets} missing_ids = sorted(set(dataset_ids) - found_ids) diff --git a/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py b/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py index 287d787ad70..7a8efe85f13 100644 --- a/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py +++ b/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py @@ -74,14 +74,14 @@ def _build_source() -> Source: class TestInnerKnowledgeRetrievalService: @patch("services.knowledge_retrieval_inner_service.DatasetRetrieval") - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_maps_multiple_request_and_skips_enable_api_check(self, mock_db, mock_rag_cls): + def test_retrieve_maps_multiple_request_and_skips_enable_api_check(self, mock_rag_cls): request = _build_request() + mock_session = MagicMock() mock_app = MagicMock(id="app-1", tenant_id="tenant-1") dataset_1 = MagicMock(id="dataset-1", tenant_id="tenant-1", enable_api=False) dataset_2 = MagicMock(id="dataset-2", tenant_id="tenant-1", enable_api=True) - mock_db.session.scalar.return_value = mock_app - mock_db.session.scalars.return_value.all.return_value = [dataset_1, dataset_2] + mock_session.scalar.return_value = mock_app + mock_session.scalars.return_value.all.return_value = [dataset_1, dataset_2] rag = MagicMock() rag.knowledge_retrieval.return_value = [_build_source()] @@ -101,7 +101,7 @@ class TestInnerKnowledgeRetrievalService: } mock_rag_cls.return_value = rag - response = InnerKnowledgeRetrievalService().retrieve(request) + response = InnerKnowledgeRetrievalService().retrieve(request, mock_session) rag_request = rag.knowledge_retrieval.call_args.kwargs["request"] assert rag_request.tenant_id == "tenant-1" @@ -127,8 +127,7 @@ class TestInnerKnowledgeRetrievalService: assert response.usage.currency == "USD" @patch("services.knowledge_retrieval_inner_service.DatasetRetrieval") - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_maps_single_request(self, mock_db, mock_rag_cls): + def test_retrieve_maps_single_request(self, mock_rag_cls): request = _build_request( dataset_ids=["dataset-1"], retrieval={ @@ -151,8 +150,9 @@ class TestInnerKnowledgeRetrievalService: }, attachment_ids=[], ) - mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") - mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") + mock_session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] rag = MagicMock() rag.knowledge_retrieval.return_value = [] @@ -172,7 +172,7 @@ class TestInnerKnowledgeRetrievalService: } mock_rag_cls.return_value = rag - InnerKnowledgeRetrievalService().retrieve(request) + InnerKnowledgeRetrievalService().retrieve(request, mock_session) rag_request = rag.knowledge_retrieval.call_args.kwargs["request"] assert rag_request.retrieval_mode == "single" @@ -184,35 +184,35 @@ class TestInnerKnowledgeRetrievalService: assert rag_request.metadata_model_config is not None assert rag_request.metadata_model_config.provider == "openai" - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_raises_when_app_missing(self, mock_db): - mock_db.session.scalar.return_value = None + def test_retrieve_raises_when_app_missing(self): + mock_session = MagicMock() + mock_session.scalar.return_value = None with pytest.raises(InnerKnowledgeRetrieveAppNotFoundError): - InnerKnowledgeRetrievalService().retrieve(_build_request()) + InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session) - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_raises_when_app_belongs_to_other_tenant(self, mock_db): - mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-2") + def test_retrieve_raises_when_app_belongs_to_other_tenant(self): + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-2") with pytest.raises(InnerKnowledgeRetrieveAppTenantMismatchError): - InnerKnowledgeRetrievalService().retrieve(_build_request()) + InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session) - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_raises_when_dataset_missing(self, mock_db): - mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") - mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] + def test_retrieve_raises_when_dataset_missing(self): + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") + mock_session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] with pytest.raises(InnerKnowledgeRetrieveDatasetNotFoundError): - InnerKnowledgeRetrievalService().retrieve(_build_request()) + InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session) - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_raises_when_dataset_belongs_to_other_tenant(self, mock_db): - mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") - mock_db.session.scalars.return_value.all.return_value = [ + def test_retrieve_raises_when_dataset_belongs_to_other_tenant(self): + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") + mock_session.scalars.return_value.all.return_value = [ MagicMock(id="dataset-1", tenant_id="tenant-1"), MagicMock(id="dataset-2", tenant_id="tenant-2"), ] with pytest.raises(InnerKnowledgeRetrieveDatasetTenantMismatchError): - InnerKnowledgeRetrievalService().retrieve(_build_request()) + InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session)