mirror of
https://github.com/langgenius/dify.git
synced 2026-06-24 21:11:16 +08:00
refactor: pass session as parameter in knowledge_retrieval_inner_service and agent_app_feature_service (#37639)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
parent
2cde7e4a94
commit
67a5eacf2d
@ -91,7 +91,10 @@ class AgentAppFeatureConfigResource(Resource):
|
||||
args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
new_app_model_config = AgentAppFeatureConfigService.update_features(
|
||||
app_model=app_model, account=current_user, config=args.model_dump(exclude_none=True), session=db.session
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
config=args.model_dump(exclude_none=True),
|
||||
session=db.session,
|
||||
)
|
||||
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
||||
@ -14,6 +14,7 @@ from controllers.common.schema import register_response_schema_models, register_
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||
from core.workflow.nodes.knowledge_retrieval import exc as retrieval_exc
|
||||
from extensions.ext_database import db
|
||||
from libs.exception import BaseHTTPException
|
||||
from services.entities.knowledge_retrieval_inner import InnerKnowledgeRetrieveRequest, InnerKnowledgeRetrieveResponse
|
||||
from services.errors.knowledge_retrieval import ExternalKnowledgeRetrievalError, InnerKnowledgeRetrievalServiceError
|
||||
@ -81,7 +82,7 @@ class InnerKnowledgeRetrieveApi(Resource):
|
||||
) from exc
|
||||
|
||||
try:
|
||||
response = InnerKnowledgeRetrievalService().retrieve(payload)
|
||||
response = InnerKnowledgeRetrievalService().retrieve(payload, session=db.session)
|
||||
except InnerKnowledgeRetrievalServiceError as exc:
|
||||
raise InnerKnowledgeRetrievalHttpError(
|
||||
error_code=exc.error_code,
|
||||
|
||||
@ -69,7 +69,12 @@ class AgentAppFeatureConfigService:
|
||||
|
||||
@classmethod
|
||||
def update_features(
|
||||
cls, *, app_model: App, account: Account, config: dict[str, Any], session: scoped_session
|
||||
cls,
|
||||
*,
|
||||
app_model: App,
|
||||
account: Account,
|
||||
config: dict[str, Any],
|
||||
session: scoped_session,
|
||||
) -> AppModelConfig:
|
||||
"""Persist the presentation features as a new app_model_config version.
|
||||
|
||||
|
||||
@ -13,11 +13,11 @@ of a separate validation error.
|
||||
"""
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import scoped_session
|
||||
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.nodes.llm.entities import ModelConfig
|
||||
from models.dataset import Dataset
|
||||
@ -38,7 +38,11 @@ from services.errors.knowledge_retrieval import (
|
||||
class InnerKnowledgeRetrievalService:
|
||||
"""Validate inner caller scope and delegate to workflow dataset retrieval."""
|
||||
|
||||
def retrieve(self, request: InnerKnowledgeRetrieveRequest) -> InnerKnowledgeRetrieveResponse:
|
||||
def retrieve(
|
||||
self,
|
||||
request: InnerKnowledgeRetrieveRequest,
|
||||
session: scoped_session,
|
||||
) -> InnerKnowledgeRetrieveResponse:
|
||||
"""Run tenant-scoped retrieval for a trusted internal caller.
|
||||
|
||||
This method only rejects caller app existence/tenant mismatches and
|
||||
@ -56,8 +60,8 @@ class InnerKnowledgeRetrievalService:
|
||||
InnerKnowledgeRetrieveDatasetTenantMismatchError:
|
||||
At least one requested dataset is outside the caller tenant.
|
||||
"""
|
||||
self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id)
|
||||
self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids)
|
||||
self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id, session=session)
|
||||
self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids, session=session)
|
||||
|
||||
rag = DatasetRetrieval()
|
||||
results = rag.knowledge_retrieval(request=self._to_rag_request(request))
|
||||
@ -66,8 +70,8 @@ class InnerKnowledgeRetrievalService:
|
||||
usage=InnerKnowledgeRetrieveUsage.model_validate(jsonable_encoder(rag.llm_usage)),
|
||||
)
|
||||
|
||||
def _validate_caller_app(self, *, tenant_id: str, app_id: str) -> None:
|
||||
app = db.session.scalar(select(App).where(App.id == app_id).limit(1))
|
||||
def _validate_caller_app(self, *, tenant_id: str, app_id: str, session: scoped_session) -> None:
|
||||
app = session.scalar(select(App).where(App.id == app_id).limit(1))
|
||||
if app is None:
|
||||
raise InnerKnowledgeRetrieveAppNotFoundError(f"App '{app_id}' not found")
|
||||
if app.tenant_id != tenant_id:
|
||||
@ -75,8 +79,8 @@ class InnerKnowledgeRetrievalService:
|
||||
f"App '{app_id}' does not belong to tenant '{tenant_id}'"
|
||||
)
|
||||
|
||||
def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str]) -> None:
|
||||
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
|
||||
def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str], session: scoped_session) -> None:
|
||||
datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
|
||||
|
||||
found_ids = {dataset.id for dataset in datasets}
|
||||
missing_ids = sorted(set(dataset_ids) - found_ids)
|
||||
|
||||
@ -74,14 +74,14 @@ def _build_source() -> Source:
|
||||
|
||||
class TestInnerKnowledgeRetrievalService:
|
||||
@patch("services.knowledge_retrieval_inner_service.DatasetRetrieval")
|
||||
@patch("services.knowledge_retrieval_inner_service.db")
|
||||
def test_retrieve_maps_multiple_request_and_skips_enable_api_check(self, mock_db, mock_rag_cls):
|
||||
def test_retrieve_maps_multiple_request_and_skips_enable_api_check(self, mock_rag_cls):
|
||||
request = _build_request()
|
||||
mock_session = MagicMock()
|
||||
mock_app = MagicMock(id="app-1", tenant_id="tenant-1")
|
||||
dataset_1 = MagicMock(id="dataset-1", tenant_id="tenant-1", enable_api=False)
|
||||
dataset_2 = MagicMock(id="dataset-2", tenant_id="tenant-1", enable_api=True)
|
||||
mock_db.session.scalar.return_value = mock_app
|
||||
mock_db.session.scalars.return_value.all.return_value = [dataset_1, dataset_2]
|
||||
mock_session.scalar.return_value = mock_app
|
||||
mock_session.scalars.return_value.all.return_value = [dataset_1, dataset_2]
|
||||
|
||||
rag = MagicMock()
|
||||
rag.knowledge_retrieval.return_value = [_build_source()]
|
||||
@ -101,7 +101,7 @@ class TestInnerKnowledgeRetrievalService:
|
||||
}
|
||||
mock_rag_cls.return_value = rag
|
||||
|
||||
response = InnerKnowledgeRetrievalService().retrieve(request)
|
||||
response = InnerKnowledgeRetrievalService().retrieve(request, mock_session)
|
||||
|
||||
rag_request = rag.knowledge_retrieval.call_args.kwargs["request"]
|
||||
assert rag_request.tenant_id == "tenant-1"
|
||||
@ -127,8 +127,7 @@ class TestInnerKnowledgeRetrievalService:
|
||||
assert response.usage.currency == "USD"
|
||||
|
||||
@patch("services.knowledge_retrieval_inner_service.DatasetRetrieval")
|
||||
@patch("services.knowledge_retrieval_inner_service.db")
|
||||
def test_retrieve_maps_single_request(self, mock_db, mock_rag_cls):
|
||||
def test_retrieve_maps_single_request(self, mock_rag_cls):
|
||||
request = _build_request(
|
||||
dataset_ids=["dataset-1"],
|
||||
retrieval={
|
||||
@ -151,8 +150,9 @@ class TestInnerKnowledgeRetrievalService:
|
||||
},
|
||||
attachment_ids=[],
|
||||
)
|
||||
mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
|
||||
mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")]
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
|
||||
mock_session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")]
|
||||
|
||||
rag = MagicMock()
|
||||
rag.knowledge_retrieval.return_value = []
|
||||
@ -172,7 +172,7 @@ class TestInnerKnowledgeRetrievalService:
|
||||
}
|
||||
mock_rag_cls.return_value = rag
|
||||
|
||||
InnerKnowledgeRetrievalService().retrieve(request)
|
||||
InnerKnowledgeRetrievalService().retrieve(request, mock_session)
|
||||
|
||||
rag_request = rag.knowledge_retrieval.call_args.kwargs["request"]
|
||||
assert rag_request.retrieval_mode == "single"
|
||||
@ -184,35 +184,35 @@ class TestInnerKnowledgeRetrievalService:
|
||||
assert rag_request.metadata_model_config is not None
|
||||
assert rag_request.metadata_model_config.provider == "openai"
|
||||
|
||||
@patch("services.knowledge_retrieval_inner_service.db")
|
||||
def test_retrieve_raises_when_app_missing(self, mock_db):
|
||||
mock_db.session.scalar.return_value = None
|
||||
def test_retrieve_raises_when_app_missing(self):
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(InnerKnowledgeRetrieveAppNotFoundError):
|
||||
InnerKnowledgeRetrievalService().retrieve(_build_request())
|
||||
InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session)
|
||||
|
||||
@patch("services.knowledge_retrieval_inner_service.db")
|
||||
def test_retrieve_raises_when_app_belongs_to_other_tenant(self, mock_db):
|
||||
mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-2")
|
||||
def test_retrieve_raises_when_app_belongs_to_other_tenant(self):
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-2")
|
||||
|
||||
with pytest.raises(InnerKnowledgeRetrieveAppTenantMismatchError):
|
||||
InnerKnowledgeRetrievalService().retrieve(_build_request())
|
||||
InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session)
|
||||
|
||||
@patch("services.knowledge_retrieval_inner_service.db")
|
||||
def test_retrieve_raises_when_dataset_missing(self, mock_db):
|
||||
mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
|
||||
mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")]
|
||||
def test_retrieve_raises_when_dataset_missing(self):
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
|
||||
mock_session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")]
|
||||
|
||||
with pytest.raises(InnerKnowledgeRetrieveDatasetNotFoundError):
|
||||
InnerKnowledgeRetrievalService().retrieve(_build_request())
|
||||
InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session)
|
||||
|
||||
@patch("services.knowledge_retrieval_inner_service.db")
|
||||
def test_retrieve_raises_when_dataset_belongs_to_other_tenant(self, mock_db):
|
||||
mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
|
||||
mock_db.session.scalars.return_value.all.return_value = [
|
||||
def test_retrieve_raises_when_dataset_belongs_to_other_tenant(self):
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
|
||||
mock_session.scalars.return_value.all.return_value = [
|
||||
MagicMock(id="dataset-1", tenant_id="tenant-1"),
|
||||
MagicMock(id="dataset-2", tenant_id="tenant-2"),
|
||||
]
|
||||
|
||||
with pytest.raises(InnerKnowledgeRetrieveDatasetTenantMismatchError):
|
||||
InnerKnowledgeRetrievalService().retrieve(_build_request())
|
||||
InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user