refactor: pass session as parameter in knowledge_retrieval_inner_service and agent_app_feature_service (#37639)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Evan 2026-06-24 13:34:41 +08:00 committed by GitHub
parent 2cde7e4a94
commit 67a5eacf2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 52 additions and 39 deletions

View File

@ -91,7 +91,10 @@ class AgentAppFeatureConfigResource(Resource):
args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {})
new_app_model_config = AgentAppFeatureConfigService.update_features(
app_model=app_model, account=current_user, config=args.model_dump(exclude_none=True), session=db.session
app_model=app_model,
account=current_user,
config=args.model_dump(exclude_none=True),
session=db.session,
)
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)

View File

@ -14,6 +14,7 @@ from controllers.common.schema import register_response_schema_models, register_
from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import plugin_inner_api_only
from core.workflow.nodes.knowledge_retrieval import exc as retrieval_exc
from extensions.ext_database import db
from libs.exception import BaseHTTPException
from services.entities.knowledge_retrieval_inner import InnerKnowledgeRetrieveRequest, InnerKnowledgeRetrieveResponse
from services.errors.knowledge_retrieval import ExternalKnowledgeRetrievalError, InnerKnowledgeRetrievalServiceError
@ -81,7 +82,7 @@ class InnerKnowledgeRetrieveApi(Resource):
) from exc
try:
response = InnerKnowledgeRetrievalService().retrieve(payload)
response = InnerKnowledgeRetrievalService().retrieve(payload, session=db.session)
except InnerKnowledgeRetrievalServiceError as exc:
raise InnerKnowledgeRetrievalHttpError(
error_code=exc.error_code,

View File

@ -69,7 +69,12 @@ class AgentAppFeatureConfigService:
@classmethod
def update_features(
cls, *, app_model: App, account: Account, config: dict[str, Any], session: scoped_session
cls,
*,
app_model: App,
account: Account,
config: dict[str, Any],
session: scoped_session,
) -> AppModelConfig:
"""Persist the presentation features as a new app_model_config version.

View File

@ -13,11 +13,11 @@ of a separate validation error.
"""
from sqlalchemy import select
from sqlalchemy.orm import scoped_session
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
from extensions.ext_database import db
from graphon.model_runtime.utils.encoders import jsonable_encoder
from graphon.nodes.llm.entities import ModelConfig
from models.dataset import Dataset
@ -38,7 +38,11 @@ from services.errors.knowledge_retrieval import (
class InnerKnowledgeRetrievalService:
"""Validate inner caller scope and delegate to workflow dataset retrieval."""
def retrieve(self, request: InnerKnowledgeRetrieveRequest) -> InnerKnowledgeRetrieveResponse:
def retrieve(
self,
request: InnerKnowledgeRetrieveRequest,
session: scoped_session,
) -> InnerKnowledgeRetrieveResponse:
"""Run tenant-scoped retrieval for a trusted internal caller.
This method only rejects caller app existence/tenant mismatches and
@ -56,8 +60,8 @@ class InnerKnowledgeRetrievalService:
InnerKnowledgeRetrieveDatasetTenantMismatchError:
At least one requested dataset is outside the caller tenant.
"""
self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id)
self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids)
self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id, session=session)
self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids, session=session)
rag = DatasetRetrieval()
results = rag.knowledge_retrieval(request=self._to_rag_request(request))
@ -66,8 +70,8 @@ class InnerKnowledgeRetrievalService:
usage=InnerKnowledgeRetrieveUsage.model_validate(jsonable_encoder(rag.llm_usage)),
)
def _validate_caller_app(self, *, tenant_id: str, app_id: str) -> None:
app = db.session.scalar(select(App).where(App.id == app_id).limit(1))
def _validate_caller_app(self, *, tenant_id: str, app_id: str, session: scoped_session) -> None:
app = session.scalar(select(App).where(App.id == app_id).limit(1))
if app is None:
raise InnerKnowledgeRetrieveAppNotFoundError(f"App '{app_id}' not found")
if app.tenant_id != tenant_id:
@ -75,8 +79,8 @@ class InnerKnowledgeRetrievalService:
f"App '{app_id}' does not belong to tenant '{tenant_id}'"
)
def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str]) -> None:
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str], session: scoped_session) -> None:
datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
found_ids = {dataset.id for dataset in datasets}
missing_ids = sorted(set(dataset_ids) - found_ids)

View File

