refactor(api): migrate remaining console APIs to use injected user/tenant (#37288)

This commit is contained in:
chariri 2026-06-11 10:30:31 +09:00 committed by GitHub
parent 5ed663e7fd
commit 2a46a7d91d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 1448 additions and 1265 deletions

View File

@ -8,6 +8,7 @@ from controllers.common.schema import register_response_schema_models, register_
from fields.hit_testing_fields import HitTestingResponse
from libs.helper import dump_response
from libs.login import login_required
from models import Account
from .. import console_ns
from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
@ -15,6 +16,8 @@ from ..wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
register_schema_models(console_ns, HitTestingPayload)
@ -38,11 +41,16 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id: UUID) -> dict[str, object]:
@with_current_tenant_id
@with_current_user
def post(self, current_user: Account, current_tenant_id: str, dataset_id: UUID) -> dict[str, object]:
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
dataset = self.get_and_validate_dataset(dataset_id_str, current_user, current_tenant_id)
args = self.parse_args(console_ns.payload)
self.hit_testing_args_check(args)
return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args))
return dump_response(
HitTestingResponse,
self.perform_hit_testing(dataset, args, current_user, current_tenant_id),
)

View File

@ -19,7 +19,7 @@ from core.errors.error import (
QuotaExceededError,
)
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import current_user
from libs.login import resolve_account_fallback
from models.account import Account
from models.dataset import Dataset
from services.dataset_service import DatasetService
@ -71,8 +71,10 @@ class DatasetsHitTestingBase:
return normalized_records
@staticmethod
def get_and_validate_dataset(dataset_id: str) -> Dataset:
assert isinstance(current_user, Account)
def get_and_validate_dataset(
dataset_id: str, current_user: Account | None = None, current_tenant_id: str | None = None
) -> Dataset:
current_user, _ = resolve_account_fallback(current_user, current_tenant_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
@ -95,9 +97,14 @@ class DatasetsHitTestingBase:
return hit_testing_payload.model_dump(exclude_none=True)
@staticmethod
def perform_hit_testing(dataset: Dataset, args: dict[str, Any]) -> dict[str, Any]:
assert isinstance(current_user, Account)
def perform_hit_testing(
dataset: Dataset,
args: dict[str, Any],
current_user: Account | None = None,
current_tenant_id: str | None = None,
) -> dict[str, Any]:
try:
current_user, _ = resolve_account_fallback(current_user, current_tenant_id)
response = HitTestingService.retrieve(
dataset=dataset,
query=cast(str, args.get("query")),

View File

@ -11,6 +11,7 @@ from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from fields.dataset_fields import (
@ -50,7 +51,8 @@ class DatasetMetadataCreateApi(Resource):
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
dataset_id_str = str(dataset_id)
@ -59,7 +61,7 @@ class DatasetMetadataCreateApi(Resource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args, current_user, current_tenant_id)
return dump_response(DatasetMetadataResponse, metadata), 201
@setup_required
@ -87,7 +89,8 @@ class DatasetMetadataApi(Resource):
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
@with_current_user
def patch(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
@with_current_tenant_id
def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, metadata_id: UUID):
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
name = payload.name
@ -98,7 +101,9 @@ class DatasetMetadataApi(Resource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
metadata = MetadataService.update_metadata_name(
dataset_id_str, metadata_id_str, name, current_user, current_tenant_id
)
return dump_response(DatasetMetadataResponse, metadata), 200
@setup_required
@ -181,7 +186,7 @@ class DocumentMetadataEditApi(Resource):
metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
MetadataService.update_documents_metadata(dataset, metadata_args)
MetadataService.update_documents_metadata(dataset, metadata_args, current_user)
# Frontend callers only await success and invalidate caches; no response body is consumed.
return "", 204

View File

@ -21,11 +21,14 @@ from controllers.console.wraps import (
enterprise_license_required,
knowledge_pipeline_publish_enabled,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import dump_response
from libs.login import login_required
from models.account import Account
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -96,10 +99,11 @@ class PipelineTemplateListApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
def get(self) -> JsonResponseWithStatus:
@with_current_tenant_id
def get(self, current_tenant_id: str) -> JsonResponseWithStatus:
query = PipelineTemplateListQuery.model_validate(request.args.to_dict(flat=True))
# get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(query.type, query.language)
pipeline_templates = RagPipelineService.get_pipeline_templates(query.type, query.language, current_tenant_id)
return dump_response(PipelineTemplateListResponse, pipeline_templates), 200
@ -128,10 +132,14 @@ class CustomizedPipelineTemplateApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str) -> tuple[str, int]:
@with_current_user
@with_current_tenant_id
def patch(self, current_tenant_id: str, current_user: Account, template_id: str) -> tuple[str, int]:
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
RagPipelineService.update_customized_pipeline_template(
template_id, pipeline_template_info, current_user, current_tenant_id
)
return "", 204
@console_ns.response(204, "Pipeline template deleted")
@ -139,8 +147,9 @@ class CustomizedPipelineTemplateApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, template_id: str) -> tuple[str, int]:
RagPipelineService.delete_customized_pipeline_template(template_id)
@with_current_tenant_id
def delete(self, current_tenant_id: str, template_id: str) -> tuple[str, int]:
RagPipelineService.delete_customized_pipeline_template(template_id, current_tenant_id)
return "", 204
@setup_required
@ -168,8 +177,12 @@ class PublishCustomizedPipelineTemplateApi(Resource):
@account_initialization_required
@enterprise_license_required
@knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str) -> tuple[str, int]:
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, pipeline_id: str) -> tuple[str, int]:
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
rag_pipeline_service.publish_customized_pipeline_template(
pipeline_id, payload.model_dump(), current_user, current_tenant_id
)
return "", 204

View File

@ -48,7 +48,7 @@ from fields.workflow_run_fields import (
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs import helper
from libs.helper import TimestampField, UUIDStrOrEmpty, dump_response
from libs.login import current_user, login_required
from libs.login import login_required
from models import Account
from models.dataset import Pipeline
from models.model import EndUser
@ -835,7 +835,7 @@ class RagPipelineWorkflowRunListApi(Resource):
}
)
args = {
"last_id": str(query.last_id) if query.last_id else None,
"last_id": query.last_id or None,
"limit": query.limit,
}
@ -881,7 +881,8 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
def get(self, pipeline: Pipeline, run_id: UUID):
@with_current_user
def get(self, current_user: Account, pipeline: Pipeline, run_id: UUID):
"""
Get workflow run node execution list
"""
@ -988,9 +989,11 @@ class RagPipelineRecommendedPluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type, current_user, current_tenant_id)
return recommended_plugins

View File

@ -18,7 +18,7 @@ from controllers.console.app.error import (
)
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user_id
from controllers.console.wraps import with_current_user, with_current_user_id
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@ -30,7 +30,6 @@ from extensions.ext_database import db
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
@ -84,7 +83,8 @@ register_response_schema_models(console_ns, SimpleResultResponse)
)
class CompletionApi(InstalledAppResource):
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
def post(self, installed_app: InstalledApp):
@with_current_user
def post(self, current_user: Account, installed_app: InstalledApp):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
@ -101,8 +101,6 @@ class CompletionApi(InstalledAppResource):
db.session.commit()
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
)
@ -160,7 +158,8 @@ class CompletionStopApi(InstalledAppResource):
)
class ChatApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
def post(self, installed_app: InstalledApp):
@with_current_user
def post(self, current_user: Account, installed_app: InstalledApp):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
@ -177,8 +176,6 @@ class ChatApi(InstalledAppResource):
db.session.commit()
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)

View File