@ -74,14 +74,14 @@ def _build_source() -> Source:
class TestInnerKnowledgeRetrievalService:
@patch("services.knowledge_retrieval_inner_service.DatasetRetrieval")
@patch("services.knowledge_retrieval_inner_service.db")
def test_retrieve_maps_multiple_request_and_skips_enable_api_check(self, mock_db, mock_rag_cls):
def test_retrieve_maps_multiple_request_and_skips_enable_api_check(self, mock_rag_cls):
request = _build_request()
mock_session = MagicMock()
mock_app = MagicMock(id="app-1", tenant_id="tenant-1")
dataset_1 = MagicMock(id="dataset-1", tenant_id="tenant-1", enable_api=False)
dataset_2 = MagicMock(id="dataset-2", tenant_id="tenant-1", enable_api=True)
mock_db.session.scalar.return_value = mock_app
mock_db.session.scalars.return_value.all.return_value = [dataset_1, dataset_2]
mock_session.scalar.return_value = mock_app
mock_session.scalars.return_value.all.return_value = [dataset_1, dataset_2]
rag = MagicMock()
rag.knowledge_retrieval.return_value = [_build_source()]
@ -101,7 +101,7 @@ class TestInnerKnowledgeRetrievalService:
}
mock_rag_cls.return_value = rag
response = InnerKnowledgeRetrievalService().retrieve(request)
response = InnerKnowledgeRetrievalService().retrieve(request, mock_session)
rag_request = rag.knowledge_retrieval.call_args.kwargs["request"]
assert rag_request.tenant_id == "tenant-1"
@ -127,8 +127,7 @@ class TestInnerKnowledgeRetrievalService:
assert response.usage.currency == "USD"
@patch("services.knowledge_retrieval_inner_service.DatasetRetrieval")
@patch("services.knowledge_retrieval_inner_service.db")
def test_retrieve_maps_single_request(self, mock_db, mock_rag_cls):
def test_retrieve_maps_single_request(self, mock_rag_cls):
request = _build_request(
dataset_ids=["dataset-1"],
retrieval={
@ -151,8 +150,9 @@ class TestInnerKnowledgeRetrievalService:
},
attachment_ids=[],
)
mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")]
mock_session = MagicMock()
mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
mock_session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")]
rag = MagicMock()
rag.knowledge_retrieval.return_value = []
@ -172,7 +172,7 @@ class TestInnerKnowledgeRetrievalService:
}
mock_rag_cls.return_value = rag
InnerKnowledgeRetrievalService().retrieve(request)
InnerKnowledgeRetrievalService().retrieve(request, mock_session)
rag_request = rag.knowledge_retrieval.call_args.kwargs["request"]
assert rag_request.retrieval_mode == "single"
@ -184,35 +184,35 @@ class TestInnerKnowledgeRetrievalService:
assert rag_request.metadata_model_config is not None
assert rag_request.metadata_model_config.provider == "openai"
@patch("services.knowledge_retrieval_inner_service.db")
def test_retrieve_raises_when_app_missing(self, mock_db):
mock_db.session.scalar.return_value = None
def test_retrieve_raises_when_app_missing(self):
mock_session = MagicMock()
mock_session.scalar.return_value = None
with pytest.raises(InnerKnowledgeRetrieveAppNotFoundError):
InnerKnowledgeRetrievalService().retrieve(_build_request())
InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session)
@patch("services.knowledge_retrieval_inner_service.db")
def test_retrieve_raises_when_app_belongs_to_other_tenant(self, mock_db):
mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-2")
def test_retrieve_raises_when_app_belongs_to_other_tenant(self):
mock_session = MagicMock()
mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-2")
with pytest.raises(InnerKnowledgeRetrieveAppTenantMismatchError):
InnerKnowledgeRetrievalService().retrieve(_build_request())
InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session)
@patch("services.knowledge_retrieval_inner_service.db")
def test_retrieve_raises_when_dataset_missing(self, mock_db):
mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")]
def test_retrieve_raises_when_dataset_missing(self):
mock_session = MagicMock()
mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
mock_session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")]
with pytest.raises(InnerKnowledgeRetrieveDatasetNotFoundError):
InnerKnowledgeRetrievalService().retrieve(_build_request())
InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session)
@patch("services.knowledge_retrieval_inner_service.db")
def test_retrieve_raises_when_dataset_belongs_to_other_tenant(self, mock_db):
mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
mock_db.session.scalars.return_value.all.return_value = [
def test_retrieve_raises_when_dataset_belongs_to_other_tenant(self):
mock_session = MagicMock()
mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
mock_session.scalars.return_value.all.return_value = [
MagicMock(id="dataset-1", tenant_id="tenant-1"),
MagicMock(id="dataset-2", tenant_id="tenant-2"),
]
with pytest.raises(InnerKnowledgeRetrieveDatasetTenantMismatchError):
InnerKnowledgeRetrievalService().retrieve(_build_request())
InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session)