@ -11,6 +11,7 @@ from controllers.common.schema import register_response_schema_models, register_
from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
@ -19,7 +20,6 @@ from fields.conversation_fields import (
SimpleConversation,
)
from libs.helper import UUIDStrOrEmpty
from libs.login import current_user
from models import Account
from models.model import AppMode, InstalledApp
from services.conversation_service import ConversationService
@ -45,7 +45,8 @@ register_response_schema_models(console_ns, ResultResponse)
)
class ConversationListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app: InstalledApp):
@with_current_user
def get(self, current_user: Account, installed_app: InstalledApp):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
@ -66,14 +67,12 @@ class ConversationListApi(InstalledAppResource):
args = ConversationListQuery.model_validate(raw_args)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
with sessionmaker(db.engine).begin() as session:
pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=current_user,
last_id=str(args.last_id) if args.last_id else None,
last_id=args.last_id or None,
limit=args.limit,
invoke_from=InvokeFrom.EXPLORE,
pinned=args.pinned,
@ -95,7 +94,8 @@ class ConversationListApi(InstalledAppResource):
)
class ConversationApi(InstalledAppResource):
@console_ns.response(204, "Conversation deleted successfully")
def delete(self, installed_app: InstalledApp, c_id: UUID):
@with_current_user
def delete(self, current_user: Account, installed_app: InstalledApp, c_id: UUID):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
@ -105,8 +105,6 @@ class ConversationApi(InstalledAppResource):
conversation_id = str(c_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -120,7 +118,8 @@ class ConversationApi(InstalledAppResource):
)
class ConversationRenameApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app: InstalledApp, c_id: UUID):
@with_current_user
def post(self, current_user: Account, installed_app: InstalledApp, c_id: UUID):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
@ -133,8 +132,6 @@ class ConversationRenameApi(InstalledAppResource):
payload = ConversationRenamePayload.model_validate(console_ns.payload or {})
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
conversation = ConversationService.rename(
app_model, conversation_id, current_user, payload.name, payload.auto_generate
)
@ -153,7 +150,8 @@ class ConversationRenameApi(InstalledAppResource):
)
class ConversationPinApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
def patch(self, installed_app: InstalledApp, c_id: UUID):
@with_current_user
def patch(self, current_user: Account, installed_app: InstalledApp, c_id: UUID):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
@ -164,8 +162,6 @@ class ConversationPinApi(InstalledAppResource):
conversation_id = str(c_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
WebConversationService.pin(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -179,7 +175,8 @@ class ConversationPinApi(InstalledAppResource):
)
class ConversationUnPinApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
def patch(self, installed_app: InstalledApp, c_id: UUID):
@with_current_user
def patch(self, current_user: Account, installed_app: InstalledApp, c_id: UUID):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
@ -188,8 +185,6 @@ class ConversationUnPinApi(InstalledAppResource):
raise NotChatAppError()
conversation_id = str(c_id)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
WebConversationService.unpin(app_model, conversation_id, current_user)
return ResultResponse(result="success").model_dump(mode="json")

View File

@ -33,6 +33,7 @@ from controllers.console.explore.error import (
NotWorkflowAppError,
)
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
from controllers.console.wraps import with_current_user
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -63,7 +64,6 @@ from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.account import TenantStatus
from models.model import AppMode, Site
@ -155,7 +155,8 @@ register_schema_models(console_ns, WorkflowRunRequest, ChatRequest, TextToSpeech
class TrialAppWorkflowRunApi(TrialAppResource):
@trial_feature_enable
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
def post(self, trial_app):
@with_current_user
def post(self, current_user: Account, trial_app):
"""
Run workflow
"""
@ -168,7 +169,6 @@ class TrialAppWorkflowRunApi(TrialAppResource):
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
assert current_user is not None
try:
app_id = app_model.id
user_id = current_user.id
@ -206,7 +206,6 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
@ -221,7 +220,8 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
class TrialChatApi(TrialAppResource):
@console_ns.expect(console_ns.models[ChatRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
@with_current_user
def post(self, current_user: Account, trial_app):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -239,9 +239,6 @@ class TrialChatApi(TrialAppResource):
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
@ -276,7 +273,8 @@ class TrialChatApi(TrialAppResource):
class TrialMessageSuggestedQuestionApi(TrialAppResource):
def get(self, trial_app, message_id):
@with_current_user
def get(self, current_user: Account, trial_app, message_id):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -285,8 +283,6 @@ class TrialMessageSuggestedQuestionApi(TrialAppResource):
message_id = str(message_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)
@ -313,15 +309,13 @@ class TrialMessageSuggestedQuestionApi(TrialAppResource):
class TrialChatAudioApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
@with_current_user
def post(self, current_user: Account, trial_app):
app_model = trial_app
file = request.files["file"]
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
@ -358,7 +352,8 @@ class TrialChatAudioApi(TrialAppResource):
class TrialChatTextApi(TrialAppResource):
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
@with_current_user
def post(self, current_user: Account, trial_app):
app_model = trial_app
try:
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
@ -366,8 +361,6 @@ class TrialChatTextApi(TrialAppResource):
message_id = request_data.message_id
text = request_data.text
voice = request_data.voice
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
@ -405,7 +398,8 @@ class TrialChatTextApi(TrialAppResource):
class TrialCompletionApi(TrialAppResource):
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
@with_current_user
def post(self, current_user: Account, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -417,9 +411,6 @@ class TrialCompletionApi(TrialAppResource):
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id

View File

@ -1,10 +1,9 @@
from flask_restx import Resource
from werkzeug.exceptions import Unauthorized
from controllers.common.schema import register_response_schema_models
from fields.base import ResponseModel
from libs.helper import dump_response
from libs.login import current_user, login_required
from libs.login import current_account_with_tenant_optional, login_required
from services.feature_service import (
FeatureModel,
FeatureService,
@ -13,7 +12,12 @@ from services.feature_service import (
)
from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
from .wraps import (
account_initialization_required,
cloud_utm_record,
setup_required,
with_current_tenant_id,
)
class TrialModelsResponse(ResponseModel):
@ -133,12 +137,6 @@ class SystemFeatureApi(Resource):
Only non-sensitive configuration data should be returned by this endpoint.
"""
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
# raise `Unauthorized` exception if authentication token is not provided.
try:
is_authenticated = current_user.is_authenticated
except Unauthorized:
is_authenticated = False
current_user, _ = current_account_with_tenant_optional()
is_authenticated = current_user is not None
return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump()

View File

@ -17,7 +17,7 @@ from core.trigger.entities.entities import SubscriptionBuilderUpdater
from core.trigger.trigger_manager import TriggerManager
from extensions.ext_database import db
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_user, login_required
from libs.login import login_required
from models.account import Account
from models.provider_ids import TriggerProviderID
from services.plugin.oauth_service import OAuthProxyService
@ -31,6 +31,8 @@ from ..wraps import (
edit_permission_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
logger = logging.getLogger(__name__)
@ -77,12 +79,9 @@ class TriggerProviderIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider)
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
return TriggerManager.get_trigger_plugin_icon(tenant_id=tenant_id, provider_id=provider)
@console_ns.route("/workspaces/current/triggers")
@ -90,12 +89,10 @@ class TriggerProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
@with_current_tenant_id
def get(self, tenant_id: str):
"""List all trigger providers for the current tenant"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
return jsonable_encoder(TriggerProviderService.list_trigger_providers(tenant_id))
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/info")
@ -103,14 +100,10 @@ class TriggerProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
"""Get info for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return jsonable_encoder(
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
)
return jsonable_encoder(TriggerProviderService.get_trigger_provider(tenant_id, TriggerProviderID(provider)))
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
@ -119,16 +112,14 @@ class TriggerSubscriptionListApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider: str):
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, provider: str):
"""List all trigger subscriptions for the current tenant's provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
return jsonable_encoder(
TriggerProviderService.list_trigger_provider_subscriptions(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
provider_id=TriggerProviderID(provider),
user=user,
)
@ -149,17 +140,16 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str):
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account, provider: str):
"""Add a new subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {})
try:
credential_type = CredentialType.of(payload.credential_type)
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
credential_type=credential_type,
@ -194,17 +184,16 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str, subscription_builder_id: str):
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account, provider: str, subscription_builder_id: str):
"""Verify and update a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
try:
# Use atomic update_and_verify to prevent race conditions
return TriggerSubscriptionBuilderService.update_and_verify_builder(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
@ -226,17 +215,14 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str, subscription_builder_id: str):
@with_current_tenant_id
def post(self, tenant_id: str, provider: str, subscription_builder_id: str):
"""Update a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
@ -262,10 +248,6 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
@account_initialization_required
def get(self, provider: str, subscription_builder_id: str):
"""Get the request logs for a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
@ -283,15 +265,15 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str, subscription_builder_id: str):
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account, provider: str, subscription_builder_id: str):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
try:
# Use atomic update_and_build to prevent race conditions
TriggerSubscriptionBuilderService.update_and_build_builder(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
@ -316,15 +298,13 @@ class TriggerSubscriptionUpdateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, subscription_id: str):
@with_current_tenant_id
def post(self, tenant_id: str, subscription_id: str):
"""Update a subscription instance"""
user = current_user
assert user.current_tenant_id is not None
request = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
subscription_id=subscription_id,
)
if not subscription:
@ -341,7 +321,7 @@ class TriggerSubscriptionUpdateApi(Resource):
manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED
if rename or manually_created:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
subscription_id=subscription_id,
name=request.name,
properties=request.properties,
@ -351,7 +331,7 @@ class TriggerSubscriptionUpdateApi(Resource):
# For the rest cases(API_KEY, OAUTH2)
# we need to call third party provider(e.g. GitHub) to rebuild the subscription
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
name=request.name,
provider_id=provider_id,
subscription_id=subscription_id,
@ -375,23 +355,21 @@ class TriggerSubscriptionDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, subscription_id: str):
@with_current_tenant_id
def post(self, tenant_id: str, subscription_id: str):
"""Delete a subscription instance"""
user = current_user
assert user.current_tenant_id is not None
try:
with sessionmaker(db.engine).begin() as session:
# Delete trigger provider subscription
TriggerProviderService.delete_trigger_provider(
session=session,
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
subscription_id=subscription_id,
)
# Delete plugin triggers
TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription(
session=session,
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
subscription_id=subscription_id,
)
return {"result": "success"}
@ -407,17 +385,14 @@ class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, provider: str):
"""Initiate OAuth authorization flow for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
tenant_id = user.current_tenant_id
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
@ -557,30 +532,28 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider: str):
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
"""Get OAuth client configuration for a provider"""
user = current_user
assert user.current_tenant_id is not None
try:
provider_id = TriggerProviderID(provider)
# Get custom OAuth client params if exists
custom_params = TriggerProviderService.get_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
provider_id=provider_id,
)
# Check if custom client is enabled
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
provider_id=provider_id,
)
system_client_exists = TriggerProviderService.is_oauth_system_client_exists(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
provider_id=provider_id,
)
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
return jsonable_encoder(
{
@ -603,17 +576,15 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
"""Configure custom OAuth client for a provider"""
user = current_user
assert user.current_tenant_id is not None
payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {})
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
provider_id=provider_id,
client_params=payload.client_params,
enabled=payload.enabled,
@ -629,16 +600,14 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
@with_current_tenant_id
def delete(self, tenant_id: str, provider: str):
"""Remove custom OAuth client configuration"""
user = current_user
assert user.current_tenant_id is not None
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.delete_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
provider_id=provider_id,
)
except ValueError as e:
@ -657,16 +626,15 @@ class TriggerSubscriptionVerifyApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider: str, subscription_id: str):
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account, provider: str, subscription_id: str):
"""Verify credentials for an existing subscription (edit mode only)"""
user = current_user
assert user.current_tenant_id is not None
verify_request = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
try:
result = TriggerProviderService.verify_subscription_credentials(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
subscription_id=subscription_id,

View File

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Concatenate, cast, overload
from flask import Response, current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS
from werkzeug.exceptions import Unauthorized
from werkzeug.local import LocalProxy
from configs import dify_config
@ -48,6 +49,53 @@ def current_account_with_tenant() -> tuple[Account, str]:
return user, user.current_tenant_id
def current_account_with_tenant_optional() -> tuple[Account | None, str | None]:
try:
user = _resolve_current_user()
except Unauthorized:
return None, None
if not isinstance(user, Account):
return None, None
if not bool(getattr(user, "is_authenticated", False)):
return None, None
return user, user.current_tenant_id
def resolve_account_fallback(
current_user: Account | None = None,
current_tenant_id: str | None = None,
*,
fallback_tenant_id: str | None = None,
) -> tuple[Account, str]:
"""
If the provided current user and tenant ID is None, fallback to current_account_with_tenant.
This is useful for those service layers whose controllers are not migrated to use DI for
resolving current user yet.
TODO: this should be removed after all ctrls (especially service API) are migrated
"""
if current_user is not None:
tenant_id = current_tenant_id or fallback_tenant_id
if tenant_id is None:
raise ValueError("current_tenant_id is required when current_user is provided.")
return current_user, tenant_id
return current_account_with_tenant()
def resolve_tenant_id_fallback(current_tenant_id: str | None = None) -> str:
"""
If the provided tenant ID is None, fallback to the tenant resolved from current_account_with_tenant.
This is useful for tenant-only service paths whose controllers are not all migrated to tenant injection yet.
TODO: this should be removed after all ctrls (especially service API) are migrated
"""
if current_tenant_id is not None:
return current_tenant_id
_, tenant_id = current_account_with_tenant()
return tenant_id
@overload
def login_required[T, **P, R](
func: Callable[Concatenate[T, P], R],

View File

@ -7,7 +7,8 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField, Metad
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant
from libs.login import resolve_account_fallback
from models import Account
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
from models.enums import DatasetMetadataType
from services.dataset_service import DocumentService
@ -21,11 +22,16 @@ logger = logging.getLogger(__name__)
class MetadataService:
@staticmethod
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
def create_metadata(
dataset_id: str,
metadata_args: MetadataArgs,
current_user: Account | None = None, # TODO: the service_api is not migrated yet
current_tenant_id: str | None = None,
) -> DatasetMetadata:
# check if metadata name is too long
if len(metadata_args.name) > 255:
raise ValueError("Metadata name cannot exceed 255 characters.")
current_user, current_tenant_id = current_account_with_tenant()
current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id)
# check if metadata name already exists
if db.session.scalar(
select(DatasetMetadata)
@ -52,14 +58,20 @@ class MetadataService:
return metadata
@staticmethod
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
def update_metadata_name(
dataset_id: str,
metadata_id: str,
name: str,
current_user: Account | None = None,
current_tenant_id: str | None = None, # TODO: the service_api is not migrated yet
) -> DatasetMetadata | None:
# check if metadata name is too long
if len(name) > 255:
raise ValueError("Metadata name cannot exceed 255 characters.")
lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists
current_user, current_tenant_id = current_account_with_tenant()
current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id)
if db.session.scalar(
select(DatasetMetadata)
.where(
@ -107,6 +119,7 @@ class MetadataService:
return metadata
except Exception:
logger.exception("Update metadata name failed")
return None
finally:
redis_client.delete(lock_key)
@ -217,7 +230,15 @@ class MetadataService:
redis_client.delete(lock_key)
@staticmethod
def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData):
def update_documents_metadata(
dataset: Dataset,
metadata_args: MetadataOperationData,
current_user: Account | None = None, # TODO: the service_api is not migrated yet
current_tenant_id: str | None = None,
):
current_user, current_tenant_id = resolve_account_fallback(
current_user, current_tenant_id, fallback_tenant_id=dataset.tenant_id
)
for operation in metadata_args.operation_data:
lock_key = f"document_metadata_lock_{operation.document_id}"
try:
@ -248,7 +269,6 @@ class MetadataService:
)
)
current_user, current_tenant_id = current_account_with_tenant()
for metadata_value in operation.metadata_list:
# check if binding already exists
if operation.partial_update:

View File

@ -21,7 +21,8 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return PipelineTemplateType.BUILTIN
@override
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]:
del current_tenant_id
result = self.fetch_pipeline_templates_from_builtin(language)
return result

View File

@ -4,7 +4,7 @@ import yaml
from sqlalchemy import select
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from libs.login import resolve_tenant_id_fallback
from models.dataset import PipelineCustomizedTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
@ -40,8 +40,8 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
@override
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
_, current_tenant_id = current_account_with_tenant()
def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]:
current_tenant_id = resolve_tenant_id_fallback(current_tenant_id)
return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
@override

View File

@ -40,7 +40,8 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
@override
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]:
del current_tenant_id
return self.fetch_pipeline_templates_from_db(language)
@override

View File

@ -4,7 +4,7 @@ from typing import Any, Protocol
class PipelineTemplateRetrievalBase(Protocol):
"""Interface for pipeline template retrieval."""
def get_pipeline_templates(self, language: str) -> dict[str, Any]: ...
def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]: ...
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: ...

View File

@ -25,7 +25,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
@override
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]:
del current_tenant_id
try:
return self.fetch_pipeline_templates_from_dify_official(language)
except Exception as e:

View File

@ -8,7 +8,6 @@ from datetime import UTC, datetime
from typing import Any, cast
from uuid import uuid4
from flask_login import current_user
from sqlalchemy import func, select
from sqlalchemy.orm import Session, sessionmaker
@ -54,6 +53,7 @@ from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_htt
from graphon.runtime import VariablePool
from graphon.variables.variables import Variable, VariableBase
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import resolve_account_fallback, resolve_tenant_id_fallback
from models import Account
from models.dataset import ( # type: ignore
Dataset,
@ -104,11 +104,16 @@ class RagPipelineService:
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@classmethod
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict[str, Any]:
def get_pipeline_templates(
cls,
type: str = "built-in",
language: str = "en-US",
current_tenant_id: str | None = None,
) -> dict[str, Any]:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language)
result = retrieval_instance.get_pipeline_templates(language, current_tenant_id)
if not result.get("pipeline_templates") and language != "en-US":
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
@ -116,7 +121,7 @@ class RagPipelineService:
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result = retrieval_instance.get_pipeline_templates(language)
result = retrieval_instance.get_pipeline_templates(language, current_tenant_id)
return result
@classmethod
@ -146,17 +151,24 @@ class RagPipelineService:
return customized_result
@classmethod
def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
def update_customized_pipeline_template(
cls,
template_id: str,
template_info: PipelineTemplateInfoEntity,
current_user: Account | None = None,
current_tenant_id: str | None = None,
):
"""
Update pipeline template.
:param template_id: template id
:param template_info: template info
"""
current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id)
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.tenant_id == current_tenant_id,
)
.limit(1)
)
@ -169,7 +181,7 @@ class RagPipelineService:
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.tenant_id == current_tenant_id,
PipelineCustomizedTemplate.id != template_id,
)
.limit(1)
@ -184,15 +196,16 @@ class RagPipelineService:
return customized_template
@classmethod
def delete_customized_pipeline_template(cls, template_id: str):
def delete_customized_pipeline_template(cls, template_id: str, current_tenant_id: str | None = None):
"""
Delete customized pipeline template.
"""
current_tenant_id = resolve_tenant_id_fallback(current_tenant_id)
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.tenant_id == current_tenant_id,
)
.limit(1)
)
@ -1174,10 +1187,17 @@ class RagPipelineService:
return list(node_executions)
@classmethod
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict[str, Any]):
def publish_customized_pipeline_template(
cls,
pipeline_id: str,
args: dict[str, Any],
current_user: Account | None = None,
current_tenant_id: str | None = None,
):
"""
Publish customized pipeline template
"""
current_user, _ = resolve_account_fallback(current_user, current_tenant_id)
pipeline = db.session.get(Pipeline, pipeline_id)
if not pipeline:
raise ValueError("Pipeline not found")
@ -1357,7 +1377,7 @@ class RagPipelineService:
return []
return marketplace.batch_fetch_plugin_by_ids(plugin_ids)
def get_recommended_plugins(self, type: str) -> dict[str, Any]:
def get_recommended_plugins(self, type: str, current_user: Account, current_tenant_id: str) -> dict[str, Any]:
# Query active recommended plugins
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
if type and type != "all":
@ -1375,7 +1395,7 @@ class RagPipelineService:
plugin_ids = [plugin.plugin_id for plugin in pipeline_recommended_plugins]
providers = BuiltinToolManageService.list_builtin_tools(
user_id=current_user.id,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
)
providers_map = {provider.plugin_id: provider.to_dict() for provider in providers}

View File

@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from inspect import unwrap
from typing import cast
from unittest.mock import MagicMock, patch
from uuid import uuid4
@ -20,8 +21,8 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import (
PipelineTemplateListApi,
PublishCustomizedPipelineTemplateApi,
)
from models.account import Account
from models.dataset import PipelineCustomizedTemplate
from tests.test_containers_integration_tests.controllers.console.helpers import unwrap
class TestPipelineTemplateListApi:
@ -53,7 +54,7 @@ class TestPipelineTemplateListApi:
return_value=templates,
),
):
response, status = method(api)
response, status = method(api, str(uuid4()))
assert status == 200
assert response == {
@ -147,6 +148,9 @@ class TestCustomizedPipelineTemplateApi:
def test_patch_success(self, app: Flask) -> None:
api = CustomizedPipelineTemplateApi()
method = unwrap(api.patch)
account = Account(name="Test User", email="test@example.com")
account.id = str(uuid4())
tenant_id = str(uuid4())
payload = {
"name": "Template",
@ -161,15 +165,18 @@ class TestCustomizedPipelineTemplateApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template"
) as update_mock,
):
response, status = method(api, "tpl-1")
response, status = method(api, tenant_id, account, "tpl-1")
update_mock.assert_called_once()
assert update_mock.call_args.args[2] is account
assert update_mock.call_args.args[3] == tenant_id
assert status == 204
assert response == ""
def test_delete_success(self, app: Flask) -> None:
api = CustomizedPipelineTemplateApi()
method = unwrap(api.delete)
tenant_id = str(uuid4())
with (
app.test_request_context("/"),
@ -177,9 +184,9 @@ class TestCustomizedPipelineTemplateApi:
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template"
) as delete_mock,
):
response, status = method(api, "tpl-1")
response, status = method(api, tenant_id, "tpl-1")
delete_mock.assert_called_once_with("tpl-1")
delete_mock.assert_called_once_with("tpl-1", tenant_id)
assert status == 204
assert response == ""
@ -227,6 +234,9 @@ class TestPublishCustomizedPipelineTemplateApi:
def test_post_success(self, app: Flask) -> None:
api = PublishCustomizedPipelineTemplateApi()
method = unwrap(api.post)
account = Account(name="Test User", email="test@example.com")
account.id = str(uuid4())
tenant_id = str(uuid4())
payload = {
"name": "Template",
@ -244,8 +254,10 @@ class TestPublishCustomizedPipelineTemplateApi:
return_value=service,
),
):
response, status = method(api, "pipeline-1")
response, status = method(api, tenant_id, account, "pipeline-1")
service.publish_customized_pipeline_template.assert_called_once()
assert service.publish_customized_pipeline_template.call_args.args[2] is account
assert service.publish_customized_pipeline_template.call_args.args[3] == tenant_id
assert status == 204
assert response == ""

View File

@ -5,7 +5,6 @@ from __future__ import annotations
import json
from datetime import datetime
from inspect import unwrap
from types import SimpleNamespace
from typing import TypedDict, Unpack
from unittest.mock import MagicMock, patch
from uuid import uuid4
@ -35,12 +34,15 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
RagPipelineTaskStopApi,
RagPipelineTransformApi,
RagPipelineWorkflowLastRunApi,
RagPipelineWorkflowRunNodeExecutionListApi,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from graphon.enums import WorkflowNodeExecutionStatus
from libs.datetime_utils import naive_utc_now
from models.account import Account, TenantAccountRole
from models.dataset import Pipeline
from models.workflow import Workflow
from models.enums import CreatorUserRole
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
@ -71,7 +73,7 @@ class WorkflowFactoryPayload(TypedDict):
created_by: str
created_at: datetime
updated_by: str | None
updated_at: datetime
updated_at: datetime | None
environment_variables: list[WorkflowVariablePayload]
conversation_variables: list[WorkflowVariablePayload]
rag_pipeline_variables: list[WorkflowVariablePayload]
@ -90,39 +92,69 @@ class WorkflowFactoryOverrides(TypedDict, total=False):
created_by: str
created_at: datetime
updated_by: str | None
updated_at: datetime
updated_at: datetime | None
environment_variables: list[WorkflowVariablePayload]
conversation_variables: list[WorkflowVariablePayload]
rag_pipeline_variables: list[WorkflowVariablePayload]
def make_node_execution(**overrides: object) -> SimpleNamespace:
payload: dict[str, object] = {
class NodeExecutionOverrides(TypedDict, total=False):
id: str
tenant_id: str
app_id: str
workflow_id: str
workflow_run_id: str | None
index: int
predecessor_node_id: str | None
node_execution_id: str | None
node_id: str
node_type: str
title: str
inputs: str | None
process_data: str | None
outputs: str | None
status: WorkflowNodeExecutionStatus
error: str | None
elapsed_time: float
execution_metadata: str | None
created_at: datetime
created_by_role: CreatorUserRole
created_by: str
finished_at: datetime | None
def make_node_execution(**overrides: Unpack[NodeExecutionOverrides]) -> WorkflowNodeExecutionModel:
payload: NodeExecutionOverrides = {
"id": "node-exec-1",
"tenant_id": DEFAULT_WORKFLOW_TENANT_ID,
"app_id": DEFAULT_WORKFLOW_APP_ID,
"workflow_id": "workflow-1",
"workflow_run_id": None,
"index": 1,
"predecessor_node_id": None,
"node_execution_id": None,
"node_id": "node1",
"node_type": "start",
"title": "Start",
"inputs_dict": {"query": "hello"},
"process_data_dict": {},
"outputs_dict": {"answer": "world"},
"status": "succeeded",
"inputs": json.dumps({"query": "hello"}),
"process_data": json.dumps({}),
"outputs": json.dumps({"answer": "world"}),
"status": WorkflowNodeExecutionStatus.SUCCEEDED,
"error": None,
"elapsed_time": 1.0,
"execution_metadata_dict": {},
"extras": {},
"execution_metadata": json.dumps({}),
"created_at": datetime(2026, 1, 1, 0, 0, 0),
"created_by_role": "account",
"created_by_account": None,
"created_by_end_user": None,
"created_by_role": CreatorUserRole.ACCOUNT,
"created_by": DEFAULT_WORKFLOW_CREATED_BY,
"finished_at": datetime(2026, 1, 1, 0, 0, 1),
"inputs_truncated": False,
"outputs_truncated": False,
"process_data_truncated": False,
}
payload.update(overrides)
return SimpleNamespace(**payload)
execution = WorkflowNodeExecutionModel(
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
**payload,
)
execution.offload_data = []
return execution
def default_workflow_payload() -> WorkflowFactoryPayload:
@ -274,7 +306,10 @@ class TestDraftWorkflowApi:
pipeline = make_pipeline()
user = make_account(id="account-1")
workflow = MagicMock(unique_hash="restored-hash", updated_at=None, created_at=datetime(2024, 1, 1))
workflow = make_workflow(
graph=json.dumps({"nodes": [{"id": "restored"}], "edges": []}),
created_at=datetime(2024, 1, 1),
)
service = MagicMock()
service.restore_published_workflow_to_draft.return_value = workflow
@ -289,7 +324,7 @@ class TestDraftWorkflowApi:
result = method(api, user, pipeline, "published-workflow")
assert result["result"] == "success"
assert result["hash"] == "restored-hash"
assert result["hash"] == workflow.unique_hash
def test_restore_published_workflow_to_draft_not_found(self, app: Flask) -> None:
api = RagPipelineDraftWorkflowRestoreApi()
@ -515,10 +550,7 @@ class TestPublishedPipelineApis:
user = make_account(id="u1")
workflow = MagicMock(
id=str(uuid4()),
created_at=naive_utc_now(),
)
workflow = make_workflow(id=str(uuid4()), created_at=naive_utc_now())
service = MagicMock()
service.publish_workflow.return_value = workflow
@ -576,6 +608,8 @@ class TestMiscApis:
service = MagicMock()
service.get_recommended_plugins.return_value = [{"id": "p1"}]
user = make_account()
tenant_id = "tenant-1"
with (
app.test_request_context("/?type=all"),
@ -584,8 +618,9 @@ class TestMiscApis:
return_value=service,
),
):
result = method(api)
result = method(api, tenant_id, user)
assert result == [{"id": "p1"}]
service.get_recommended_plugins.assert_called_once_with("all", user, tenant_id)
class TestPublishedRagPipelineRunApi:
@ -814,7 +849,7 @@ class TestRagPipelineWorkflowLastRunApi:
method = unwrap(api.get)
pipeline = make_pipeline()
workflow = MagicMock()
workflow = make_workflow()
node_exec = make_node_execution()
service = MagicMock()
@ -853,6 +888,42 @@ class TestRagPipelineWorkflowLastRunApi:
method(api, pipeline, "node1")
class TestRagPipelineWorkflowRunNodeExecutionListApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_get_node_executions_passes_current_user(self, app: Flask) -> None:
api = RagPipelineWorkflowRunNodeExecutionListApi()
method = unwrap(api.get)
user = make_account()
pipeline = make_pipeline()
run_id = uuid4()
node_exec = make_node_execution(workflow_run_id=str(run_id))
service = MagicMock()
service.get_rag_pipeline_workflow_run_node_executions.return_value = [node_exec]
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, user, pipeline, run_id)
service.get_rag_pipeline_workflow_run_node_executions.assert_called_once_with(
pipeline=pipeline,
run_id=str(run_id),
user=user,
)
assert result["data"][0]["id"] == "node-exec-1"
assert result["data"][0]["inputs"] == {"query": "hello"}
assert result["data"][0]["outputs"] == {"answer": "world"}
class TestRagPipelineDatasourceVariableApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask) -> Flask:

View File

@ -2,7 +2,10 @@
from __future__ import annotations
from unittest.mock import MagicMock, patch
from dataclasses import dataclass
from inspect import unwrap
from typing import cast
from unittest.mock import patch
import pytest
from flask import Flask
@ -10,85 +13,103 @@ from werkzeug.exceptions import NotFound
import controllers.console.explore.conversation as conversation_module
from controllers.console.explore.error import NotChatAppError
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.model import AppMode
from models.enums import ConversationFromSource, ConversationStatus
from models.model import App, AppMode, Conversation, InstalledApp
from services.errors.conversation import (
ConversationNotExistsError,
LastConversationNotExistsError,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class FakeConversation:
def __init__(self, cid):
self.id = cid
self.name = "test"
self.inputs = {}
self.status = "normal"
self.introduction = ""
@dataclass
class InstalledAppCarrier:
app: App | None
@pytest.fixture
def chat_app():
app_model = MagicMock(mode=AppMode.CHAT, id="app-id")
return MagicMock(app=app_model)
def chat_app() -> InstalledApp:
app_model = App(
tenant_id="tenant-1",
name="Chat App",
mode=AppMode.CHAT,
enable_site=True,
enable_api=False,
)
app_model.id = "app-id"
return cast(InstalledApp, InstalledAppCarrier(app=app_model))
@pytest.fixture
def non_chat_app():
app_model = MagicMock(mode=AppMode.COMPLETION)
return MagicMock(app=app_model)
def non_chat_app() -> InstalledApp:
app_model = App(
tenant_id="tenant-1",
name="Completion App",
mode=AppMode.COMPLETION,
enable_site=True,
enable_api=False,
)
app_model.id = "app-id"
return cast(InstalledApp, InstalledAppCarrier(app=app_model))
def make_conversation(*, id: str) -> Conversation:
conversation = Conversation(
app_id="app-id",
mode=AppMode.CHAT,
name="test",
from_source=ConversationFromSource.API,
)
conversation.id = id
conversation.inputs = {}
conversation.status = ConversationStatus.NORMAL
conversation.introduction = ""
return conversation
@pytest.fixture
def user():
user = MagicMock(spec=Account)
def user() -> Account:
user = Account(name="User", email="user.com")
user.id = "uid"
return user
class TestConversationListApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_get_success(self, app: Flask, chat_app, user):
def test_get_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
pagination = MagicMock(
pagination = InfiniteScrollPagination(
data=[make_conversation(id="c1"), make_conversation(id="c2")],
limit=20,
has_more=False,
data=[FakeConversation("c1"), FakeConversation("c2")],
)
with (
app.test_request_context("/?limit=20"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.WebConversationService,
"pagination_by_last_id",
return_value=pagination,
),
):
result = method(chat_app)
result = method(api, user, chat_app)
assert result["limit"] == 20
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
def test_last_conversation_not_exists(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.WebConversationService,
"pagination_by_last_id",
@ -96,47 +117,45 @@ class TestConversationListApi:
),
):
with pytest.raises(NotFound):
method(chat_app)
method(api, user, chat_app)
def test_wrong_app_mode(self, app: Flask, non_chat_app):
def test_wrong_app_mode(self, app: Flask, non_chat_app: InstalledApp, user: Account) -> None:
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(NotChatAppError):
method(non_chat_app)
method(api, user, non_chat_app)
class TestConversationApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_delete_success(self, app: Flask, chat_app, user):
def test_delete_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.ConversationService,
"delete",
),
):
result = method(chat_app, "cid")
result = method(api, user, chat_app, "cid")
body, status = result
assert status == 204
assert body == ""
def test_delete_not_found(self, app: Flask, chat_app, user):
def test_delete_not_found(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.ConversationService,
"delete",
@ -144,48 +163,46 @@ class TestConversationApi:
),
):
with pytest.raises(NotFound):
method(chat_app, "cid")
method(api, user, chat_app, "cid")
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app: InstalledApp, user: Account) -> None:
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
with app.test_request_context("/"):
with pytest.raises(NotChatAppError):
method(non_chat_app, "cid")
method(api, user, non_chat_app, "cid")
class TestConversationRenameApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_rename_success(self, app: Flask, chat_app, user):
def test_rename_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
conversation = FakeConversation("cid")
conversation = make_conversation(id="cid")
with (
app.test_request_context("/", json={"name": "new"}),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.ConversationService,
"rename",
return_value=conversation,
),
):
result = method(chat_app, "cid")
result = method(api, user, chat_app, "cid")
assert result["id"] == "cid"
def test_rename_not_found(self, app: Flask, chat_app, user):
def test_rename_not_found(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "new"}),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.ConversationService,
"rename",
@ -193,48 +210,46 @@ class TestConversationRenameApi:
),
):
with pytest.raises(NotFound):
method(chat_app, "cid")
method(api, user, chat_app, "cid")
class TestConversationPinApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_pin_success(self, app: Flask, chat_app, user):
def test_pin_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
api = conversation_module.ConversationPinApi()
method = unwrap(api.patch)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.WebConversationService,
"pin",
),
):
result = method(chat_app, "cid")
result = method(api, user, chat_app, "cid")
assert result == {"result": "success"}
class TestConversationUnPinApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_unpin_success(self, app: Flask, chat_app, user):
def test_unpin_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
api = conversation_module.ConversationUnPinApi()
method = unwrap(api.patch)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.WebConversationService,
"unpin",
),
):
result = method(chat_app, "cid")
result = method(api, user, chat_app, "cid")
assert result == {"result": "success"}

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from inspect import unwrap
from unittest.mock import MagicMock, patch
import pytest
@ -31,123 +32,110 @@ from core.plugin.entities.plugin_daemon import CredentialType
from models.account import Account
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def mock_user():
user = MagicMock(spec=Account)
def mock_user() -> Account:
user = Account(name="User", email="user.com")
user.id = "u1"
user.current_tenant_id = "t1"
return user
class TestTriggerProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_icon_success(self, app: Flask):
def test_icon_success(self, app: Flask) -> None:
api = TriggerProviderIconApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_plugin_icon",
return_value="icon",
),
):
assert method(api, "github") == "icon"
assert method(api, "t1", "github") == "icon"
def test_list_providers(self, app: Flask):
def test_list_providers(self, app: Flask) -> None:
api = TriggerProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_providers",
return_value=[],
),
):
assert method(api) == []
assert method(api, "t1") == []
def test_provider_info(self, app: Flask):
def test_provider_info(self, app: Flask) -> None:
api = TriggerProviderInfoApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider",
return_value={"id": "p1"},
),
):
assert method(api, "github") == {"id": "p1"}
assert method(api, "t1", "github") == {"id": "p1"}
class TestTriggerSubscriptionListApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_list_success(self, app: Flask):
def test_list_success(self, app: Flask) -> None:
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
return_value=[],
),
):
assert method(api, "github") == []
assert method(api, "t1", mock_user(), "github") == []
def test_list_invalid_provider(self, app: Flask):
def test_list_invalid_provider(self, app: Flask) -> None:
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
side_effect=ValueError("bad"),
),
):
result, status = method(api, "bad")
result, status = method(api, "t1", mock_user(), "bad")
assert status == 404
class TestTriggerSubscriptionBuilderApis:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_create_builder(self, app: Flask):
def test_create_builder(self, app: Flask) -> None:
api = TriggerSubscriptionBuilderCreateApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
return_value={"id": "b1"},
),
):
result = method(api, "github")
result = method(api, "t1", mock_user(), "github")
assert "subscription_builder" in result
def test_get_builder(self, app: Flask):
def test_get_builder(self, app: Flask) -> None:
api = TriggerSubscriptionBuilderGetApi()
method = unwrap(api.get)
@ -160,50 +148,47 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_verify_builder(self, app: Flask):
def test_verify_builder(self, app: Flask) -> None:
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {"a": 1}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
return_value={"ok": True},
),
):
assert method(api, "github", "b1") == {"ok": True}
assert method(api, "t1", mock_user(), "github", "b1") == {"ok": True}
def test_verify_builder_error(self, app: Flask):
def test_verify_builder_error(self, app: Flask) -> None:
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
side_effect=Exception("err"),
),
):
with pytest.raises(ValueError):
method(api, "github", "b1")
method(api, "t1", mock_user(), "github", "b1")
def test_update_builder(self, app: Flask):
def test_update_builder(self, app: Flask) -> None:
api = TriggerSubscriptionBuilderUpdateApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "n"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder",
return_value={"id": "b1"},
),
):
assert method(api, "github", "b1") == {"id": "b1"}
assert method(api, "t1", "github", "b1") == {"id": "b1"}
def test_logs(self, app: Flask):
def test_logs(self, app: Flask) -> None:
api = TriggerSubscriptionBuilderLogsApi()
method = unwrap(api.get)
@ -212,7 +197,6 @@ class TestTriggerSubscriptionBuilderApis:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs",
return_value=[log],
@ -220,27 +204,26 @@ class TestTriggerSubscriptionBuilderApis:
):
assert "logs" in method(api, "github", "b1")
def test_build(self, app: Flask):
def test_build(self, app: Flask) -> None:
api = TriggerSubscriptionBuilderBuildApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "x"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_build_builder",
return_value=None,
),
):
assert method(api, "github", "b1") == 200
assert method(api, "t1", mock_user(), "github", "b1") == 200
class TestTriggerSubscriptionCrud:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_update_rename_only(self, app: Flask):
def test_update_rename_only(self, app: Flask) -> None:
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -250,43 +233,40 @@ class TestTriggerSubscriptionCrud:
with (
app.test_request_context("/", json={"name": "x"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
return_value=sub,
),
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"),
):
assert method(api, "s1") == 200
assert method(api, "t1", "s1") == 200
def test_update_not_found(self, app: Flask):
def test_update_not_found(self, app: Flask) -> None:
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "x"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
return_value=None,
),
):
with pytest.raises(NotFoundError):
method(api, "x")
method(api, "t1", "x")
def test_update_rebuild(self, app: Flask):
def test_update_rebuild(self, app: Flask) -> None:
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
sub = MagicMock()
sub.provider_id = "github"
sub.credential_type = CredentialType.OAUTH2
sub.credentials = {}
sub.parameters = {}
sub.credentials = {"token": "old"}
sub.parameters = {"repo": "demo"}
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
return_value=sub,
@ -295,9 +275,9 @@ class TestTriggerSubscriptionCrud:
"controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription"
),
):
assert method(api, "s1") == 200
assert method(api, "t1", "s1") == 200
def test_delete_subscription(self, app: Flask):
def test_delete_subscription(self, app: Flask) -> None:
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
@ -305,7 +285,6 @@ class TestTriggerSubscriptionCrud:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
patch("controllers.console.workspace.trigger_providers.sessionmaker") as mock_session_cls,
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"),
@ -316,17 +295,16 @@ class TestTriggerSubscriptionCrud:
mock_db.engine = MagicMock()
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
result = method(api, "sub1")
result = method(api, "t1", "sub1")
assert result["result"] == "success"
def test_delete_subscription_value_error(self, app: Flask):
def test_delete_subscription_value_error(self, app: Flask) -> None:
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
patch("controllers.console.workspace.trigger_providers.sessionmaker") as session_cls,
patch(
@ -338,21 +316,20 @@ class TestTriggerSubscriptionCrud:
session_cls.return_value.begin.return_value.__enter__.return_value = MagicMock()
with pytest.raises(BadRequest):
method(api, "sub1")
method(api, "t1", "sub1")
class TestTriggerOAuthApis:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_oauth_authorize_success(self, app: Flask):
def test_oauth_authorize_success(self, app: Flask) -> None:
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value={"a": 1},
@ -370,25 +347,24 @@ class TestTriggerOAuthApis:
return_value=MagicMock(authorization_url="url"),
),
):
resp = method(api, "github")
resp = method(api, "t1", mock_user(), "github")
assert resp.status_code == 200
def test_oauth_authorize_no_client(self, app: Flask):
def test_oauth_authorize_no_client(self, app: Flask) -> None:
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value=None,
),
):
with pytest.raises(NotFoundError):
method(api, "github")
method(api, "t1", mock_user(), "github")
def test_oauth_callback_forbidden(self, app: Flask):
def test_oauth_callback_forbidden(self, app: Flask) -> None:
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -396,7 +372,7 @@ class TestTriggerOAuthApis:
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_success(self, app: Flask):
def test_oauth_callback_success(self, app: Flask) -> None:
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -426,7 +402,7 @@ class TestTriggerOAuthApis:
resp = method(api, "github")
assert resp.status_code == 302
def test_oauth_callback_no_oauth_client(self, app: Flask):
def test_oauth_callback_no_oauth_client(self, app: Flask) -> None:
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -450,7 +426,7 @@ class TestTriggerOAuthApis:
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_empty_credentials(self, app: Flask):
def test_oauth_callback_empty_credentials(self, app: Flask) -> None:
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -481,16 +457,15 @@ class TestTriggerOAuthApis:
class TestTriggerOAuthClientManageApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_get_client(self, app: Flask):
def test_get_client(self, app: Flask) -> None:
api = TriggerOAuthClientManageApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_custom_oauth_client_params",
return_value={},
@ -508,84 +483,79 @@ class TestTriggerOAuthClientManageApi:
return_value=MagicMock(get_oauth_client_schema=lambda: {}),
),
):
result = method(api, "github")
result = method(api, "t1", "github")
assert "configured" in result
def test_post_client(self, app: Flask):
def test_post_client(self, app: Flask) -> None:
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"enabled": True}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "github") == {"ok": True}
assert method(api, "t1", "github") == {"ok": True}
def test_delete_client(self, app: Flask):
def test_delete_client(self, app: Flask) -> None:
api = TriggerOAuthClientManageApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "github") == {"ok": True}
assert method(api, "t1", "github") == {"ok": True}
def test_oauth_client_post_value_error(self, app: Flask):
def test_oauth_client_post_value_error(self, app: Flask) -> None:
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"enabled": True}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
side_effect=ValueError("bad"),
),
):
with pytest.raises(BadRequest):
method(api, "github")
method(api, "t1", "github")
class TestTriggerSubscriptionVerifyApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_verify_success(self, app: Flask):
def test_verify_success(self, app: Flask) -> None:
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
return_value={"ok": True},
),
):
assert method(api, "github", "s1") == {"ok": True}
assert method(api, "t1", mock_user(), "github", "s1") == {"ok": True}
@pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")])
def test_verify_errors(self, app: Flask, raised_exception):
def test_verify_errors(self, app: Flask, raised_exception: Exception) -> None:
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
side_effect=raised_exception,
),
):
with pytest.raises(BadRequest):
method(api, "github", "s1")
method(api, "t1", mock_user(), "github", "s1")

View File

@ -11,19 +11,29 @@ Covers:
"""
from collections.abc import Generator
from types import SimpleNamespace
from unittest.mock import patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy.orm import Session, sessionmaker
from models import Account, Tenant
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate
from models.enums import DataSourceType
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
def _make_account(account_id: str, tenant_id: str) -> Account:
account = Account(name="Test User", email=f"{account_id}@example.com")
account.id = account_id
tenant = Tenant(name="Test Tenant")
tenant.id = tenant_id
account._current_tenant = tenant
return account
class TestRagPipelineServiceGetPipeline:
"""Integration tests for RagPipelineService.get_pipeline."""
@ -32,7 +42,7 @@ class TestRagPipelineServiceGetPipeline:
yield
db_session_with_containers.rollback()
def _make_service(self, flask_app_with_containers) -> RagPipelineService:
def _make_service(self, flask_app_with_containers: Flask) -> RagPipelineService:
with (
patch(
"services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository",
@ -72,7 +82,7 @@ class TestRagPipelineServiceGetPipeline:
return dataset
def test_get_pipeline_raises_when_dataset_not_found(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
) -> None:
"""get_pipeline raises ValueError when dataset does not exist."""
service = self._make_service(flask_app_with_containers)
@ -81,7 +91,7 @@ class TestRagPipelineServiceGetPipeline:
service.get_pipeline(tenant_id=str(uuid4()), dataset_id=str(uuid4()))
def test_get_pipeline_raises_when_pipeline_not_found(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
) -> None:
"""get_pipeline raises ValueError when dataset exists but has no linked pipeline."""
tenant_id = str(uuid4())
@ -95,7 +105,7 @@ class TestRagPipelineServiceGetPipeline:
service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id)
def test_get_pipeline_returns_pipeline_when_found(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
) -> None:
"""get_pipeline returns the Pipeline when both Dataset and Pipeline exist."""
tenant_id = str(uuid4())
@ -139,43 +149,44 @@ class TestUpdateCustomizedPipelineTemplate:
db_session.flush()
return template
def test_update_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None:
def test_update_template_succeeds(
self, db_session_with_containers: Session, flask_app_with_containers: Flask
) -> None:
"""update_customized_pipeline_template updates name and description."""
tenant_id = str(uuid4())
created_by = str(uuid4())
template = self._create_template(db_session_with_containers, tenant_id, created_by)
db_session_with_containers.flush()
fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id)
account = _make_account(created_by, tenant_id)
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
info = PipelineTemplateInfoEntity(
name="Updated Name",
description="Updated description",
icon_info=IconInfo(icon="🔥"),
)
result = RagPipelineService.update_customized_pipeline_template(template.id, info)
info = PipelineTemplateInfoEntity(
name="Updated Name",
description="Updated description",
icon_info=IconInfo(icon="🔥"),
)
result = RagPipelineService.update_customized_pipeline_template(template.id, info, account, tenant_id)
assert result.name == "Updated Name"
assert result.description == "Updated description"
def test_update_template_raises_when_not_found(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
) -> None:
"""update_customized_pipeline_template raises ValueError when template doesn't exist."""
fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4()))
tenant_id = str(uuid4())
account = _make_account(str(uuid4()), tenant_id)
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
info = PipelineTemplateInfoEntity(
name="New Name",
description="desc",
icon_info=IconInfo(icon="📄"),
)
with pytest.raises(ValueError, match="Customized pipeline template not found"):
RagPipelineService.update_customized_pipeline_template(str(uuid4()), info)
info = PipelineTemplateInfoEntity(
name="New Name",
description="desc",
icon_info=IconInfo(icon="📄"),
)
with pytest.raises(ValueError, match="Customized pipeline template not found"):
RagPipelineService.update_customized_pipeline_template(str(uuid4()), info, account, tenant_id)
def test_update_template_raises_on_duplicate_name(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
) -> None:
"""update_customized_pipeline_template raises ValueError when new name already exists."""
tenant_id = str(uuid4())
@ -184,16 +195,15 @@ class TestUpdateCustomizedPipelineTemplate:
self._create_template(db_session_with_containers, tenant_id, created_by, name="Duplicate")
db_session_with_containers.flush()
fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id)
account = _make_account(created_by, tenant_id)
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
info = PipelineTemplateInfoEntity(
name="Duplicate",
description="desc",
icon_info=IconInfo(icon="📄"),
)
with pytest.raises(ValueError, match="Template name is already exists"):
RagPipelineService.update_customized_pipeline_template(template1.id, info)
info = PipelineTemplateInfoEntity(
name="Duplicate",
description="desc",
icon_info=IconInfo(icon="📄"),
)
with pytest.raises(ValueError, match="Template name is already exists"):
RagPipelineService.update_customized_pipeline_template(template1.id, info, account, tenant_id)
class TestDeleteCustomizedPipelineTemplate:
@ -221,7 +231,9 @@ class TestDeleteCustomizedPipelineTemplate:
db_session.flush()
return template
def test_delete_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None:
def test_delete_template_succeeds(
self, db_session_with_containers: Session, flask_app_with_containers: Flask
) -> None:
"""delete_customized_pipeline_template removes the template from the DB."""
tenant_id = str(uuid4())
created_by = str(uuid4())
@ -229,27 +241,23 @@ class TestDeleteCustomizedPipelineTemplate:
template_id = template.id
db_session_with_containers.flush()
fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id)
RagPipelineService.delete_customized_pipeline_template(template_id, tenant_id)
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
RagPipelineService.delete_customized_pipeline_template(template_id)
# Verify the record is deleted within the same context
from sqlalchemy import select
# Verify the record is deleted within the same context
from sqlalchemy import select
from extensions.ext_database import db as ext_db
from extensions.ext_database import db as ext_db
remaining = ext_db.session.scalar(
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id)
)
assert remaining is None
remaining = ext_db.session.scalar(
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id)
)
assert remaining is None
def test_delete_template_raises_when_not_found(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
) -> None:
"""delete_customized_pipeline_template raises ValueError when template doesn't exist."""
fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4()))
tenant_id = str(uuid4())
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
with pytest.raises(ValueError, match="Customized pipeline template not found"):
RagPipelineService.delete_customized_pipeline_template(str(uuid4()))
with pytest.raises(ValueError, match="Customized pipeline template not found"):
RagPipelineService.delete_customized_pipeline_template(str(uuid4()), tenant_id)

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from unittest.mock import Mock, patch
from unittest.mock import patch
from uuid import uuid4
import pytest
@ -8,6 +8,7 @@ from flask import Flask
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import Account, Tenant
from models.dataset import Dataset, DatasetMetadataBinding, Document
from models.enums import DataSourceType, DocumentCreatedFrom
from services.entities.knowledge_entities.knowledge_entities import (
@ -33,7 +34,7 @@ def _create_dataset(db_session: Session, *, tenant_id: str, built_in_field_enabl
def _create_document(
db_session: Session, *, dataset_id: str, tenant_id: str, doc_metadata: dict | None = None
db_session: Session, *, dataset_id: str, tenant_id: str, doc_metadata: dict[str, str] | None = None
) -> Document:
document = Document(
tenant_id=tenant_id,
@ -63,18 +64,21 @@ class TestMetadataPartialUpdate:
return str(uuid4())
@pytest.fixture
def mock_current_account(self, user_id, tenant_id):
account = Mock(id=user_id, current_tenant_id=tenant_id)
with patch("services.metadata_service.current_account_with_tenant", return_value=(account, tenant_id)):
yield account
def current_account(self, user_id: str, tenant_id: str) -> Account:
account = Account(name="Test User", email=f"{user_id}@example.com")
account.id = user_id
tenant = Tenant(name="Test Tenant")
tenant.id = tenant_id
account._current_tenant = tenant
return account
def test_partial_update_merges_metadata(
self,
flask_app_with_containers: Flask,
db_session_with_containers: Session,
tenant_id: str,
mock_current_account,
):
current_account: Account,
) -> None:
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
document = _create_document(
db_session_with_containers,
@ -91,7 +95,7 @@ class TestMetadataPartialUpdate:
)
metadata_args = MetadataOperationData(operation_data=[operation])
MetadataService.update_documents_metadata(dataset, metadata_args)
MetadataService.update_documents_metadata(dataset, metadata_args, current_account)
db_session_with_containers.expire_all()
updated_doc = db_session_with_containers.get(Document, document.id)
@ -104,8 +108,8 @@ class TestMetadataPartialUpdate:
flask_app_with_containers: Flask,
db_session_with_containers: Session,
tenant_id: str,
mock_current_account,
):
current_account: Account,
) -> None:
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
document = _create_document(
db_session_with_containers,
@ -122,7 +126,7 @@ class TestMetadataPartialUpdate:
)
metadata_args = MetadataOperationData(operation_data=[operation])
MetadataService.update_documents_metadata(dataset, metadata_args)
MetadataService.update_documents_metadata(dataset, metadata_args, current_account)
db_session_with_containers.expire_all()
updated_doc = db_session_with_containers.get(Document, document.id)
@ -134,10 +138,10 @@ class TestMetadataPartialUpdate:
self,
flask_app_with_containers: Flask,
db_session_with_containers: Session,
tenant_id,
user_id,
mock_current_account,
):
tenant_id: str,
user_id: str,
current_account: Account,
) -> None:
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
document = _create_document(
db_session_with_containers,
@ -164,7 +168,7 @@ class TestMetadataPartialUpdate:
)
metadata_args = MetadataOperationData(operation_data=[operation])
MetadataService.update_documents_metadata(dataset, metadata_args)
MetadataService.update_documents_metadata(dataset, metadata_args, current_account)
db_session_with_containers.expire_all()
bindings = db_session_with_containers.scalars(
@ -180,8 +184,8 @@ class TestMetadataPartialUpdate:
flask_app_with_containers: Flask,
db_session_with_containers: Session,
tenant_id: str,
mock_current_account,
):
current_account: Account,
) -> None:
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
document = _create_document(
db_session_with_containers,
@ -200,4 +204,4 @@ class TestMetadataPartialUpdate:
with patch("services.metadata_service.db.session.commit", side_effect=RuntimeError("database connection lost")):
with pytest.raises(RuntimeError, match="database connection lost"):
MetadataService.update_documents_metadata(dataset, metadata_args)
MetadataService.update_documents_metadata(dataset, metadata_args, current_account)

View File

@ -1,4 +1,6 @@
from unittest.mock import create_autospec, patch
from collections.abc import Generator
from typing import TypedDict
from unittest.mock import Mock, patch
import pytest
from faker import Faker
@ -6,21 +8,25 @@ from sqlalchemy.orm import Session
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document
from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom
from models.enums import DataSourceType, DocumentCreatedFrom
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService
class MetadataServiceDeps(TypedDict):
redis_client: Mock
document_service: Mock
class TestMetadataService:
"""Integration tests for MetadataService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
def mock_external_service_dependencies(self) -> Generator[MetadataServiceDeps, None, None]:
"""Mock setup for external service dependencies."""
with (
patch("libs.login.current_user", create_autospec(Account, instance=True)) as mock_current_user,
patch("services.metadata_service.redis_client") as mock_redis_client,
patch("services.dataset_service.DocumentService") as mock_document_service,
):
@ -30,12 +36,15 @@ class TestMetadataService:
mock_redis_client.delete.return_value = 1
yield {
"current_user": mock_current_user,
"redis_client": mock_redis_client,
"document_service": mock_document_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
def _create_test_account_and_tenant(
self,
db_session_with_containers: Session,
mock_external_service_dependencies: MetadataServiceDeps,
) -> tuple[Account, Tenant]:
"""
Helper method to create a test account and tenant for testing.
@ -53,7 +62,7 @@ class TestMetadataService:
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
db_session_with_containers.add(account)
@ -62,7 +71,7 @@ class TestMetadataService:
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
status=TenantStatus.NORMAL,
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
@ -83,8 +92,12 @@ class TestMetadataService:
return account, tenant
def _create_test_dataset(
self, db_session_with_containers: Session, mock_external_service_dependencies, account, tenant
):
self,
db_session_with_containers: Session,
mock_external_service_dependencies: MetadataServiceDeps,
account: Account,
tenant: Tenant,
) -> Dataset:
"""
Helper method to create a test dataset for testing.
@ -114,8 +127,12 @@ class TestMetadataService:
return dataset
def _create_test_document(
self, db_session_with_containers: Session, mock_external_service_dependencies, dataset, account
):
self,
db_session_with_containers: Session,
mock_external_service_dependencies: MetadataServiceDeps,
dataset: Dataset,
account: Account,
) -> Document:
"""
Helper method to create a test document for testing.
@ -149,7 +166,9 @@ class TestMetadataService:
return document
def test_create_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
def test_create_metadata_success(
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test successful metadata creation with valid parameters.
"""
@ -161,14 +180,10 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata_args = MetadataArgs(type="string", name="test_metadata")
# Act: Execute the method under test
result = MetadataService.create_metadata(dataset.id, metadata_args)
result = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
# Assert: Verify the expected outcomes
assert result is not None
@ -185,8 +200,8 @@ class TestMetadataService:
assert result.created_at is not None
def test_create_metadata_name_too_long(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata creation fails when name exceeds 255 characters.
"""
@ -198,20 +213,16 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
long_name = "a" * 256 # 256 characters, exceeding 255 limit
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=long_name)
metadata_args = MetadataArgs(type="string", name=long_name)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
MetadataService.create_metadata(dataset.id, metadata_args)
MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
def test_create_metadata_name_already_exists(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata creation fails when name already exists in the same dataset.
"""
@ -223,24 +234,20 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Create first metadata
first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="duplicate_name")
MetadataService.create_metadata(dataset.id, first_metadata_args)
first_metadata_args = MetadataArgs(type="string", name="duplicate_name")
MetadataService.create_metadata(dataset.id, first_metadata_args, account, tenant.id)
# Try to create second metadata with same name
second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="duplicate_name")
second_metadata_args = MetadataArgs(type="number", name="duplicate_name")
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError, match="Metadata name already exists."):
MetadataService.create_metadata(dataset.id, second_metadata_args)
MetadataService.create_metadata(dataset.id, second_metadata_args, account, tenant.id)
def test_create_metadata_name_conflicts_with_built_in_field(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata creation fails when name conflicts with built-in field names.
"""
@ -252,21 +259,17 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Try to create metadata with built-in field name
built_in_field_name = BuiltInField.document_name
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=built_in_field_name)
metadata_args = MetadataArgs(type="string", name=built_in_field_name)
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
MetadataService.create_metadata(dataset.id, metadata_args)
MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
def test_update_metadata_name_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test successful metadata name update with valid parameters.
"""
@ -278,17 +281,13 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata first
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
metadata_args = MetadataArgs(type="string", name="old_name")
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
# Act: Execute the method under test
new_name = "new_name"
result = MetadataService.update_metadata_name(dataset.id, metadata.id, new_name)
result = MetadataService.update_metadata_name(dataset.id, metadata.id, new_name, account, tenant.id)
# Assert: Verify the expected outcomes
assert result is not None
@ -302,8 +301,8 @@ class TestMetadataService:
assert result.name == new_name
def test_update_metadata_name_too_long(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata name update fails when new name exceeds 255 characters.
"""
@ -315,24 +314,20 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata first
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
metadata_args = MetadataArgs(type="string", name="old_name")
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
# Try to update with too long name
long_name = "a" * 256 # 256 characters, exceeding 255 limit
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
MetadataService.update_metadata_name(dataset.id, metadata.id, long_name)
MetadataService.update_metadata_name(dataset.id, metadata.id, long_name, account, tenant.id)
def test_update_metadata_name_already_exists(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata name update fails when new name already exists in the same dataset.
"""
@ -344,24 +339,20 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Create two metadata entries
first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="first_metadata")
first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args)
first_metadata_args = MetadataArgs(type="string", name="first_metadata")
first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args, account, tenant.id)
second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="second_metadata")
second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args)
second_metadata_args = MetadataArgs(type="number", name="second_metadata")
second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args, account, tenant.id)
# Try to update first metadata with second metadata's name
with pytest.raises(ValueError, match="Metadata name already exists."):
MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata")
MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata", account, tenant.id)
def test_update_metadata_name_conflicts_with_built_in_field(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata name update fails when new name conflicts with built-in field names.
"""
@ -373,23 +364,19 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata first
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
metadata_args = MetadataArgs(type="string", name="old_name")
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
# Try to update with built-in field name
built_in_field_name = BuiltInField.document_name
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name)
MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name, account, tenant.id)
def test_update_metadata_name_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata name update fails when metadata ID does not exist.
"""
@ -401,10 +388,6 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Try to update non-existent metadata
import uuid
@ -412,12 +395,14 @@ class TestMetadataService:
new_name = "new_name"
# Act: Execute the method under test
result = MetadataService.update_metadata_name(dataset.id, fake_metadata_id, new_name)
result = MetadataService.update_metadata_name(dataset.id, fake_metadata_id, new_name, account, tenant.id)
# Assert: Verify the method returns None when metadata is not found
assert result is None
def test_delete_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
def test_delete_metadata_success(
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test successful metadata deletion with valid parameters.
"""
@ -429,13 +414,9 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata first
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="to_be_deleted")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
metadata_args = MetadataArgs(type="string", name="to_be_deleted")
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
# Act: Execute the method under test
result = MetadataService.delete_metadata(dataset.id, metadata.id)
@ -449,7 +430,9 @@ class TestMetadataService:
deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first()
assert deleted_metadata is None
def test_delete_metadata_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
def test_delete_metadata_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata deletion fails when metadata ID does not exist.
"""
@ -461,10 +444,6 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Try to delete non-existent metadata
import uuid
@ -477,8 +456,8 @@ class TestMetadataService:
assert result is None
def test_delete_metadata_with_document_bindings(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata deletion successfully removes document metadata bindings.
"""
@ -493,13 +472,9 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, dataset, account
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
# Create metadata binding
binding = DatasetMetadataBinding(
@ -531,7 +506,9 @@ class TestMetadataService:
# Note: The service attempts to update document metadata but may not succeed
# due to mock configuration. The main functionality (metadata deletion) is verified.
def test_get_built_in_fields_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
def test_get_built_in_fields_success(
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test successful retrieval of built-in metadata fields.
"""
@ -557,8 +534,8 @@ class TestMetadataService:
assert "time" in field_types
def test_enable_built_in_field_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test successful enabling of built-in fields for a dataset.
"""
@ -573,10 +550,6 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, dataset, account
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Mock DocumentService.get_working_documents_by_dataset_id
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [
document
@ -591,14 +564,14 @@ class TestMetadataService:
# Assert: Verify the expected outcomes
db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is True
assert dataset.built_in_field_enabled
# Note: Document metadata update depends on DocumentService mock working correctly
# The main functionality (enabling built-in fields) is verified
def test_enable_built_in_field_already_enabled(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test enabling built-in fields when they are already enabled.
"""
@ -610,10 +583,6 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Enable built-in fields first
dataset.built_in_field_enabled = True
@ -621,7 +590,9 @@ class TestMetadataService:
db_session_with_containers.commit()
# Mock DocumentService.get_working_documents_by_dataset_id
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = list[
Document
]()
# Act: Execute the method under test
MetadataService.enable_built_in_field(dataset)
@ -631,8 +602,8 @@ class TestMetadataService:
assert dataset.built_in_field_enabled is True
def test_enable_built_in_field_with_no_documents(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test enabling built-in fields for a dataset with no documents.
"""
@ -644,12 +615,10 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Mock DocumentService.get_working_documents_by_dataset_id to return empty list
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = list[
Document
]()
# Act: Execute the method under test
MetadataService.enable_built_in_field(dataset)
@ -657,11 +626,11 @@ class TestMetadataService:
# Assert: Verify the expected outcomes
db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is True
assert dataset.built_in_field_enabled
def test_disable_built_in_field_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test successful disabling of built-in fields for a dataset.
"""
@ -676,10 +645,6 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, dataset, account
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Enable built-in fields first
dataset.built_in_field_enabled = True
@ -713,8 +678,8 @@ class TestMetadataService:
# The main functionality (disabling built-in fields) is verified
def test_disable_built_in_field_already_disabled(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test disabling built-in fields when they are already disabled.
"""
@ -726,15 +691,13 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Verify dataset starts with built-in fields disabled
assert dataset.built_in_field_enabled is False
# Mock DocumentService.get_working_documents_by_dataset_id
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = list[
Document
]()
# Act: Execute the method under test
MetadataService.disable_built_in_field(dataset)
@ -742,11 +705,11 @@ class TestMetadataService:
# Assert: Verify the method returns early without changes
db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is False
assert not dataset.built_in_field_enabled
def test_disable_built_in_field_with_no_documents(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test disabling built-in fields for a dataset with no documents.
"""
@ -758,10 +721,6 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Enable built-in fields first
dataset.built_in_field_enabled = True
@ -769,7 +728,9 @@ class TestMetadataService:
db_session_with_containers.commit()
# Mock DocumentService.get_working_documents_by_dataset_id to return empty list
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = list[
Document
]()
# Act: Execute the method under test
MetadataService.disable_built_in_field(dataset)
@ -779,8 +740,8 @@ class TestMetadataService:
assert dataset.built_in_field_enabled is False
def test_update_documents_metadata_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test successful update of documents metadata.
"""
@ -795,13 +756,9 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, dataset, account
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
# Mock DocumentService.get_document
mock_external_service_dependencies["document_service"].get_document.return_value = document
@ -820,7 +777,7 @@ class TestMetadataService:
operation_data = MetadataOperationData(operation_data=[operation])
# Act: Execute the method under test
MetadataService.update_documents_metadata(dataset, operation_data)
MetadataService.update_documents_metadata(dataset, operation_data, account)
# Assert: Verify the expected outcomes
@ -841,8 +798,8 @@ class TestMetadataService:
assert binding.dataset_id == dataset.id
def test_update_documents_metadata_with_built_in_fields_enabled(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test update of documents metadata when built-in fields are enabled.
"""
@ -863,13 +820,9 @@ class TestMetadataService:
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
# Mock DocumentService.get_document
mock_external_service_dependencies["document_service"].get_document.return_value = document
@ -888,7 +841,7 @@ class TestMetadataService:
operation_data = MetadataOperationData(operation_data=[operation])
# Act: Execute the method under test
MetadataService.update_documents_metadata(dataset, operation_data)
MetadataService.update_documents_metadata(dataset, operation_data, account)
# Assert: Verify the expected outcomes
# Verify document metadata was updated with both custom and built-in fields
@ -901,8 +854,8 @@ class TestMetadataService:
# The main functionality (custom metadata update) is verified
def test_update_documents_metadata_document_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test update of documents metadata when document is not found.
"""
@ -914,13 +867,9 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
# Create metadata operation data
from services.entities.knowledge_entities.knowledge_entities import (
@ -941,11 +890,11 @@ class TestMetadataService:
# Act & Assert: The method should raise ValueError("Document not found.")
# because the exception is now re-raised after rollback
with pytest.raises(ValueError, match="Document not found"):
MetadataService.update_documents_metadata(dataset, operation_data)
MetadataService.update_documents_metadata(dataset, operation_data, account)
def test_knowledge_base_metadata_lock_check_dataset_id(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata lock check for dataset operations.
"""
@ -967,8 +916,8 @@ class TestMetadataService:
assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}"
def test_knowledge_base_metadata_lock_check_document_id(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata lock check for document operations.
"""
@ -990,8 +939,8 @@ class TestMetadataService:
assert call_args[0][0] == f"document_metadata_lock_{document_id}"
def test_knowledge_base_metadata_lock_check_lock_exists(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata lock check when lock already exists.
"""
@ -1007,8 +956,8 @@ class TestMetadataService:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
def test_knowledge_base_metadata_lock_check_document_lock_exists(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test metadata lock check when document lock already exists.
"""
@ -1022,8 +971,8 @@ class TestMetadataService:
MetadataService.knowledge_base_metadata_lock_check(None, document_id)
def test_get_dataset_metadatas_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test successful retrieval of dataset metadata information.
"""
@ -1035,13 +984,9 @@ class TestMetadataService:
db_session_with_containers, mock_external_service_dependencies, account, tenant
)
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
# Create document and metadata binding
document = self._create_test_document(
@ -1079,8 +1024,8 @@ class TestMetadataService:
assert result["built_in_field_enabled"] is False
def test_get_dataset_metadatas_with_built_in_fields_enabled(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test retrieval of dataset metadata when built-in fields are enabled.
"""
@ -1098,13 +1043,9 @@ class TestMetadataService:
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
mock_external_service_dependencies["current_user"].id = account.id
# Create metadata
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
metadata_args = MetadataArgs(type="string", name="test_metadata")
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
# Act: Execute the method under test
result = MetadataService.get_dataset_metadatas(dataset)
@ -1122,8 +1063,8 @@ class TestMetadataService:
assert result["built_in_field_enabled"] is True
def test_get_dataset_metadatas_no_metadata(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
) -> None:
"""
Test retrieval of dataset metadata when no metadata exists.
"""

View File

@ -15,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import (
PipelineTemplateListApi,
PublishCustomizedPipelineTemplateApi,
)
from models.account import Account
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
@ -50,24 +51,31 @@ def _payload() -> dict[str, object]:
}
def _account() -> Account:
account = Account(name="Test User", email="test@example.com")
account.id = "account-1"
return account
class TestPipelineTemplateListApi:
def test_get_uses_query_defaults_and_serializes_nullable_fields(self, app: Flask) -> None:
api = PipelineTemplateListApi()
method = unwrap(api.get)
service_calls: list[tuple[str, str]] = []
tenant_id = "tenant-1"
service_calls: list[tuple[str, str, str]] = []
def get_pipeline_templates(template_type: str, language: str) -> dict[str, object]:
service_calls.append((template_type, language))
def get_pipeline_templates(template_type: str, language: str, current_tenant_id: str) -> dict[str, object]:
service_calls.append((template_type, language, current_tenant_id))
return {"pipeline_templates": [_template_item()]}
with (
app.test_request_context("/rag/pipeline/templates"),
patch.object(module.RagPipelineService, "get_pipeline_templates", side_effect=get_pipeline_templates),
):
response, status = method(api)
response, status = method(api, tenant_id)
assert status == 200
assert service_calls == [("built-in", "en-US")]
assert service_calls == [("built-in", "en-US", tenant_id)]
assert response == {
"pipeline_templates": [
{
@ -81,21 +89,22 @@ class TestPipelineTemplateListApi:
def test_get_passes_explicit_query_to_service(self, app: Flask) -> None:
api = PipelineTemplateListApi()
method = unwrap(api.get)
service_calls: list[tuple[str, str]] = []
tenant_id = "tenant-1"
service_calls: list[tuple[str, str, str]] = []
def get_pipeline_templates(template_type: str, language: str) -> dict[str, object]:
service_calls.append((template_type, language))
def get_pipeline_templates(template_type: str, language: str, current_tenant_id: str) -> dict[str, object]:
service_calls.append((template_type, language, current_tenant_id))
return {"pipeline_templates": []}
with (
app.test_request_context("/rag/pipeline/templates?type=customized&language=ja-JP"),
patch.object(module.RagPipelineService, "get_pipeline_templates", side_effect=get_pipeline_templates),
):
response, status = method(api)
response, status = method(api, tenant_id)
assert status == 200
assert response == {"pipeline_templates": []}
assert service_calls == [("customized", "ja-JP")]
assert service_calls == [("customized", "ja-JP", tenant_id)]
class TestPipelineTemplateDetailApi:
@ -140,22 +149,28 @@ class TestCustomizedPipelineTemplateApi:
api = CustomizedPipelineTemplateApi()
method = unwrap(api.patch)
payload = _payload()
service_calls: list[tuple[str, PipelineTemplateInfoEntity]] = []
account = _account()
tenant_id = "tenant-1"
service_calls: list[tuple[str, PipelineTemplateInfoEntity, Account, str]] = []
def update_template(template_id: str, template_info: PipelineTemplateInfoEntity) -> None:
service_calls.append((template_id, template_info))
def update_template(
template_id: str, template_info: PipelineTemplateInfoEntity, current_user: Account, current_tenant_id: str
) -> None:
service_calls.append((template_id, template_info, current_user, current_tenant_id))
with (
app.test_request_context("/rag/pipeline/customized/templates/template-1", method="PATCH", json=payload),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(module.RagPipelineService, "update_customized_pipeline_template", side_effect=update_template),
):
response, status = method(api, "template-1")
response, status = method(api, tenant_id, account, "template-1")
assert (response, status) == ("", 204)
assert len(service_calls) == 1
template_id, template_info = service_calls[0]
template_id, template_info, current_user, current_tenant_id = service_calls[0]
assert template_id == "template-1"
assert current_user is account
assert current_tenant_id == tenant_id
assert template_info.name == "Updated template"
assert template_info.description == "Updated description"
assert template_info.icon_info.model_dump() == {
@ -172,22 +187,28 @@ class TestCustomizedPipelineTemplateApi:
"name": "Updated template",
"description": "Updated description",
}
service_calls: list[tuple[str, PipelineTemplateInfoEntity]] = []
account = _account()
tenant_id = "tenant-1"
service_calls: list[tuple[str, PipelineTemplateInfoEntity, Account, str]] = []
def update_template(template_id: str, template_info: PipelineTemplateInfoEntity) -> None:
service_calls.append((template_id, template_info))
def update_template(
template_id: str, template_info: PipelineTemplateInfoEntity, current_user: Account, current_tenant_id: str
) -> None:
service_calls.append((template_id, template_info, current_user, current_tenant_id))
with (
app.test_request_context("/rag/pipeline/customized/templates/template-1", method="PATCH", json=payload),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(module.RagPipelineService, "update_customized_pipeline_template", side_effect=update_template),
):
response, status = method(api, "template-1")
response, status = method(api, tenant_id, account, "template-1")
assert (response, status) == ("", 204)
assert len(service_calls) == 1
template_id, template_info = service_calls[0]
template_id, template_info, current_user, current_tenant_id = service_calls[0]
assert template_id == "template-1"
assert current_user is account
assert current_tenant_id == tenant_id
assert template_info.icon_info.model_dump() == {
"icon": "",
"icon_background": None,
@ -198,19 +219,20 @@ class TestCustomizedPipelineTemplateApi:
def test_delete_returns_empty_204(self, app: Flask) -> None:
api = CustomizedPipelineTemplateApi()
method = unwrap(api.delete)
deleted_template_ids: list[str] = []
tenant_id = "tenant-1"
deleted_templates: list[tuple[str, str]] = []
def delete_template(template_id: str) -> None:
deleted_template_ids.append(template_id)
def delete_template(template_id: str, current_tenant_id: str) -> None:
deleted_templates.append((template_id, current_tenant_id))
with (
app.test_request_context("/rag/pipeline/customized/templates/template-1", method="DELETE"),
patch.object(module.RagPipelineService, "delete_customized_pipeline_template", side_effect=delete_template),
):
response, status = method(api, "template-1")
response, status = method(api, tenant_id, "template-1")
assert (response, status) == ("", 204)
assert deleted_template_ids == ["template-1"]
assert deleted_templates == [("template-1", tenant_id)]
def test_post_exports_yaml_from_orm_template(self, app: Flask) -> None:
api = CustomizedPipelineTemplateApi()
@ -292,21 +314,25 @@ class TestPublishCustomizedPipelineTemplateApi:
api = PublishCustomizedPipelineTemplateApi()
method = unwrap(api.post)
payload = _payload()
service_calls: list[tuple[str, dict[str, object]]] = []
account = _account()
tenant_id = "tenant-1"
service_calls: list[tuple[str, dict[str, object], Account, str]] = []
class Service:
def publish_customized_pipeline_template(self, pipeline_id: str, data: dict[str, object]) -> None:
service_calls.append((pipeline_id, data))
def publish_customized_pipeline_template(
self, pipeline_id: str, data: dict[str, object], current_user: Account, current_tenant_id: str
) -> None:
service_calls.append((pipeline_id, data, current_user, current_tenant_id))
with (
app.test_request_context("/rag/pipelines/pipeline-1/customized/publish", method="POST", json=payload),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(module, "RagPipelineService", Service),
):
response, status = method(api, "pipeline-1")
response, status = method(api, tenant_id, account, "pipeline-1")
assert (response, status) == ("", 204)
assert service_calls == [("pipeline-1", payload)]
assert service_calls == [("pipeline-1", payload, account, tenant_id)]
def test_post_allows_missing_icon_info_for_publish_service_fallback(self, app: Flask) -> None:
api = PublishCustomizedPipelineTemplateApi()
@ -315,18 +341,22 @@ class TestPublishCustomizedPipelineTemplateApi:
"name": "Published template",
"description": "Description",
}
service_calls: list[tuple[str, dict[str, object]]] = []
account = _account()
tenant_id = "tenant-1"
service_calls: list[tuple[str, dict[str, object], Account, str]] = []
class Service:
def publish_customized_pipeline_template(self, pipeline_id: str, data: dict[str, object]) -> None:
service_calls.append((pipeline_id, data))
def publish_customized_pipeline_template(
self, pipeline_id: str, data: dict[str, object], current_user: Account, current_tenant_id: str
) -> None:
service_calls.append((pipeline_id, data, current_user, current_tenant_id))
with (
app.test_request_context("/rag/pipelines/pipeline-1/customized/publish", method="POST", json=payload),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(module, "RagPipelineService", Service),
):
response, status = method(api, "pipeline-1")
response, status = method(api, tenant_id, account, "pipeline-1")
assert (response, status) == ("", 204)
assert service_calls == [
@ -341,5 +371,7 @@ class TestPublishCustomizedPipelineTemplateApi:
"icon_url": None,
},
},
account,
tenant_id,
)
]

View File

@ -1,4 +1,5 @@
import uuid
from inspect import unwrap
from unittest.mock import PropertyMock, patch
import pytest
@ -8,16 +9,10 @@ from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.datasets.hit_testing import HitTestingApi
from models.account import Account, Tenant, TenantAccountRole
from models.dataset import Dataset
def unwrap(func):
"""Recursively unwrap decorated functions."""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_hit_testing")
@ -35,6 +30,17 @@ def dataset():
return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1")
@pytest.fixture
def account() -> Account:
account = Account(name="User", email="user@example.com")
account.id = "account-1"
tenant = Tenant(name="Tenant")
tenant.id = "tenant-1"
account._current_tenant = tenant
account.role = TenantAccountRole.OWNER
return account
def hit_testing_record() -> dict[str, object]:
return {
"segment": {
@ -98,7 +104,7 @@ def bypass_decorators(mocker: MockerFixture):
class TestHitTestingApi:
def test_hit_testing_success(self, app: Flask, dataset, dataset_id):
def test_hit_testing_success(self, app: Flask, dataset, dataset_id, account: Account):
api = HitTestingApi()
method = unwrap(api.post)
@ -129,13 +135,13 @@ class TestHitTestingApi:
return_value={"query": {"content": "what is vector search"}, "records": []},
),
):
result = method(api, dataset_id)
result = method(api, account, "tenant-1", dataset_id)
assert "query" in result
assert "records" in result
assert result["records"] == []
def test_hit_testing_success_with_optional_record_fields(self, app: Flask, dataset, dataset_id):
def test_hit_testing_success_with_optional_record_fields(self, app: Flask, dataset, dataset_id, account: Account):
api = HitTestingApi()
method = unwrap(api.post)
@ -167,7 +173,7 @@ class TestHitTestingApi:
return_value={"query": {"content": payload["query"]}, "records": records},
),
):
result = method(api, dataset_id)
result = method(api, account, "tenant-1", dataset_id)
assert result["query"] == {"content": payload["query"]}
assert result["records"][0]["segment"]["keywords"] == []
@ -175,7 +181,7 @@ class TestHitTestingApi:
assert result["records"][0]["files"] == []
assert result["records"][0]["score"] is None
def test_hit_testing_dataset_not_found(self, app: Flask, dataset_id):
def test_hit_testing_dataset_not_found(self, app: Flask, dataset_id, account: Account):
api = HitTestingApi()
method = unwrap(api.post)
@ -198,9 +204,9 @@ class TestHitTestingApi:
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
method(api, account, "tenant-1", dataset_id)
def test_hit_testing_invalid_args(self, app: Flask, dataset, dataset_id):
def test_hit_testing_invalid_args(self, app: Flask, dataset, dataset_id, account: Account):
api = HitTestingApi()
method = unwrap(api.post)
@ -228,4 +234,4 @@ class TestHitTestingApi:
),
):
with pytest.raises(ValueError, match="Invalid parameters"):
method(api, dataset_id)
method(api, account, "tenant-1", dataset_id)

View File

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -21,7 +21,7 @@ from core.errors.error import (
QuotaExceededError,
)
from graphon.model_runtime.errors.invoke import InvokeError
from models.account import Account
from models.account import Account, Tenant, TenantAccountRole
from models.dataset import Dataset
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@ -29,19 +29,15 @@ from services.hit_testing_service import HitTestingService
@pytest.fixture
def account():
acc = MagicMock(spec=Account)
acc = Account(name="User", email="user@example.com")
acc.id = "account-1"
tenant = Tenant(name="Tenant")
tenant.id = "tenant-1"
acc._current_tenant = tenant
acc.role = TenantAccountRole.OWNER
return acc
@pytest.fixture(autouse=True)
def patch_current_user(mocker, account):
"""Patch current_user to a valid Account."""
mocker.patch(
"controllers.console.datasets.hit_testing_base.current_user",
account,
)
@pytest.fixture
def dataset():
return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1")
@ -86,7 +82,7 @@ def hit_testing_record() -> dict[str, object]:
class TestGetAndValidateDataset:
def test_success(self, dataset):
def test_success(self, dataset, account):
with (
patch.object(
DatasetService,
@ -98,20 +94,20 @@ class TestGetAndValidateDataset:
"check_dataset_permission",
),
):
result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1", account, "tenant-1")
assert result == dataset
def test_dataset_not_found(self):
def test_dataset_not_found(self, account):
with patch.object(
DatasetService,
"get_dataset",
return_value=None,
):
with pytest.raises(NotFound, match="Dataset not found"):
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1", account, "tenant-1")
def test_permission_denied(self, dataset):
def test_permission_denied(self, dataset, account):
with (
patch.object(
DatasetService,
@ -125,7 +121,7 @@ class TestGetAndValidateDataset:
),
):
with pytest.raises(Forbidden, match="no access"):
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1", account, "tenant-1")
class TestHitTestingArgsCheck:
@ -164,7 +160,7 @@ class TestParseArgs:
class TestPerformHitTesting:
def test_success(self, dataset):
def test_success(self, dataset, account):
response = {
"query": {"content": "hello"},
"records": [],
@ -175,12 +171,12 @@ class TestPerformHitTesting:
"retrieve",
return_value=response,
):
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
assert result["query"] == {"content": "hello"}
assert result["records"] == []
def test_success_prepares_nullable_list_fields(self, dataset):
def test_success_prepares_nullable_list_fields(self, dataset, account):
response = {
"query": {"content": "hello"},
"records": [hit_testing_record()],
@ -191,7 +187,7 @@ class TestPerformHitTesting:
"retrieve",
return_value=response,
):
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
assert result["query"] == {"content": "hello"}
record = result["records"][0]
@ -203,7 +199,7 @@ class TestPerformHitTesting:
assert record["tsne_position"] is None
assert record["summary"] is None
def test_invalid_query_response_raises_value_error(self, dataset):
def test_invalid_query_response_raises_value_error(self, dataset, account):
with (
patch.object(
HitTestingService,
@ -212,7 +208,7 @@ class TestPerformHitTesting:
),
pytest.raises(ValueError, match="Invalid hit testing query response"),
):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
def test_invalid_records_response_raises_value_error(self):
with pytest.raises(ValueError, match="Invalid hit testing records response"):
@ -222,74 +218,74 @@ class TestPerformHitTesting:
with pytest.raises(ValueError, match="Invalid hit testing record response"):
DatasetsHitTestingBase._prepare_hit_testing_records(["record"])
def test_index_not_initialized(self, dataset):
def test_index_not_initialized(self, dataset, account):
with patch.object(
HitTestingService,
"retrieve",
side_effect=services.errors.index.IndexNotInitializedError(),
):
with pytest.raises(DatasetNotInitializedError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
def test_provider_token_not_init(self, dataset):
def test_provider_token_not_init(self, dataset, account):
with patch.object(
HitTestingService,
"retrieve",
side_effect=ProviderTokenNotInitError("token missing"),
):
with pytest.raises(ProviderNotInitializeError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
def test_quota_exceeded(self, dataset):
def test_quota_exceeded(self, dataset, account):
with patch.object(
HitTestingService,
"retrieve",
side_effect=QuotaExceededError(),
):
with pytest.raises(ProviderQuotaExceededError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
def test_model_not_supported(self, dataset):
def test_model_not_supported(self, dataset, account):
with patch.object(
HitTestingService,
"retrieve",
side_effect=ModelCurrentlyNotSupportError(),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
def test_llm_bad_request(self, dataset):
def test_llm_bad_request(self, dataset, account):
with patch.object(
HitTestingService,
"retrieve",
side_effect=LLMBadRequestError("bad request"),
):
with pytest.raises(ProviderNotInitializeError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
def test_invoke_error(self, dataset):
def test_invoke_error(self, dataset, account):
with patch.object(
HitTestingService,
"retrieve",
side_effect=InvokeError("invoke failed"),
):
with pytest.raises(CompletionRequestError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
def test_value_error(self, dataset):
def test_value_error(self, dataset, account):
with patch.object(
HitTestingService,
"retrieve",
side_effect=ValueError("bad args"),
):
with pytest.raises(ValueError, match="bad args"):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
def test_unexpected_error(self, dataset):
def test_unexpected_error(self, dataset, account):
with patch.object(
HitTestingService,
"retrieve",
side_effect=Exception("boom"),
):
with pytest.raises(InternalServerError, match="boom"):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")

View File

@ -1,4 +1,5 @@
import uuid
from inspect import unwrap
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@ -14,6 +15,7 @@ from controllers.console.datasets.metadata import (
DatasetMetadataCreateApi,
DocumentMetadataEditApi,
)
from models.account import Account
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
@ -22,13 +24,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService
def unwrap(func):
"""Recursively unwrap decorated functions."""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_dataset_metadata")
@ -37,8 +32,8 @@ def app():
@pytest.fixture
def current_user():
user = MagicMock()
def current_user() -> Account:
user = Account(name="Test User", email="test@example.com")
user.id = "user-1"
return user
@ -116,7 +111,7 @@ class TestDatasetMetadataCreateApi:
return_value={"id": "m1", "type": "string", "name": "author"},
),
):
result, status = method(api, current_user, dataset_id)
result, status = method(api, "tenant-1", current_user, dataset_id)
assert status == 201
assert result["type"] == "string"
@ -151,7 +146,7 @@ class TestDatasetMetadataCreateApi:
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, current_user, dataset_id)
method(api, "tenant-1", current_user, dataset_id)
class TestDatasetMetadataGetApi:
@ -227,7 +222,7 @@ class TestDatasetMetadataApi:
return_value={"id": "m1", "type": "string", "name": "updated-name"},
),
):
result, status = method(api, current_user, dataset_id, metadata_id)
result, status = method(api, "tenant-1", current_user, dataset_id, metadata_id)
assert status == 200
assert result["type"] == "string"

View File

@ -1,3 +1,4 @@
from inspect import unwrap
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@ -15,15 +16,11 @@ from models.model import AppMode
from services.errors.llm import InvokeRateLimitError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def user():
return MagicMock(spec=Account)
account = Account(name="User", email="user.com")
account.id = "uid"
return account
@pytest.fixture
@ -59,7 +56,6 @@ class TestCompletionApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -71,18 +67,18 @@ class TestCompletionApi:
return_value=("ok", 200),
),
):
result = method(completion_app)
result = method(api, user, completion_app)
assert result == ("ok", 200)
def test_post_wrong_app_mode(self):
def test_post_wrong_app_mode(self, user):
api = completion_module.CompletionApi()
method = unwrap(api.post)
installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT))
with pytest.raises(NotCompletionAppError):
method(installed_app)
method(api, user, installed_app)
def test_conversation_completed(self, app: Flask, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
@ -91,7 +87,6 @@ class TestCompletionApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -99,7 +94,7 @@ class TestCompletionApi:
),
):
with pytest.raises(ConversationCompletedError):
method(completion_app)
method(api, user, completion_app)
def test_internal_error(self, app: Flask, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
@ -108,7 +103,6 @@ class TestCompletionApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -116,7 +110,7 @@ class TestCompletionApi:
),
):
with pytest.raises(InternalServerError):
method(completion_app)
method(api, user, completion_app)
def test_conversation_not_exists(self, app: Flask, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
@ -125,7 +119,6 @@ class TestCompletionApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -133,7 +126,7 @@ class TestCompletionApi:
),
):
with pytest.raises(completion_module.NotFound):
method(completion_app)
method(api, user, completion_app)
def test_app_unavailable(self, app: Flask, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
@ -142,7 +135,6 @@ class TestCompletionApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -150,7 +142,7 @@ class TestCompletionApi:
),
):
with pytest.raises(completion_module.AppUnavailableError):
method(completion_app)
method(api, user, completion_app)
def test_provider_not_initialized(self, app: Flask, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
@ -159,7 +151,6 @@ class TestCompletionApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -167,7 +158,7 @@ class TestCompletionApi:
),
):
with pytest.raises(completion_module.ProviderNotInitializeError):
method(completion_app)
method(api, user, completion_app)
def test_quota_exceeded(self, app: Flask, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
@ -176,7 +167,6 @@ class TestCompletionApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -184,7 +174,7 @@ class TestCompletionApi:
),
):
with pytest.raises(completion_module.ProviderQuotaExceededError):
method(completion_app)
method(api, user, completion_app)
def test_model_not_supported(self, app: Flask, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
@ -193,7 +183,6 @@ class TestCompletionApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -201,7 +190,7 @@ class TestCompletionApi:
),
):
with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError):
method(completion_app)
method(api, user, completion_app)
def test_invoke_error(self, app: Flask, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
@ -210,7 +199,6 @@ class TestCompletionApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -218,7 +206,7 @@ class TestCompletionApi:
),
):
with pytest.raises(completion_module.CompletionRequestError):
method(completion_app)
method(api, user, completion_app)
class TestCompletionStopApi:
@ -250,7 +238,6 @@ class TestChatApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -262,18 +249,18 @@ class TestChatApi:
return_value=("ok", 200),
),
):
result = method(chat_app)
result = method(api, user, chat_app)
assert result == ("ok", 200)
def test_post_not_chat_app(self):
def test_post_not_chat_app(self, user):
api = completion_module.ChatApi()
method = unwrap(api.post)
installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
with pytest.raises(NotChatAppError):
method(installed_app)
method(api, user, installed_app)
def test_rate_limit_error(self, app: Flask, chat_app, user, payload_patch):
api = completion_module.ChatApi()
@ -282,7 +269,6 @@ class TestChatApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -290,7 +276,7 @@ class TestChatApi:
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(chat_app)
method(api, user, chat_app)
def test_conversation_completed_chat(self, app: Flask, chat_app, user, payload_patch):
api = completion_module.ChatApi()
@ -299,7 +285,6 @@ class TestChatApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -307,7 +292,7 @@ class TestChatApi:
),
):
with pytest.raises(ConversationCompletedError):
method(chat_app)
method(api, user, chat_app)
def test_conversation_not_exists_chat(self, app: Flask, chat_app, user, payload_patch):
api = completion_module.ChatApi()
@ -316,7 +301,6 @@ class TestChatApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -324,7 +308,7 @@ class TestChatApi:
),
):
with pytest.raises(completion_module.NotFound):
method(chat_app)
method(api, user, chat_app)
def test_app_unavailable_chat(self, app: Flask, chat_app, user, payload_patch):
api = completion_module.ChatApi()
@ -333,7 +317,6 @@ class TestChatApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -341,7 +324,7 @@ class TestChatApi:
),
):
with pytest.raises(completion_module.AppUnavailableError):
method(chat_app)
method(api, user, chat_app)
def test_provider_not_initialized_chat(self, app: Flask, chat_app, user, payload_patch):
api = completion_module.ChatApi()
@ -350,7 +333,6 @@ class TestChatApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -358,7 +340,7 @@ class TestChatApi:
),
):
with pytest.raises(completion_module.ProviderNotInitializeError):
method(chat_app)
method(api, user, chat_app)
def test_quota_exceeded_chat(self, app: Flask, chat_app, user, payload_patch):
api = completion_module.ChatApi()
@ -367,7 +349,6 @@ class TestChatApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -375,7 +356,7 @@ class TestChatApi:
),
):
with pytest.raises(completion_module.ProviderQuotaExceededError):
method(chat_app)
method(api, user, chat_app)
def test_model_not_supported_chat(self, app: Flask, chat_app, user, payload_patch):
api = completion_module.ChatApi()
@ -384,7 +365,6 @@ class TestChatApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -392,7 +372,7 @@ class TestChatApi:
),
):
with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError):
method(chat_app)
method(api, user, chat_app)
def test_invoke_error_chat(self, app: Flask, chat_app, user, payload_patch):
api = completion_module.ChatApi()
@ -401,7 +381,6 @@ class TestChatApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -409,7 +388,7 @@ class TestChatApi:
),
):
with pytest.raises(completion_module.CompletionRequestError):
method(chat_app)
method(api, user, chat_app)
def test_internal_error_chat(self, app: Flask, chat_app, user, payload_patch):
api = completion_module.ChatApi()
@ -418,7 +397,6 @@ class TestChatApi:
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
@ -426,7 +404,7 @@ class TestChatApi:
),
):
with pytest.raises(InternalServerError):
method(chat_app)
method(api, user, chat_app)
class TestChatStopApi:

View File

@ -1,32 +1,36 @@
from inspect import unwrap
from pytest_mock import MockerFixture
from werkzeug.exceptions import Unauthorized
from models import Account
from services.feature_service import FeatureModel, LimitationModel, SystemFeatureModel
def unwrap(func):
"""
Recursively unwrap decorated functions.
"""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def make_account() -> Account:
account = Account(name="Alice", email="alice@example.com")
account.id = "account-1"
return account
class TestFeatureApi:
def test_get_tenant_features_success(self, mocker: MockerFixture):
from controllers.console.feature import FeatureApi
features = FeatureModel(
knowledge_rate_limit=42,
vector_space=LimitationModel(size=1, limit=2),
)
get_features = mocker.patch("controllers.console.feature.FeatureService.get_features")
get_features.return_value.model_dump.return_value = {
"features": {"feature_a": True},
"vector_space": {"size": 1, "limit": 2},
}
get_features.return_value = features
api = FeatureApi()
raw_get = unwrap(FeatureApi.get)
result = raw_get(api, "tenant_123")
assert result == {"features": {"feature_a": True}}
expected = features.model_dump()
expected.pop("vector_space")
assert result == expected
get_features.assert_called_once_with("tenant_123", exclude_vector_space=True)
@ -35,7 +39,7 @@ class TestFeatureVectorSpaceApi:
from controllers.console.feature import FeatureVectorSpaceApi
get_vector_space = mocker.patch("controllers.console.feature.FeatureService.get_vector_space")
get_vector_space.return_value.model_dump.return_value = {"size": 5120, "limit": 20480}
get_vector_space.return_value = LimitationModel(size=5120, limit=20480)
api = FeatureVectorSpaceApi()
@ -85,22 +89,23 @@ class TestSystemFeatureApi:
from controllers.console.feature import SystemFeatureApi
fake_user = mocker.Mock()
fake_user.is_authenticated = True
mocker.patch(
"controllers.console.feature.current_user",
fake_user,
account = make_account()
current_account = mocker.patch(
"controllers.console.feature.current_account_with_tenant_optional",
return_value=(account, "tenant-123"),
)
system_features = SystemFeatureModel(is_allow_register=True)
get_system_features = mocker.patch(
"controllers.console.feature.FeatureService.get_system_features",
return_value=system_features,
)
mocker.patch(
"controllers.console.feature.FeatureService.get_system_features"
).return_value.model_dump.return_value = {"features": {"sys_feature": True}}
api = SystemFeatureApi()
result = api.get()
assert result == {"features": {"sys_feature": True}}
assert result == system_features.model_dump()
current_account.assert_called_once_with()
get_system_features.assert_called_once_with(is_authenticated=True)
def test_get_system_features_unauthenticated(self, mocker: MockerFixture):
"""
@ -109,19 +114,19 @@ class TestSystemFeatureApi:
from controllers.console.feature import SystemFeatureApi
fake_user = mocker.Mock()
type(fake_user).is_authenticated = mocker.PropertyMock(side_effect=Unauthorized())
mocker.patch(
"controllers.console.feature.current_user",
fake_user,
current_account = mocker.patch(
"controllers.console.feature.current_account_with_tenant_optional",
return_value=(None, None),
)
system_features = SystemFeatureModel(is_allow_register=False)
get_system_features = mocker.patch(
"controllers.console.feature.FeatureService.get_system_features",
return_value=system_features,
)
mocker.patch(
"controllers.console.feature.FeatureService.get_system_features"
).return_value.model_dump.return_value = {"features": {"sys_feature": False}}
api = SystemFeatureApi()
result = api.get()
assert result == {"features": {"sys_feature": False}}
assert result == system_features.model_dump()
current_account.assert_called_once_with()
get_system_features.assert_called_once_with(is_authenticated=False)

View File

@ -1,3 +1,4 @@
from typing import override
from unittest.mock import MagicMock, patch
import pytest
@ -36,6 +37,7 @@ class MockUser(UserMixin):
self.id = user_id
self.current_tenant_id = "tenant123"
@override
def get_id(self) -> str:
return self.id
@ -210,6 +212,7 @@ class TestModelValidationInjection:
Handler().post()
assert exc_info.value.code == 422
assert exc_info.value.description is not None
assert "count" in exc_info.value.description

View File

@ -9,21 +9,21 @@ Strategy:
- ``HitTestingApi.post`` is decorated with ``@cloud_edition_billing_rate_limit_check``
which preserves ``__wrapped__``. We call ``post.__wrapped__(self, ...)`` to skip
the billing decorator and test the business logic directly.
- Base-class methods (``get_and_validate_dataset``, ``perform_hit_testing``) read
``current_user`` from ``controllers.console.datasets.hit_testing_base``, so we
patch it there.
- ``validate_dataset_token`` installs the tenant owner account into Flask-Login's
request context before calling the handler, so direct method-call tests install
the same concrete account on ``g._login_user``.
"""
import uuid
from unittest.mock import Mock, patch
from unittest.mock import patch
import pytest
from flask import Flask
from flask import Flask, g
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.service_api.dataset.hit_testing import HitTestingApi, HitTestingPayload
from models.account import Account
from models.account import Account, Tenant, TenantAccountRole
from models.dataset import Dataset
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
@ -131,13 +131,21 @@ class TestHitTestingApiPost:
def _dataset(dataset_id: str, tenant_id: str) -> Dataset:
return Dataset(id=dataset_id, tenant_id=tenant_id, name="Dataset", created_by="account-1")
@staticmethod
def _account(tenant_id: str) -> Account:
account = Account(name="Service API", email="service-api@example.com")
account.id = "account-1"
tenant = Tenant(name="Tenant")
tenant.id = tenant_id
account._current_tenant = tenant
account.role = TenantAccountRole.OWNER
return account
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
def test_post_success(
self,
mock_current_user,
mock_dataset_svc,
mock_hit_svc,
mock_ns,
@ -148,6 +156,7 @@ class TestHitTestingApiPost:
tenant_id = str(uuid.uuid4())
mock_dataset = self._dataset(dataset_id, tenant_id)
account = self._account(tenant_id)
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
@ -158,6 +167,8 @@ class TestHitTestingApiPost:
mock_ns.payload = {"query": "test query"}
with app.test_request_context():
# TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack
g._login_user = account
api = HitTestingApi()
# Skip billing decorator via __wrapped__
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
@ -168,10 +179,8 @@ class TestHitTestingApiPost:
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
def test_post_with_retrieval_model(
self,
mock_current_user,
mock_dataset_svc,
mock_hit_svc,
mock_ns,
@ -182,6 +191,7 @@ class TestHitTestingApiPost:
tenant_id = str(uuid.uuid4())
mock_dataset = self._dataset(dataset_id, tenant_id)
account = self._account(tenant_id)
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
@ -204,6 +214,8 @@ class TestHitTestingApiPost:
}
with app.test_request_context():
# TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack
g._login_user = account
api = HitTestingApi()
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
@ -218,10 +230,8 @@ class TestHitTestingApiPost:
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
def test_post_preserves_retrieval_model_metadata_filtering_conditions(
self,
mock_current_user,
mock_dataset_svc,
mock_hit_svc,
mock_ns,
@ -232,6 +242,7 @@ class TestHitTestingApiPost:
tenant_id = str(uuid.uuid4())
mock_dataset = self._dataset(dataset_id, tenant_id)
account = self._account(tenant_id)
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
@ -260,6 +271,8 @@ class TestHitTestingApiPost:
}
with app.test_request_context():
# TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack
g._login_user = account
api = HitTestingApi()
HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
@ -270,10 +283,8 @@ class TestHitTestingApiPost:
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
def test_post_prepares_nullable_list_fields(
self,
mock_current_user,
mock_dataset_svc,
mock_hit_svc,
mock_ns,
@ -284,6 +295,7 @@ class TestHitTestingApiPost:
tenant_id = str(uuid.uuid4())
mock_dataset = self._dataset(dataset_id, tenant_id)
account = self._account(tenant_id)
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
@ -297,6 +309,8 @@ class TestHitTestingApiPost:
mock_ns.payload = {"query": "legacy query"}
with app.test_request_context():
# TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack
g._login_user = account
api = HitTestingApi()
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
@ -312,10 +326,8 @@ class TestHitTestingApiPost:
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
def test_post_dataset_not_found(
self,
mock_current_user,
mock_dataset_svc,
mock_ns,
app: Flask,
@ -323,21 +335,22 @@ class TestHitTestingApiPost:
"""Test hit testing with non-existent dataset."""
dataset_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
account = self._account(tenant_id)
mock_dataset_svc.get_dataset.return_value = None
mock_ns.payload = {"query": "test query"}
with app.test_request_context():
# TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack
g._login_user = account
api = HitTestingApi()
with pytest.raises(NotFound):
HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
def test_post_no_dataset_permission(
self,
mock_current_user,
mock_dataset_svc,
mock_ns,
app: Flask,
@ -347,6 +360,7 @@ class TestHitTestingApiPost:
tenant_id = str(uuid.uuid4())
mock_dataset = self._dataset(dataset_id, tenant_id)
account = self._account(tenant_id)
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.side_effect = services.errors.account.NoPermissionError(
@ -355,6 +369,8 @@ class TestHitTestingApiPost:
mock_ns.payload = {"query": "test query"}
with app.test_request_context():
# TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack
g._login_user = account
api = HitTestingApi()
with pytest.raises(Forbidden):
HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)

View File

@ -1,15 +1,15 @@
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock
import pytest
from flask import Flask, Response, g
from flask_login import UserMixin
from pytest_mock import MockerFixture
from werkzeug.exceptions import Unauthorized
import libs.login as login_module
from extensions.ext_login import DifyLoginManager
from libs.login import current_user
from models.account import Account
from models.account import Account, Tenant
@pytest.fixture
@ -23,7 +23,7 @@ def protected_view():
return _protected_view
class MockUser(UserMixin):
class MockUser:
"""Mock user class for testing."""
def __init__(self, id: str, is_authenticated: bool = True):
@ -35,6 +35,22 @@ class MockUser(UserMixin):
return self._is_authenticated
class LoginManagerStub:
def __init__(self, unauthorized_response: Response) -> None:
self._unauthorized_response = unauthorized_response
def unauthorized(self) -> Response:
return self._unauthorized_response
def _login_manager(app: Flask) -> DifyLoginManager:
return cast(DifyLoginManager, app.__dict__["login_manager"])
def _unauthorized_mock(app: Flask) -> MagicMock:
return cast(MagicMock, _login_manager(app).unauthorized)
@pytest.fixture
def login_app(mocker: MockerFixture) -> Flask:
app = Flask(__name__)
@ -95,7 +111,7 @@ class TestLoginRequired:
assert result == "Protected content"
resolve_user.assert_called_once_with()
login_app.login_manager.unauthorized.assert_not_called()
_unauthorized_mock(login_app).assert_not_called()
@pytest.mark.parametrize(
("resolved_user", "description"),
@ -120,11 +136,11 @@ class TestLoginRequired:
with login_app.test_request_context():
result = protected_view()
assert result is login_app.login_manager.unauthorized.return_value, description
assert result is _unauthorized_mock(login_app).return_value, description
assert isinstance(result, Response)
assert result.status_code == 401
resolve_user.assert_called_once_with()
login_app.login_manager.unauthorized.assert_called_once_with()
_unauthorized_mock(login_app).assert_called_once_with()
csrf_check.assert_not_called()
def test_unauthorized_access_propagates_response_object(
@ -138,9 +154,7 @@ class TestLoginRequired:
"""Test that unauthorized responses are propagated as Flask Response objects."""
resolve_user = resolve_current_user(None)
response = Response("Unauthorized", status=401, content_type="application/json")
mocker.patch.object(
login_module, "_get_login_manager", return_value=SimpleNamespace(unauthorized=lambda: response)
)
mocker.patch.object(login_module, "_get_login_manager", return_value=LoginManagerStub(response))
with login_app.test_request_context():
result = protected_view()
@ -177,7 +191,7 @@ class TestLoginRequired:
assert result == "Protected content"
resolve_user.assert_not_called()
csrf_check.assert_not_called()
login_app.login_manager.unauthorized.assert_not_called()
_unauthorized_mock(login_app).assert_not_called()
class TestGetUser:
@ -191,6 +205,7 @@ class TestGetUser:
g._login_user = mock_user
user = login_module._get_user()
assert user == mock_user
assert user is not None
assert user.id == "test_user"
def test_get_user_loads_user_if_not_in_g(self, login_app: Flask, mocker: MockerFixture):
@ -201,7 +216,7 @@ class TestGetUser:
g._login_user = mock_user
load_user = mocker.patch.object(
login_app.login_manager,
_login_manager(login_app),
"load_user_from_request_context",
side_effect=load_user_from_request_context,
)
@ -244,7 +259,9 @@ class TestCurrentAccountWithTenant:
def test_returns_account_and_tenant_id(self, mocker: MockerFixture):
account = Account(name="Test User", email="test@example.com")
account._current_tenant = SimpleNamespace(id="tenant-123")
tenant = Tenant(name="Test Tenant")
tenant.id = "tenant-123"
account._current_tenant = tenant
current_user_proxy = mocker.Mock()
current_user_proxy._get_current_object.return_value = account
mocker.patch.object(login_module, "current_user", new=current_user_proxy)
@ -267,3 +284,58 @@ class TestCurrentAccountWithTenant:
with pytest.raises(AssertionError, match="tenant information should be loaded"):
login_module.current_account_with_tenant()
class TestCurrentAccountWithTenantOptional:
"""Test cases for optional current account resolution."""
def test_returns_account_and_tenant_id_for_authenticated_account(self, mocker: MockerFixture) -> None:
account = Account(name="Test User", email="test@example.com")
tenant = Tenant(name="Test Tenant")
tenant.id = "tenant-123"
account._current_tenant = tenant
mocker.patch.object(login_module, "_resolve_current_user", return_value=account)
user, tenant_id = login_module.current_account_with_tenant_optional()
assert user is account
assert tenant_id == "tenant-123"
def test_returns_none_pair_when_request_loader_raises_unauthorized(self, mocker: MockerFixture) -> None:
mocker.patch.object(login_module, "_resolve_current_user", side_effect=Unauthorized())
user, tenant_id = login_module.current_account_with_tenant_optional()
assert user is None
assert tenant_id is None
def test_returns_none_pair_when_resolved_user_is_not_account(self, mocker: MockerFixture) -> None:
mocker.patch.object(login_module, "_resolve_current_user", return_value=MockUser("end-user"))
user, tenant_id = login_module.current_account_with_tenant_optional()
assert user is None
assert tenant_id is None
class TestResolveTenantIdFallback:
"""Test cases for tenant-only fallback helper."""
def test_returns_provided_tenant_id_without_current_user_lookup(self, mocker: MockerFixture) -> None:
current_account_with_tenant = mocker.patch.object(login_module, "current_account_with_tenant")
tenant_id = login_module.resolve_tenant_id_fallback("tenant-123")
assert tenant_id == "tenant-123"
current_account_with_tenant.assert_not_called()
def test_falls_back_to_current_account_tenant(self, mocker: MockerFixture) -> None:
account = Account(name="Test User", email="test@example.com")
tenant = Tenant(name="Test Tenant")
tenant.id = "tenant-123"
account._current_tenant = tenant
mocker.patch.object(login_module, "current_account_with_tenant", return_value=(account, tenant.id))
tenant_id = login_module.resolve_tenant_id_fallback()
assert tenant_id == "tenant-123"

View File

@ -5,10 +5,6 @@ from services.rag_pipeline.pipeline_template.pipeline_template_type import Pipel
def test_get_pipeline_templates(mocker) -> None:
mocker.patch(
"services.rag_pipeline.pipeline_template.customized.customized_retrieval.current_account_with_tenant",
return_value=("account-id", "tenant-id"),
)
customized_template = SimpleNamespace(
id="tpl-1",
name="Custom Template",
@ -27,7 +23,7 @@ def test_get_pipeline_templates(mocker) -> None:
)
retrieval = CustomizedPipelineTemplateRetrieval()
result = retrieval.get_pipeline_templates("en-US")
result = retrieval.get_pipeline_templates("en-US", "tenant-id")
assert retrieval.get_type() == PipelineTemplateType.CUSTOMIZED
assert result == {

View File

@ -1,9 +1,14 @@
import json
import time
from datetime import datetime
from types import SimpleNamespace
import pytest
from sqlalchemy.orm import sessionmaker
from models import Account, Tenant
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin
from models.workflow import Workflow
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
@ -25,6 +30,89 @@ class MockRepo:
pass
def _make_account(account_id: str = "u1", tenant_id: str = "t1") -> Account:
account = Account(name="Test User", email=f"{account_id}@example.com")
account.id = account_id
tenant = Tenant(name="Test Tenant")
tenant.id = tenant_id
account._current_tenant = tenant
return account
def _make_pipeline(
*,
pipeline_id: str = "p1",
tenant_id: str = "t1",
workflow_id: str | None = None,
is_published: bool = False,
) -> Pipeline:
pipeline = Pipeline(tenant_id=tenant_id, name="Test Pipeline", description="test")
pipeline.id = pipeline_id
pipeline.workflow_id = workflow_id
pipeline.is_published = is_published
return pipeline
def _make_workflow(
*,
workflow_id: str = "wf-1",
tenant_id: str = "t1",
app_id: str = "p1",
graph: dict[str, object] | None = None,
features: dict[str, object] | None = None,
created_by: str = "u1",
) -> Workflow:
workflow = Workflow(
id=workflow_id,
tenant_id=tenant_id,
app_id=app_id,
type="workflow",
version="draft",
marked_name="",
marked_comment="",
graph=json.dumps(graph or {"nodes": []}),
features=json.dumps(features or {}),
created_by=created_by,
created_at=datetime(2024, 1, 1),
updated_by=None,
updated_at=datetime(2024, 1, 1),
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
return workflow
def _make_dataset(*, dataset_id: str = "d1", pipeline_id: str = "p1", tenant_id: str = "t1") -> Dataset:
dataset = Dataset(
id=dataset_id,
tenant_id=tenant_id,
name="Test Dataset",
created_by="u1",
)
dataset.pipeline_id = pipeline_id
return dataset
def _make_customized_template() -> PipelineCustomizedTemplate:
return PipelineCustomizedTemplate(
tenant_id="t1",
name="old",
description="old",
chunk_structure="paragraph",
icon={},
position=1,
yaml_content="",
install_count=0,
language="en-US",
created_by="u1",
)
def _make_recommended_plugin(plugin_id: str) -> PipelineRecommendedPlugin:
return PipelineRecommendedPlugin(plugin_id=plugin_id, provider_name=plugin_id, type="tool", position=0, active=True)
def test_get_pipeline_templates_fallbacks_to_builtin_for_non_english_empty_result(mocker) -> None:
mocker.patch("services.rag_pipeline.rag_pipeline.dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE", "remote")
@ -74,7 +162,7 @@ def test_get_pipeline_template_detail_uses_expected_mode(mocker, template_type:
def test_get_published_workflow_returns_none_when_pipeline_has_no_workflow_id(rag_pipeline_service) -> None:
pipeline = SimpleNamespace(workflow_id=None)
pipeline = _make_pipeline(workflow_id=None)
result = rag_pipeline_service.get_published_workflow(pipeline)
@ -82,7 +170,7 @@ def test_get_published_workflow_returns_none_when_pipeline_has_no_workflow_id(ra
def test_get_all_published_workflow_returns_empty_for_unpublished_pipeline(rag_pipeline_service) -> None:
pipeline = SimpleNamespace(workflow_id=None)
pipeline = _make_pipeline(workflow_id=None)
session = SimpleNamespace()
workflows, has_more = rag_pipeline_service.get_all_published_workflow(
@ -101,7 +189,7 @@ def test_get_all_published_workflow_returns_empty_for_unpublished_pipeline(rag_p
def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_service) -> None:
scalars_result = SimpleNamespace(all=lambda: ["wf1", "wf2", "wf3"])
session = SimpleNamespace(scalars=lambda stmt: scalars_result)
pipeline = SimpleNamespace(id="pipeline-1", workflow_id="wf-live")
pipeline = _make_pipeline(pipeline_id="pipeline-1", workflow_id="wf-live")
workflows, has_more = rag_pipeline_service.get_all_published_workflow(
session=session,
@ -133,8 +221,8 @@ def test_sync_draft_workflow_creates_new_when_none_exists(mocker, rag_pipeline_s
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.flush")
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
pipeline = SimpleNamespace(tenant_id="t1", id="p1", workflow_id=None)
account = SimpleNamespace(id="u1")
pipeline = _make_pipeline(workflow_id=None)
account = _make_account()
result = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
@ -153,11 +241,11 @@ def test_sync_draft_workflow_creates_new_when_none_exists(mocker, rag_pipeline_s
def test_sync_draft_workflow_raises_on_hash_mismatch(mocker, rag_pipeline_service) -> None:
from services.errors.app import WorkflowHashNotEqualError
existing_wf = SimpleNamespace(unique_hash="hash-old")
existing_wf = _make_workflow(graph={"nodes": [{"id": "old"}]})
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=existing_wf)
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
account = SimpleNamespace(id="u1")
pipeline = _make_pipeline()
account = _make_account()
with pytest.raises(WorkflowHashNotEqualError):
rag_pipeline_service.sync_draft_workflow(
@ -184,8 +272,8 @@ def test_sync_draft_workflow_updates_existing(mocker, rag_pipeline_service) -> N
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=existing_wf)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
account = SimpleNamespace(id="u1")
pipeline = _make_pipeline()
account = _make_account()
result = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
@ -275,7 +363,7 @@ def test_get_rag_pipeline_paginate_workflow_runs_delegates(mocker, rag_pipeline_
repo_mock.get_paginated_workflow_runs.return_value = expected
rag_pipeline_service._workflow_run_repo = repo_mock
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
pipeline = _make_pipeline()
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline, {"limit": 10, "last_id": "abc"})
assert result is expected
@ -297,7 +385,7 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) -
repo_mock.get_workflow_run_by_id.return_value = expected
rag_pipeline_service._workflow_run_repo = repo_mock
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
pipeline = _make_pipeline()
result = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline, "run-1")
assert result is expected
@ -310,14 +398,14 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) -
def test_is_workflow_exist_returns_true_when_draft_exists(mocker, rag_pipeline_service) -> None:
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=1)
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
pipeline = _make_pipeline()
assert rag_pipeline_service.is_workflow_exist(pipeline) is True
def test_is_workflow_exist_returns_false_when_no_draft(mocker, rag_pipeline_service) -> None:
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=0)
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
pipeline = _make_pipeline()
assert rag_pipeline_service.is_workflow_exist(pipeline) is False
@ -635,17 +723,10 @@ def test_get_second_step_parameters_success(mocker, rag_pipeline_service) -> Non
def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_service) -> None:
from models.dataset import Pipeline
# 1. Setup mocks
pipeline = mocker.Mock(spec=Pipeline)
pipeline.id = "p1"
pipeline.tenant_id = "t1"
pipeline.workflow_id = "wf-1"
pipeline.is_published = True
pipeline = _make_pipeline(workflow_id="wf-1", is_published=True)
workflow = mocker.Mock()
workflow.id = "wf-1"
workflow = _make_workflow(workflow_id="wf-1")
# Mock db itself to avoid app context errors
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
@ -656,8 +737,8 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
mock_db.session.scalar.side_effect = [None, 5]
# Mock retrieve_dataset
dataset = mocker.Mock()
pipeline.retrieve_dataset.return_value = dataset
dataset = _make_dataset()
dataset.chunk_structure = "paragraph"
# Mock RagPipelineDslService
mock_dsl_service = mocker.Mock()
@ -665,16 +746,14 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.RagPipelineDslService", return_value=mock_dsl_service)
# Mock Session and commit
mocker.patch("services.rag_pipeline.rag_pipeline.Session", return_value=mocker.MagicMock())
session_factory = mocker.patch("services.rag_pipeline.rag_pipeline.sessionmaker")
session_factory.return_value.begin.return_value.__enter__.return_value.scalar.return_value = dataset
# Mock current_user
mock_user = mocker.Mock()
mock_user.id = "user-123"
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", mock_user)
account = _make_account(account_id="user-123")
# 2. Run test
args = {"name": "New Template", "description": "Desc", "icon_info": {"icon": "star"}, "tags": ["tag1"]}
rag_pipeline_service.publish_customized_pipeline_template("p1", args)
rag_pipeline_service.publish_customized_pipeline_template("p1", args, account, "t1")
# 3. Assertions
# Verify a new template was added to session or similar?
@ -687,14 +766,10 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
def test_get_datasource_plugins_success(mocker, rag_pipeline_service) -> None:
from models.dataset import Dataset, Pipeline
# 1. Setup mocks
dataset = mocker.Mock(spec=Dataset)
dataset.pipeline_id = "p1"
dataset = _make_dataset()
pipeline = mocker.Mock(spec=Pipeline)
pipeline.id = "p1"
pipeline = _make_pipeline()
workflow = mocker.Mock()
workflow.graph_dict = {
@ -835,15 +910,10 @@ def test_set_datasource_variables_success(mocker, rag_pipeline_service) -> None:
def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None:
from models.dataset import Pipeline
from models.workflow import Workflow
# 1. Setup mocks
pipeline = mocker.Mock(spec=Pipeline)
pipeline.id = "p1"
pipeline.tenant_id = "t1"
pipeline = _make_pipeline()
workflow = mocker.Mock(spec=Workflow)
workflow = _make_workflow()
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.scalar.return_value = workflow
@ -856,16 +926,10 @@ def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None:
def test_get_published_workflow_success(mocker, rag_pipeline_service) -> None:
from models.dataset import Pipeline
from models.workflow import Workflow
# 1. Setup mocks
pipeline = mocker.Mock(spec=Pipeline)
pipeline.id = "p1"
pipeline.tenant_id = "t1"
pipeline.workflow_id = "wf-pub"
pipeline = _make_pipeline(workflow_id="wf-pub")
workflow = mocker.Mock(spec=Workflow)
workflow = _make_workflow(workflow_id="wf-pub")
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.scalar.return_value = workflow
@ -896,8 +960,8 @@ def test_get_default_block_config_success(rag_pipeline_service) -> None:
def test_publish_workflow_raises_when_draft_workflow_missing(mocker, rag_pipeline_service) -> None:
session = mocker.Mock()
session.scalar.return_value = None
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
account = SimpleNamespace(id="u1")
pipeline = _make_pipeline()
account = _make_account()
with pytest.raises(ValueError, match="No valid workflow found"):
rag_pipeline_service.publish_workflow(session=session, pipeline=pipeline, account=account)
@ -929,8 +993,8 @@ def test_get_default_block_config_injects_http_request_filter(mocker, rag_pipeli
def test_run_draft_workflow_node_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
account = SimpleNamespace(id="u1")
pipeline = _make_pipeline()
account = _make_account()
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=None)
with pytest.raises(ValueError, match="Workflow not initialized"):
@ -939,8 +1003,8 @@ def test_run_draft_workflow_node_raises_when_workflow_missing(mocker, rag_pipeli
def test_run_draft_workflow_node_saves_execution_and_variables(mocker, rag_pipeline_service) -> None:
mocker.patch("services.rag_pipeline.rag_pipeline.db", mocker.Mock(engine=mocker.Mock()))
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
account = SimpleNamespace(id="u1")
pipeline = _make_pipeline()
account = _make_account()
draft_workflow = mocker.Mock(id="wf-1")
draft_workflow.get_node_config_by_id.return_value = {"id": "node-1"}
draft_workflow.get_enclosing_node_type_and_id.return_value = ("loop", "enclosing-node")
@ -1163,11 +1227,11 @@ def test_get_second_step_parameters_handles_string_and_list_variable_references(
def test_get_rag_pipeline_workflow_run_node_executions_empty_when_run_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
pipeline = _make_pipeline()
mocker.patch.object(rag_pipeline_service, "get_rag_pipeline_workflow_run", return_value=None)
result = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
pipeline=pipeline, run_id="run-1", user=SimpleNamespace(id="u1")
pipeline=pipeline, run_id="run-1", user=_make_account()
)
assert result == []
@ -1175,14 +1239,14 @@ def test_get_rag_pipeline_workflow_run_node_executions_empty_when_run_missing(mo
def test_get_rag_pipeline_workflow_run_node_executions_returns_sorted_executions(mocker, rag_pipeline_service) -> None:
mocker.patch("services.rag_pipeline.rag_pipeline.db", mocker.Mock(engine=mocker.Mock()))
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
pipeline = _make_pipeline()
mocker.patch.object(rag_pipeline_service, "get_rag_pipeline_workflow_run", return_value=SimpleNamespace(id="run-1"))
repo = mocker.Mock()
repo.get_db_models_by_workflow_run.return_value = ["n1", "n2"]
mocker.patch("services.rag_pipeline.rag_pipeline.SQLAlchemyWorkflowNodeExecutionRepository", return_value=repo)
result = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
pipeline=pipeline, run_id="run-1", user=SimpleNamespace(id="u1")
pipeline=pipeline, run_id="run-1", user=_make_account()
)
assert result == ["n1", "n2"]
@ -1192,7 +1256,7 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.scalars.return_value.all.return_value = []
result = rag_pipeline_service.get_recommended_plugins("all")
result = rag_pipeline_service.get_recommended_plugins("all", _make_account(), "t1")
assert result == {
"installed_recommended_plugins": [],
@ -1201,11 +1265,10 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra
def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_pipeline_service) -> None:
plugin_a = SimpleNamespace(plugin_id="plugin-a")
plugin_b = SimpleNamespace(plugin_id="plugin-b")
plugin_a = _make_recommended_plugin("plugin-a")
plugin_b = _make_recommended_plugin("plugin-b")
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.scalars.return_value.all.return_value = [plugin_a, plugin_b]
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
mocker.patch(
"services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools",
return_value=[SimpleNamespace(plugin_id="plugin-a", to_dict=lambda: {"plugin_id": "plugin-a"})],
@ -1215,7 +1278,7 @@ def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_p
return_value=[{"plugin_id": "plugin-b", "name": "Plugin B"}],
)
result = rag_pipeline_service.get_recommended_plugins("custom")
result = rag_pipeline_service.get_recommended_plugins("custom", _make_account(), "t1")
assert result["installed_recommended_plugins"] == [{"plugin_id": "plugin-a"}]
assert result["uninstalled_recommended_plugins"] == [{"plugin_id": "plugin-b", "name": "Plugin B"}]
@ -1229,8 +1292,8 @@ def test_get_node_last_run_delegates_to_repository(mocker, rag_pipeline_service)
"services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository",
return_value=repo,
)
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
workflow = SimpleNamespace(id="wf1")
pipeline = _make_pipeline()
workflow = _make_workflow(workflow_id="wf1")
result = rag_pipeline_service.get_node_last_run(pipeline, workflow, "node-1")
@ -1572,15 +1635,17 @@ def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
rag_pipeline_service.publish_customized_pipeline_template("p1", {}, _make_account(), "t1")
def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id=None)
pipeline = _make_pipeline(workflow_id=None)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
with pytest.raises(ValueError, match="Pipeline workflow not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {"name": "template-name"})
rag_pipeline_service.publish_customized_pipeline_template(
"p1", {"name": "template-name"}, _make_account(), "t1"
)
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
@ -1630,13 +1695,12 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
template = _make_customized_template()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="", description="updated", icon_info=IconInfo(icon="i"))
result = RagPipelineService.update_customized_pipeline_template("tpl-1", info)
result = RagPipelineService.update_customized_pipeline_template("tpl-1", info, _make_account(), "t1")
assert result.description == "updated"
commit.assert_called_once()
@ -1644,7 +1708,7 @@ def test_update_customized_pipeline_template_commits_when_name_empty(mocker) ->
def test_get_all_published_workflow_without_filters_has_no_more(rag_pipeline_service) -> None:
session = SimpleNamespace(scalars=lambda stmt: SimpleNamespace(all=lambda: ["wf1"]))
pipeline = SimpleNamespace(id="p1", workflow_id="wf-live")
pipeline = _make_pipeline(workflow_id="wf-live")
workflows, has_more = rag_pipeline_service.get_all_published_workflow(
session=session,
@ -1856,38 +1920,34 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
pipeline = _make_pipeline(workflow_id="wf-1")
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
with pytest.raises(ValueError, match="Workflow not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
rag_pipeline_service.publish_customized_pipeline_template("p1", {}, _make_account(), "t1")
def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
workflow = SimpleNamespace(id="wf-1")
pipeline = _make_pipeline(workflow_id="wf-1")
workflow = _make_workflow(workflow_id="wf-1")
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.engine = mocker.Mock()
mock_db.session.get.side_effect = [pipeline, workflow]
session_ctx = mocker.MagicMock()
session_ctx.__enter__.return_value = SimpleNamespace()
session_ctx.__exit__.return_value = False
mocker.patch("services.rag_pipeline.rag_pipeline.Session", return_value=session_ctx)
pipeline.retrieve_dataset = lambda session: None
session_factory = mocker.patch("services.rag_pipeline.rag_pipeline.sessionmaker")
session_factory.return_value.begin.return_value.__enter__.return_value.scalar.return_value = None
with pytest.raises(ValueError, match="Dataset not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
rag_pipeline_service.publish_customized_pipeline_template("p1", {}, _make_account(), "t1")
def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipeline_service) -> None:
plugin = SimpleNamespace(plugin_id="plugin-a")
plugin = _make_recommended_plugin("plugin-a")
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.scalars.return_value.all.return_value = [plugin]
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
mocker.patch("services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[])
mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids", return_value=[])
result = rag_pipeline_service.get_recommended_plugins("all")
result = rag_pipeline_service.get_recommended_plugins("all", _make_account(), "t1")
assert result["installed_recommended_plugins"] == []
assert result["uninstalled_recommended_plugins"] == []
@ -1918,8 +1978,8 @@ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_
def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
dataset = _make_dataset()
pipeline = _make_pipeline()
workflow = SimpleNamespace(
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
)
@ -2079,8 +2139,8 @@ def test_set_datasource_variables_raises_when_workflow_missing(mocker, rag_pipel
def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
dataset = _make_dataset()
pipeline = _make_pipeline()
workflow = SimpleNamespace(
graph_dict={"nodes": [{"id": "n1", "data": {"type": "datasource", "datasource_parameters": {}}}]},
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
@ -2097,8 +2157,8 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
dataset = _make_dataset()
pipeline = _make_pipeline()
workflow = SimpleNamespace(
graph_dict={
"nodes": [
@ -2139,8 +2199,8 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
dataset = _make_dataset()
pipeline = _make_pipeline()
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
result = rag_pipeline_service.get_pipeline("t1", "d1")

View File

@ -1,65 +1,62 @@
from pathlib import Path
from unittest.mock import Mock, create_autospec, patch
from typing import cast
from unittest.mock import Mock
import pytest
from models.account import Account
from models import Account, Tenant
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService
def _make_account(account_id: str = "user-456", tenant_id: str = "tenant-123") -> Account:
account = Account(name="Test User", email=f"{account_id}@example.com")
account.id = account_id
tenant = Tenant(name="Test Tenant")
tenant.id = tenant_id
account._current_tenant = tenant
return account
class TestMetadataBugCompleteValidation:
"""Complete test suite to verify the metadata nullable bug and its fix."""
def test_1_pydantic_layer_validation(self):
def test_1_pydantic_layer_validation(self) -> None:
"""Test Layer 1: Pydantic model validation correctly rejects None values."""
# Pydantic should reject None values for required fields
with pytest.raises((ValueError, TypeError)):
MetadataArgs(type=None, name=None)
MetadataArgs(type=None, name=None) # pyrefly: ignore[bad-argument-type]
with pytest.raises((ValueError, TypeError)):
MetadataArgs(type="string", name=None)
MetadataArgs(type="string", name=None) # pyrefly: ignore[bad-argument-type]
with pytest.raises((ValueError, TypeError)):
MetadataArgs(type=None, name="test")
MetadataArgs(type=None, name="test") # pyrefly: ignore[bad-argument-type]
# Valid values should work
valid_args = MetadataArgs(type="string", name="test_name")
assert valid_args.type == "string"
assert valid_args.name == "test_name"
def test_2_business_logic_layer_crashes_on_none(self):
def test_2_business_logic_layer_crashes_on_none(self) -> None:
"""Test Layer 2: Business logic crashes when None values slip through."""
# Create mock that bypasses Pydantic validation
mock_metadata_args = Mock()
mock_metadata_args.name = None
mock_metadata_args.type = "string"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# Should crash with TypeError
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
account = _make_account()
# Should crash with TypeError
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123")
# Test update method as well
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
account = _make_account()
none_name = cast(str, None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", none_name, account, "tenant-123")
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
def test_3_database_constraints_verification(self):
def test_3_database_constraints_verification(self) -> None:
"""Test Layer 3: Verify database model has nullable=False constraints."""
from sqlalchemy import inspect
@ -75,7 +72,7 @@ class TestMetadataBugCompleteValidation:
assert type_column.nullable is False, "type column should be nullable=False"
assert name_column.nullable is False, "name column should be nullable=False"
def test_4_fixed_api_layer_rejects_null(self):
def test_4_fixed_api_layer_rejects_null(self) -> None:
"""Test Layer 4: Fixed API configuration properly rejects null values using Pydantic."""
with pytest.raises((ValueError, TypeError)):
MetadataArgs.model_validate({"type": None, "name": None})
@ -86,30 +83,23 @@ class TestMetadataBugCompleteValidation:
with pytest.raises((ValueError, TypeError)):
MetadataArgs.model_validate({"type": None, "name": "test"})
def test_5_fixed_api_accepts_valid_values(self):
def test_5_fixed_api_accepts_valid_values(self) -> None:
"""Test that fixed API still accepts valid non-null values."""
args = MetadataArgs.model_validate({"type": "string", "name": "valid_name"})
assert args.type == "string"
assert args.name == "valid_name"
def test_6_simulated_buggy_behavior(self):
def test_6_simulated_buggy_behavior(self) -> None:
"""Test simulating the original buggy behavior by bypassing Pydantic validation."""
mock_metadata_args = Mock()
mock_metadata_args.name = None
mock_metadata_args.type = None
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
account = _make_account()
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123")
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
def test_7_end_to_end_validation_layers(self):
def test_7_end_to_end_validation_layers(self) -> None:
"""Test all validation layers work together correctly."""
# Layer 1: API should reject null at parameter level (with fix)
# Layer 2: Pydantic should reject null at model level
@ -128,7 +118,7 @@ class TestMetadataBugCompleteValidation:
assert len(metadata_args.name) <= 255 # This should not crash
assert len(metadata_args.type) > 0 # This should not crash
def test_8_verify_specific_fix_locations(self):
def test_8_verify_specific_fix_locations(self) -> None:
"""Verify that the specific locations mentioned in bug report are fixed."""
# Read the actual files to verify fixes
import os
@ -152,7 +142,7 @@ class TestMetadataBugCompleteValidation:
class TestMetadataValidationSummary:
"""Summary tests that demonstrate the complete validation architecture."""
def test_validation_layer_architecture(self):
def test_validation_layer_architecture(self) -> None:
"""Document and test the 4-layer validation architecture."""
# Layer 1: API Parameter Validation (Flask-RESTful reqparse)
# - Role: First line of defense, validates HTTP request parameters

View File

@ -1,56 +1,53 @@
from unittest.mock import Mock, create_autospec, patch
from typing import cast
from unittest.mock import Mock
import pytest
from models.account import Account
from models import Account, Tenant
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService
def _make_account(account_id: str = "user-456", tenant_id: str = "tenant-123") -> Account:
account = Account(name="Test User", email=f"{account_id}@example.com")
account.id = account_id
tenant = Tenant(name="Test Tenant")
tenant.id = tenant_id
account._current_tenant = tenant
return account
class TestMetadataNullableBug:
"""Test case to reproduce the metadata nullable validation bug."""
def test_metadata_args_with_none_values_should_fail(self):
def test_metadata_args_with_none_values_should_fail(self) -> None:
"""Test that MetadataArgs validation should reject None values."""
# This test demonstrates the expected behavior - should fail validation
with pytest.raises((ValueError, TypeError)):
# This should fail because Pydantic expects non-None values
MetadataArgs(type=None, name=None)
MetadataArgs(type=None, name=None) # pyrefly: ignore[bad-argument-type]
def test_metadata_service_create_with_none_name_crashes(self):
def test_metadata_service_create_with_none_name_crashes(self) -> None:
"""Test that MetadataService.create_metadata crashes when name is None."""
# Mock the MetadataArgs to bypass Pydantic validation
mock_metadata_args = Mock()
mock_metadata_args.name = None # This will cause len() to crash
mock_metadata_args.type = "string"
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
account = _make_account()
# This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123")
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
def test_metadata_service_update_with_none_name_crashes(self):
def test_metadata_service_update_with_none_name_crashes(self) -> None:
"""Test that MetadataService.update_metadata_name crashes when name is None."""
mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
account = _make_account()
none_name = cast(str, None)
# This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", none_name, account, "tenant-123")
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
def test_api_layer_now_uses_pydantic_validation(self):
def test_api_layer_now_uses_pydantic_validation(self) -> None:
"""Verify that API layer relies on Pydantic validation instead of reqparse."""
invalid_payload = {"type": None, "name": None}
with pytest.raises((ValueError, TypeError)):