mirror of
https://github.com/langgenius/dify.git
synced 2026-06-11 10:57:40 +08:00
refactor(api): migrate remaining console APIs to use injected user/tenant (#37288)
This commit is contained in:
parent
5ed663e7fd
commit
2a46a7d91d
@ -8,6 +8,7 @@ from controllers.common.schema import register_response_schema_models, register_
|
||||
from fields.hit_testing_fields import HitTestingResponse
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
|
||||
from .. import console_ns
|
||||
from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
|
||||
@ -15,6 +16,8 @@ from ..wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
register_schema_models(console_ns, HitTestingPayload)
|
||||
@ -38,11 +41,16 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID) -> dict[str, object]:
|
||||
@with_current_tenant_id
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, current_tenant_id: str, dataset_id: UUID) -> dict[str, object]:
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str, current_user, current_tenant_id)
|
||||
args = self.parse_args(console_ns.payload)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args))
|
||||
return dump_response(
|
||||
HitTestingResponse,
|
||||
self.perform_hit_testing(dataset, args, current_user, current_tenant_id),
|
||||
)
|
||||
|
||||
@ -19,7 +19,7 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import current_user
|
||||
from libs.login import resolve_account_fallback
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
@ -71,8 +71,10 @@ class DatasetsHitTestingBase:
|
||||
return normalized_records
|
||||
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str) -> Dataset:
|
||||
assert isinstance(current_user, Account)
|
||||
def get_and_validate_dataset(
|
||||
dataset_id: str, current_user: Account | None = None, current_tenant_id: str | None = None
|
||||
) -> Dataset:
|
||||
current_user, _ = resolve_account_fallback(current_user, current_tenant_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -95,9 +97,14 @@ class DatasetsHitTestingBase:
|
||||
return hit_testing_payload.model_dump(exclude_none=True)
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset: Dataset, args: dict[str, Any]) -> dict[str, Any]:
|
||||
assert isinstance(current_user, Account)
|
||||
def perform_hit_testing(
|
||||
dataset: Dataset,
|
||||
args: dict[str, Any],
|
||||
current_user: Account | None = None,
|
||||
current_tenant_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
current_user, _ = resolve_account_fallback(current_user, current_tenant_id)
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=cast(str, args.get("query")),
|
||||
|
||||
@ -11,6 +11,7 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.dataset_fields import (
|
||||
@ -50,7 +51,8 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -59,7 +61,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args, current_user, current_tenant_id)
|
||||
return dump_response(DatasetMetadataResponse, metadata), 201
|
||||
|
||||
@setup_required
|
||||
@ -87,7 +89,8 @@ class DatasetMetadataApi(Resource):
|
||||
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||
@with_current_user
|
||||
def patch(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, metadata_id: UUID):
|
||||
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
|
||||
name = payload.name
|
||||
|
||||
@ -98,7 +101,9 @@ class DatasetMetadataApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
|
||||
metadata = MetadataService.update_metadata_name(
|
||||
dataset_id_str, metadata_id_str, name, current_user, current_tenant_id
|
||||
)
|
||||
return dump_response(DatasetMetadataResponse, metadata), 200
|
||||
|
||||
@setup_required
|
||||
@ -181,7 +186,7 @@ class DocumentMetadataEditApi(Resource):
|
||||
|
||||
metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args, current_user)
|
||||
|
||||
# Frontend callers only await success and invalidate caches; no response body is consumed.
|
||||
return "", 204
|
||||
|
||||
@ -21,11 +21,14 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
knowledge_pipeline_publish_enabled,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
@ -96,10 +99,11 @@ class PipelineTemplateListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self) -> JsonResponseWithStatus:
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str) -> JsonResponseWithStatus:
|
||||
query = PipelineTemplateListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(query.type, query.language)
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(query.type, query.language, current_tenant_id)
|
||||
return dump_response(PipelineTemplateListResponse, pipeline_templates), 200
|
||||
|
||||
|
||||
@ -128,10 +132,14 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str) -> tuple[str, int]:
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(self, current_tenant_id: str, current_user: Account, template_id: str) -> tuple[str, int]:
|
||||
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
RagPipelineService.update_customized_pipeline_template(
|
||||
template_id, pipeline_template_info, current_user, current_tenant_id
|
||||
)
|
||||
return "", 204
|
||||
|
||||
@console_ns.response(204, "Pipeline template deleted")
|
||||
@ -139,8 +147,9 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str) -> tuple[str, int]:
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, template_id: str) -> tuple[str, int]:
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id, current_tenant_id)
|
||||
return "", 204
|
||||
|
||||
@setup_required
|
||||
@ -168,8 +177,12 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str) -> tuple[str, int]:
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, pipeline_id: str) -> tuple[str, int]:
|
||||
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
|
||||
rag_pipeline_service.publish_customized_pipeline_template(
|
||||
pipeline_id, payload.model_dump(), current_user, current_tenant_id
|
||||
)
|
||||
return "", 204
|
||||
|
||||
@ -48,7 +48,7 @@ from fields.workflow_run_fields import (
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty, dump_response
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.model import EndUser
|
||||
@ -835,7 +835,7 @@ class RagPipelineWorkflowRunListApi(Resource):
|
||||
}
|
||||
)
|
||||
args = {
|
||||
"last_id": str(query.last_id) if query.last_id else None,
|
||||
"last_id": query.last_id or None,
|
||||
"limit": query.limit,
|
||||
}
|
||||
|
||||
@ -881,7 +881,8 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline, run_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, pipeline: Pipeline, run_id: UUID):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
@ -988,9 +989,11 @@ class RagPipelineRecommendedPluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type, current_user, current_tenant_id)
|
||||
return recommended_plugins
|
||||
|
||||
@ -18,7 +18,7 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user_id
|
||||
from controllers.console.wraps import with_current_user, with_current_user_id
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@ -30,7 +30,6 @@ from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -84,7 +83,8 @@ register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
|
||||
def post(self, installed_app: InstalledApp):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
@ -101,8 +101,6 @@ class CompletionApi(InstalledAppResource):
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||
)
|
||||
@ -160,7 +158,8 @@ class CompletionStopApi(InstalledAppResource):
|
||||
)
|
||||
class ChatApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||
def post(self, installed_app: InstalledApp):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
@ -177,8 +176,6 @@ class ChatApi(InstalledAppResource):
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
|
||||
@ -11,6 +11,7 @@ from controllers.common.schema import register_response_schema_models, register_
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
@ -19,7 +20,6 @@ from fields.conversation_fields import (
|
||||
SimpleConversation,
|
||||
)
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.conversation_service import ConversationService
|
||||
@ -45,7 +45,8 @@ register_response_schema_models(console_ns, ResultResponse)
|
||||
)
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
|
||||
def get(self, installed_app: InstalledApp):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
@ -66,14 +67,12 @@ class ConversationListApi(InstalledAppResource):
|
||||
args = ConversationListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
pagination = WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=str(args.last_id) if args.last_id else None,
|
||||
last_id=args.last_id or None,
|
||||
limit=args.limit,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=args.pinned,
|
||||
@ -95,7 +94,8 @@ class ConversationListApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
def delete(self, installed_app: InstalledApp, c_id: UUID):
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, installed_app: InstalledApp, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
@ -105,8 +105,6 @@ class ConversationApi(InstalledAppResource):
|
||||
|
||||
conversation_id = str(c_id)
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@ -120,7 +118,8 @@ class ConversationApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
|
||||
def post(self, installed_app: InstalledApp, c_id: UUID):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
@ -133,8 +132,6 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
payload = ConversationRenamePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
conversation = ConversationService.rename(
|
||||
app_model, conversation_id, current_user, payload.name, payload.auto_generate
|
||||
)
|
||||
@ -153,7 +150,8 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app: InstalledApp, c_id: UUID):
|
||||
@with_current_user
|
||||
def patch(self, current_user: Account, installed_app: InstalledApp, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
@ -164,8 +162,6 @@ class ConversationPinApi(InstalledAppResource):
|
||||
conversation_id = str(c_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
WebConversationService.pin(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@ -179,7 +175,8 @@ class ConversationPinApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationUnPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app: InstalledApp, c_id: UUID):
|
||||
@with_current_user
|
||||
def patch(self, current_user: Account, installed_app: InstalledApp, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
@ -188,8 +185,6 @@ class ConversationUnPinApi(InstalledAppResource):
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json")
|
||||
|
||||
@ -33,6 +33,7 @@ from controllers.console.explore.error import (
|
||||
NotWorkflowAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
|
||||
from controllers.console.wraps import with_current_user
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -63,7 +64,6 @@ from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.account import TenantStatus
|
||||
from models.model import AppMode, Site
|
||||
@ -155,7 +155,8 @@ register_schema_models(console_ns, WorkflowRunRequest, ChatRequest, TextToSpeech
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
|
||||
def post(self, trial_app):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
"""
|
||||
@ -168,7 +169,6 @@ class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
|
||||
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
@ -206,7 +206,6 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
@ -221,7 +220,8 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, trial_app):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -239,9 +239,6 @@ class TrialChatApi(TrialAppResource):
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
@ -276,7 +273,8 @@ class TrialChatApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||
def get(self, trial_app, message_id):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, trial_app, message_id):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -285,8 +283,6 @@ class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
@ -313,15 +309,13 @@ class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||
|
||||
class TrialChatAudioApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, trial_app):
|
||||
app_model = trial_app
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
@ -358,7 +352,8 @@ class TrialChatAudioApi(TrialAppResource):
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
|
||||
@ -366,8 +361,6 @@ class TrialChatTextApi(TrialAppResource):
|
||||
message_id = request_data.message_id
|
||||
text = request_data.text
|
||||
voice = request_data.voice
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
@ -405,7 +398,8 @@ class TrialChatTextApi(TrialAppResource):
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
@ -417,9 +411,6 @@ class TrialCompletionApi(TrialAppResource):
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import current_account_with_tenant_optional, login_required
|
||||
from services.feature_service import (
|
||||
FeatureModel,
|
||||
FeatureService,
|
||||
@ -13,7 +12,12 @@ from services.feature_service import (
|
||||
)
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
|
||||
from .wraps import (
|
||||
account_initialization_required,
|
||||
cloud_utm_record,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
)
|
||||
|
||||
|
||||
class TrialModelsResponse(ResponseModel):
|
||||
@ -133,12 +137,6 @@ class SystemFeatureApi(Resource):
|
||||
|
||||
Only non-sensitive configuration data should be returned by this endpoint.
|
||||
"""
|
||||
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
|
||||
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
|
||||
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
|
||||
# raise `Unauthorized` exception if authentication token is not provided.
|
||||
try:
|
||||
is_authenticated = current_user.is_authenticated
|
||||
except Unauthorized:
|
||||
is_authenticated = False
|
||||
current_user, _ = current_account_with_tenant_optional()
|
||||
is_authenticated = current_user is not None
|
||||
return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump()
|
||||
|
||||
@ -17,7 +17,7 @@ from core.trigger.entities.entities import SubscriptionBuilderUpdater
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
@ -31,6 +31,8 @@ from ..wraps import (
|
||||
edit_permission_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -77,12 +79,9 @@ class TriggerProviderIconApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
return TriggerManager.get_trigger_plugin_icon(tenant_id=tenant_id, provider_id=provider)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/triggers")
|
||||
@ -90,12 +89,10 @@ class TriggerProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
"""List all trigger providers for the current tenant"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
@ -103,14 +100,10 @@ class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
"""Get info for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
|
||||
)
|
||||
return jsonable_encoder(TriggerProviderService.get_trigger_provider(tenant_id, TriggerProviderID(provider)))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
@ -119,16 +112,14 @@ class TriggerSubscriptionListApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
user=user,
|
||||
)
|
||||
@ -149,17 +140,16 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account, provider: str):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
credential_type = CredentialType.of(payload.credential_type)
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
credential_type=credential_type,
|
||||
@ -194,17 +184,16 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, subscription_builder_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account, provider: str, subscription_builder_id: str):
|
||||
"""Verify and update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
# Use atomic update_and_verify to prevent race conditions
|
||||
return TriggerSubscriptionBuilderService.update_and_verify_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
@ -226,17 +215,14 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, subscription_builder_id: str):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str, subscription_builder_id: str):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
@ -262,10 +248,6 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider: str, subscription_builder_id: str):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
|
||||
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
|
||||
@ -283,15 +265,15 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, subscription_builder_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account, provider: str, subscription_builder_id: str):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
try:
|
||||
# Use atomic update_and_build to prevent race conditions
|
||||
TriggerSubscriptionBuilderService.update_and_build_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
@ -316,15 +298,13 @@ class TriggerSubscriptionUpdateApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, subscription_id: str):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, subscription_id: str):
|
||||
"""Update a subscription instance"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
request = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
subscription = TriggerProviderService.get_subscription_by_id(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
if not subscription:
|
||||
@ -341,7 +321,7 @@ class TriggerSubscriptionUpdateApi(Resource):
|
||||
manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED
|
||||
if rename or manually_created:
|
||||
TriggerProviderService.update_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
name=request.name,
|
||||
properties=request.properties,
|
||||
@ -351,7 +331,7 @@ class TriggerSubscriptionUpdateApi(Resource):
|
||||
# For the rest cases(API_KEY, OAUTH2)
|
||||
# we need to call third party provider(e.g. GitHub) to rebuild the subscription
|
||||
TriggerProviderService.rebuild_trigger_subscription(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
name=request.name,
|
||||
provider_id=provider_id,
|
||||
subscription_id=subscription_id,
|
||||
@ -375,23 +355,21 @@ class TriggerSubscriptionDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, subscription_id: str):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, subscription_id: str):
|
||||
"""Delete a subscription instance"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# Delete trigger provider subscription
|
||||
TriggerProviderService.delete_trigger_provider(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
# Delete plugin triggers
|
||||
TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
return {"result": "success"}
|
||||
@ -407,17 +385,14 @@ class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
@ -557,30 +532,28 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
# Get custom OAuth client params if exists
|
||||
custom_params = TriggerProviderService.get_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if custom client is enabled
|
||||
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
system_client_exists = TriggerProviderService.is_oauth_system_client_exists(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
return jsonable_encoder(
|
||||
{
|
||||
@ -603,17 +576,15 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
return TriggerProviderService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
client_params=payload.client_params,
|
||||
enabled=payload.enabled,
|
||||
@ -629,16 +600,14 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, provider: str):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
return TriggerProviderService.delete_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
@ -657,16 +626,15 @@ class TriggerSubscriptionVerifyApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, subscription_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account, provider: str, subscription_id: str):
|
||||
"""Verify credentials for an existing subscription (edit mode only)"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
verify_request = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
result = TriggerProviderService.verify_subscription_credentials(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_id=subscription_id,
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Concatenate, cast, overload
|
||||
|
||||
from flask import Response, current_app, g, has_request_context, request
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from configs import dify_config
|
||||
@ -48,6 +49,53 @@ def current_account_with_tenant() -> tuple[Account, str]:
|
||||
return user, user.current_tenant_id
|
||||
|
||||
|
||||
def current_account_with_tenant_optional() -> tuple[Account | None, str | None]:
|
||||
try:
|
||||
user = _resolve_current_user()
|
||||
except Unauthorized:
|
||||
return None, None
|
||||
|
||||
if not isinstance(user, Account):
|
||||
return None, None
|
||||
if not bool(getattr(user, "is_authenticated", False)):
|
||||
return None, None
|
||||
return user, user.current_tenant_id
|
||||
|
||||
|
||||
def resolve_account_fallback(
|
||||
current_user: Account | None = None,
|
||||
current_tenant_id: str | None = None,
|
||||
*,
|
||||
fallback_tenant_id: str | None = None,
|
||||
) -> tuple[Account, str]:
|
||||
"""
|
||||
If the provided current user and tenant ID is None, fallback to current_account_with_tenant.
|
||||
This is useful for those service layers whose controllers are not migrated to use DI for
|
||||
resolving current user yet.
|
||||
|
||||
TODO: this should be removed after all ctrls (especially service API) are migrated
|
||||
"""
|
||||
if current_user is not None:
|
||||
tenant_id = current_tenant_id or fallback_tenant_id
|
||||
if tenant_id is None:
|
||||
raise ValueError("current_tenant_id is required when current_user is provided.")
|
||||
return current_user, tenant_id
|
||||
return current_account_with_tenant()
|
||||
|
||||
|
||||
def resolve_tenant_id_fallback(current_tenant_id: str | None = None) -> str:
|
||||
"""
|
||||
If the provided tenant ID is None, fallback to the tenant resolved from current_account_with_tenant.
|
||||
This is useful for tenant-only service paths whose controllers are not all migrated to tenant injection yet.
|
||||
|
||||
TODO: this should be removed after all ctrls (especially service API) are migrated
|
||||
"""
|
||||
if current_tenant_id is not None:
|
||||
return current_tenant_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return tenant_id
|
||||
|
||||
|
||||
@overload
|
||||
def login_required[T, **P, R](
|
||||
func: Callable[Concatenate[T, P], R],
|
||||
|
||||
@ -7,7 +7,8 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField, Metad
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.login import resolve_account_fallback
|
||||
from models import Account
|
||||
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
|
||||
from models.enums import DatasetMetadataType
|
||||
from services.dataset_service import DocumentService
|
||||
@ -21,11 +22,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class MetadataService:
|
||||
@staticmethod
|
||||
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
|
||||
def create_metadata(
|
||||
dataset_id: str,
|
||||
metadata_args: MetadataArgs,
|
||||
current_user: Account | None = None, # TODO: the service_api is not migrated yet
|
||||
current_tenant_id: str | None = None,
|
||||
) -> DatasetMetadata:
|
||||
# check if metadata name is too long
|
||||
if len(metadata_args.name) > 255:
|
||||
raise ValueError("Metadata name cannot exceed 255 characters.")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id)
|
||||
# check if metadata name already exists
|
||||
if db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
@ -52,14 +58,20 @@ class MetadataService:
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
|
||||
def update_metadata_name(
|
||||
dataset_id: str,
|
||||
metadata_id: str,
|
||||
name: str,
|
||||
current_user: Account | None = None,
|
||||
current_tenant_id: str | None = None, # TODO: the service_api is not migrated yet
|
||||
) -> DatasetMetadata | None:
|
||||
# check if metadata name is too long
|
||||
if len(name) > 255:
|
||||
raise ValueError("Metadata name cannot exceed 255 characters.")
|
||||
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
# check if metadata name already exists
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id)
|
||||
if db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(
|
||||
@ -107,6 +119,7 @@ class MetadataService:
|
||||
return metadata
|
||||
except Exception:
|
||||
logger.exception("Update metadata name failed")
|
||||
return None
|
||||
finally:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@ -217,7 +230,15 @@ class MetadataService:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@staticmethod
|
||||
def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData):
|
||||
def update_documents_metadata(
|
||||
dataset: Dataset,
|
||||
metadata_args: MetadataOperationData,
|
||||
current_user: Account | None = None, # TODO: the service_api is not migrated yet
|
||||
current_tenant_id: str | None = None,
|
||||
):
|
||||
current_user, current_tenant_id = resolve_account_fallback(
|
||||
current_user, current_tenant_id, fallback_tenant_id=dataset.tenant_id
|
||||
)
|
||||
for operation in metadata_args.operation_data:
|
||||
lock_key = f"document_metadata_lock_{operation.document_id}"
|
||||
try:
|
||||
@ -248,7 +269,6 @@ class MetadataService:
|
||||
)
|
||||
)
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
for metadata_value in operation.metadata_list:
|
||||
# check if binding already exists
|
||||
if operation.partial_update:
|
||||
|
||||
@ -21,7 +21,8 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
return PipelineTemplateType.BUILTIN
|
||||
|
||||
@override
|
||||
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
|
||||
def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]:
|
||||
del current_tenant_id
|
||||
result = self.fetch_pipeline_templates_from_builtin(language)
|
||||
return result
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import yaml
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.login import resolve_tenant_id_fallback
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
|
||||
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
|
||||
@ -40,8 +40,8 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
"""
|
||||
|
||||
@override
|
||||
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]:
|
||||
current_tenant_id = resolve_tenant_id_fallback(current_tenant_id)
|
||||
return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
|
||||
|
||||
@override
|
||||
|
||||
@ -40,7 +40,8 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
"""
|
||||
|
||||
@override
|
||||
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
|
||||
def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]:
|
||||
del current_tenant_id
|
||||
return self.fetch_pipeline_templates_from_db(language)
|
||||
|
||||
@override
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any, Protocol
|
||||
class PipelineTemplateRetrievalBase(Protocol):
|
||||
"""Interface for pipeline template retrieval."""
|
||||
|
||||
def get_pipeline_templates(self, language: str) -> dict[str, Any]: ...
|
||||
def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]: ...
|
||||
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: ...
|
||||
|
||||
|
||||
@ -25,7 +25,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
|
||||
|
||||
@override
|
||||
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
|
||||
def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]:
|
||||
del current_tenant_id
|
||||
try:
|
||||
return self.fetch_pipeline_templates_from_dify_official(language)
|
||||
except Exception as e:
|
||||
|
||||
@ -8,7 +8,6 @@ from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@ -54,6 +53,7 @@ from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_htt
|
||||
from graphon.runtime import VariablePool
|
||||
from graphon.variables.variables import Variable, VariableBase
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import resolve_account_fallback, resolve_tenant_id_fallback
|
||||
from models import Account
|
||||
from models.dataset import ( # type: ignore
|
||||
Dataset,
|
||||
@ -104,11 +104,16 @@ class RagPipelineService:
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@classmethod
|
||||
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict[str, Any]:
|
||||
def get_pipeline_templates(
|
||||
cls,
|
||||
type: str = "built-in",
|
||||
language: str = "en-US",
|
||||
current_tenant_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
result = retrieval_instance.get_pipeline_templates(language)
|
||||
result = retrieval_instance.get_pipeline_templates(language, current_tenant_id)
|
||||
if not result.get("pipeline_templates") and language != "en-US":
|
||||
template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval()
|
||||
result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US")
|
||||
@ -116,7 +121,7 @@ class RagPipelineService:
|
||||
else:
|
||||
mode = "customized"
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
result = retrieval_instance.get_pipeline_templates(language)
|
||||
result = retrieval_instance.get_pipeline_templates(language, current_tenant_id)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@ -146,17 +151,24 @@ class RagPipelineService:
|
||||
return customized_result
|
||||
|
||||
@classmethod
|
||||
def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity):
|
||||
def update_customized_pipeline_template(
|
||||
cls,
|
||||
template_id: str,
|
||||
template_info: PipelineTemplateInfoEntity,
|
||||
current_user: Account | None = None,
|
||||
current_tenant_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Update pipeline template.
|
||||
:param template_id: template id
|
||||
:param template_info: template info
|
||||
"""
|
||||
current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id)
|
||||
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
@ -169,7 +181,7 @@ class RagPipelineService:
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.name == template_name,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_tenant_id,
|
||||
PipelineCustomizedTemplate.id != template_id,
|
||||
)
|
||||
.limit(1)
|
||||
@ -184,15 +196,16 @@ class RagPipelineService:
|
||||
return customized_template
|
||||
|
||||
@classmethod
|
||||
def delete_customized_pipeline_template(cls, template_id: str):
|
||||
def delete_customized_pipeline_template(cls, template_id: str, current_tenant_id: str | None = None):
|
||||
"""
|
||||
Delete customized pipeline template.
|
||||
"""
|
||||
current_tenant_id = resolve_tenant_id_fallback(current_tenant_id)
|
||||
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
@ -1174,10 +1187,17 @@ class RagPipelineService:
|
||||
return list(node_executions)
|
||||
|
||||
@classmethod
|
||||
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict[str, Any]):
|
||||
def publish_customized_pipeline_template(
|
||||
cls,
|
||||
pipeline_id: str,
|
||||
args: dict[str, Any],
|
||||
current_user: Account | None = None,
|
||||
current_tenant_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Publish customized pipeline template
|
||||
"""
|
||||
current_user, _ = resolve_account_fallback(current_user, current_tenant_id)
|
||||
pipeline = db.session.get(Pipeline, pipeline_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
@ -1357,7 +1377,7 @@ class RagPipelineService:
|
||||
return []
|
||||
return marketplace.batch_fetch_plugin_by_ids(plugin_ids)
|
||||
|
||||
def get_recommended_plugins(self, type: str) -> dict[str, Any]:
|
||||
def get_recommended_plugins(self, type: str, current_user: Account, current_tenant_id: str) -> dict[str, Any]:
|
||||
# Query active recommended plugins
|
||||
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
|
||||
if type and type != "all":
|
||||
@ -1375,7 +1395,7 @@ class RagPipelineService:
|
||||
plugin_ids = [plugin.plugin_id for plugin in pipeline_recommended_plugins]
|
||||
providers = BuiltinToolManageService.list_builtin_tools(
|
||||
user_id=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
providers_map = {provider.plugin_id: provider.to_dict() for provider in providers}
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from inspect import unwrap
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
@ -20,8 +21,8 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
PipelineTemplateListApi,
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
)
|
||||
from models.account import Account
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import unwrap
|
||||
|
||||
|
||||
class TestPipelineTemplateListApi:
|
||||
@ -53,7 +54,7 @@ class TestPipelineTemplateListApi:
|
||||
return_value=templates,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, str(uuid4()))
|
||||
|
||||
assert status == 200
|
||||
assert response == {
|
||||
@ -147,6 +148,9 @@ class TestCustomizedPipelineTemplateApi:
|
||||
def test_patch_success(self, app: Flask) -> None:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.patch)
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.id = str(uuid4())
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
payload = {
|
||||
"name": "Template",
|
||||
@ -161,15 +165,18 @@ class TestCustomizedPipelineTemplateApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template"
|
||||
) as update_mock,
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
response, status = method(api, tenant_id, account, "tpl-1")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert update_mock.call_args.args[2] is account
|
||||
assert update_mock.call_args.args[3] == tenant_id
|
||||
assert status == 204
|
||||
assert response == ""
|
||||
|
||||
def test_delete_success(self, app: Flask) -> None:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.delete)
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -177,9 +184,9 @@ class TestCustomizedPipelineTemplateApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template"
|
||||
) as delete_mock,
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
response, status = method(api, tenant_id, "tpl-1")
|
||||
|
||||
delete_mock.assert_called_once_with("tpl-1")
|
||||
delete_mock.assert_called_once_with("tpl-1", tenant_id)
|
||||
assert status == 204
|
||||
assert response == ""
|
||||
|
||||
@ -227,6 +234,9 @@ class TestPublishCustomizedPipelineTemplateApi:
|
||||
def test_post_success(self, app: Flask) -> None:
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.id = str(uuid4())
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
payload = {
|
||||
"name": "Template",
|
||||
@ -244,8 +254,10 @@ class TestPublishCustomizedPipelineTemplateApi:
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "pipeline-1")
|
||||
response, status = method(api, tenant_id, account, "pipeline-1")
|
||||
|
||||
service.publish_customized_pipeline_template.assert_called_once()
|
||||
assert service.publish_customized_pipeline_template.call_args.args[2] is account
|
||||
assert service.publish_customized_pipeline_template.call_args.args[3] == tenant_id
|
||||
assert status == 204
|
||||
assert response == ""
|
||||
|
||||
@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import json
|
||||
from datetime import datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from typing import TypedDict, Unpack
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
@ -35,12 +34,15 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
|
||||
RagPipelineTaskStopApi,
|
||||
RagPipelineTransformApi,
|
||||
RagPipelineWorkflowLastRunApi,
|
||||
RagPipelineWorkflowRunNodeExecutionListApi,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import Workflow
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
@ -71,7 +73,7 @@ class WorkflowFactoryPayload(TypedDict):
|
||||
created_by: str
|
||||
created_at: datetime
|
||||
updated_by: str | None
|
||||
updated_at: datetime
|
||||
updated_at: datetime | None
|
||||
environment_variables: list[WorkflowVariablePayload]
|
||||
conversation_variables: list[WorkflowVariablePayload]
|
||||
rag_pipeline_variables: list[WorkflowVariablePayload]
|
||||
@ -90,39 +92,69 @@ class WorkflowFactoryOverrides(TypedDict, total=False):
|
||||
created_by: str
|
||||
created_at: datetime
|
||||
updated_by: str | None
|
||||
updated_at: datetime
|
||||
updated_at: datetime | None
|
||||
environment_variables: list[WorkflowVariablePayload]
|
||||
conversation_variables: list[WorkflowVariablePayload]
|
||||
rag_pipeline_variables: list[WorkflowVariablePayload]
|
||||
|
||||
|
||||
def make_node_execution(**overrides: object) -> SimpleNamespace:
|
||||
payload: dict[str, object] = {
|
||||
class NodeExecutionOverrides(TypedDict, total=False):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
workflow_run_id: str | None
|
||||
index: int
|
||||
predecessor_node_id: str | None
|
||||
node_execution_id: str | None
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
inputs: str | None
|
||||
process_data: str | None
|
||||
outputs: str | None
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None
|
||||
elapsed_time: float
|
||||
execution_metadata: str | None
|
||||
created_at: datetime
|
||||
created_by_role: CreatorUserRole
|
||||
created_by: str
|
||||
finished_at: datetime | None
|
||||
|
||||
|
||||
def make_node_execution(**overrides: Unpack[NodeExecutionOverrides]) -> WorkflowNodeExecutionModel:
|
||||
payload: NodeExecutionOverrides = {
|
||||
"id": "node-exec-1",
|
||||
"tenant_id": DEFAULT_WORKFLOW_TENANT_ID,
|
||||
"app_id": DEFAULT_WORKFLOW_APP_ID,
|
||||
"workflow_id": "workflow-1",
|
||||
"workflow_run_id": None,
|
||||
"index": 1,
|
||||
"predecessor_node_id": None,
|
||||
"node_execution_id": None,
|
||||
"node_id": "node1",
|
||||
"node_type": "start",
|
||||
"title": "Start",
|
||||
"inputs_dict": {"query": "hello"},
|
||||
"process_data_dict": {},
|
||||
"outputs_dict": {"answer": "world"},
|
||||
"status": "succeeded",
|
||||
"inputs": json.dumps({"query": "hello"}),
|
||||
"process_data": json.dumps({}),
|
||||
"outputs": json.dumps({"answer": "world"}),
|
||||
"status": WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
"error": None,
|
||||
"elapsed_time": 1.0,
|
||||
"execution_metadata_dict": {},
|
||||
"extras": {},
|
||||
"execution_metadata": json.dumps({}),
|
||||
"created_at": datetime(2026, 1, 1, 0, 0, 0),
|
||||
"created_by_role": "account",
|
||||
"created_by_account": None,
|
||||
"created_by_end_user": None,
|
||||
"created_by_role": CreatorUserRole.ACCOUNT,
|
||||
"created_by": DEFAULT_WORKFLOW_CREATED_BY,
|
||||
"finished_at": datetime(2026, 1, 1, 0, 0, 1),
|
||||
"inputs_truncated": False,
|
||||
"outputs_truncated": False,
|
||||
"process_data_truncated": False,
|
||||
}
|
||||
payload.update(overrides)
|
||||
return SimpleNamespace(**payload)
|
||||
execution = WorkflowNodeExecutionModel(
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
**payload,
|
||||
)
|
||||
execution.offload_data = []
|
||||
return execution
|
||||
|
||||
|
||||
def default_workflow_payload() -> WorkflowFactoryPayload:
|
||||
@ -274,7 +306,10 @@ class TestDraftWorkflowApi:
|
||||
|
||||
pipeline = make_pipeline()
|
||||
user = make_account(id="account-1")
|
||||
workflow = MagicMock(unique_hash="restored-hash", updated_at=None, created_at=datetime(2024, 1, 1))
|
||||
workflow = make_workflow(
|
||||
graph=json.dumps({"nodes": [{"id": "restored"}], "edges": []}),
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
|
||||
service = MagicMock()
|
||||
service.restore_published_workflow_to_draft.return_value = workflow
|
||||
@ -289,7 +324,7 @@ class TestDraftWorkflowApi:
|
||||
result = method(api, user, pipeline, "published-workflow")
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert result["hash"] == "restored-hash"
|
||||
assert result["hash"] == workflow.unique_hash
|
||||
|
||||
def test_restore_published_workflow_to_draft_not_found(self, app: Flask) -> None:
|
||||
api = RagPipelineDraftWorkflowRestoreApi()
|
||||
@ -515,10 +550,7 @@ class TestPublishedPipelineApis:
|
||||
|
||||
user = make_account(id="u1")
|
||||
|
||||
workflow = MagicMock(
|
||||
id=str(uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
workflow = make_workflow(id=str(uuid4()), created_at=naive_utc_now())
|
||||
|
||||
service = MagicMock()
|
||||
service.publish_workflow.return_value = workflow
|
||||
@ -576,6 +608,8 @@ class TestMiscApis:
|
||||
|
||||
service = MagicMock()
|
||||
service.get_recommended_plugins.return_value = [{"id": "p1"}]
|
||||
user = make_account()
|
||||
tenant_id = "tenant-1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=all"),
|
||||
@ -584,8 +618,9 @@ class TestMiscApis:
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, tenant_id, user)
|
||||
assert result == [{"id": "p1"}]
|
||||
service.get_recommended_plugins.assert_called_once_with("all", user, tenant_id)
|
||||
|
||||
|
||||
class TestPublishedRagPipelineRunApi:
|
||||
@ -814,7 +849,7 @@ class TestRagPipelineWorkflowLastRunApi:
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = make_pipeline()
|
||||
workflow = MagicMock()
|
||||
workflow = make_workflow()
|
||||
node_exec = make_node_execution()
|
||||
|
||||
service = MagicMock()
|
||||
@ -853,6 +888,42 @@ class TestRagPipelineWorkflowLastRunApi:
|
||||
method(api, pipeline, "node1")
|
||||
|
||||
|
||||
class TestRagPipelineWorkflowRunNodeExecutionListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_node_executions_passes_current_user(self, app: Flask) -> None:
|
||||
api = RagPipelineWorkflowRunNodeExecutionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = make_account()
|
||||
pipeline = make_pipeline()
|
||||
run_id = uuid4()
|
||||
node_exec = make_node_execution(workflow_run_id=str(run_id))
|
||||
|
||||
service = MagicMock()
|
||||
service.get_rag_pipeline_workflow_run_node_executions.return_value = [node_exec]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, user, pipeline, run_id)
|
||||
|
||||
service.get_rag_pipeline_workflow_run_node_executions.assert_called_once_with(
|
||||
pipeline=pipeline,
|
||||
run_id=str(run_id),
|
||||
user=user,
|
||||
)
|
||||
assert result["data"][0]["id"] == "node-exec-1"
|
||||
assert result["data"][0]["inputs"] == {"query": "hello"}
|
||||
assert result["data"][0]["outputs"] == {"answer": "world"}
|
||||
|
||||
|
||||
class TestRagPipelineDatasourceVariableApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
|
||||
@ -2,7 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from dataclasses import dataclass
|
||||
from inspect import unwrap
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -10,85 +13,103 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
import controllers.console.explore.conversation as conversation_module
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from models.enums import ConversationFromSource, ConversationStatus
|
||||
from models.model import App, AppMode, Conversation, InstalledApp
|
||||
from services.errors.conversation import (
|
||||
ConversationNotExistsError,
|
||||
LastConversationNotExistsError,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class FakeConversation:
|
||||
def __init__(self, cid):
|
||||
self.id = cid
|
||||
self.name = "test"
|
||||
self.inputs = {}
|
||||
self.status = "normal"
|
||||
self.introduction = ""
|
||||
@dataclass
|
||||
class InstalledAppCarrier:
|
||||
app: App | None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_app():
|
||||
app_model = MagicMock(mode=AppMode.CHAT, id="app-id")
|
||||
return MagicMock(app=app_model)
|
||||
def chat_app() -> InstalledApp:
|
||||
app_model = App(
|
||||
tenant_id="tenant-1",
|
||||
name="Chat App",
|
||||
mode=AppMode.CHAT,
|
||||
enable_site=True,
|
||||
enable_api=False,
|
||||
)
|
||||
app_model.id = "app-id"
|
||||
return cast(InstalledApp, InstalledAppCarrier(app=app_model))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def non_chat_app():
|
||||
app_model = MagicMock(mode=AppMode.COMPLETION)
|
||||
return MagicMock(app=app_model)
|
||||
def non_chat_app() -> InstalledApp:
|
||||
app_model = App(
|
||||
tenant_id="tenant-1",
|
||||
name="Completion App",
|
||||
mode=AppMode.COMPLETION,
|
||||
enable_site=True,
|
||||
enable_api=False,
|
||||
)
|
||||
app_model.id = "app-id"
|
||||
return cast(InstalledApp, InstalledAppCarrier(app=app_model))
|
||||
|
||||
|
||||
def make_conversation(*, id: str) -> Conversation:
|
||||
conversation = Conversation(
|
||||
app_id="app-id",
|
||||
mode=AppMode.CHAT,
|
||||
name="test",
|
||||
from_source=ConversationFromSource.API,
|
||||
)
|
||||
conversation.id = id
|
||||
conversation.inputs = {}
|
||||
conversation.status = ConversationStatus.NORMAL
|
||||
conversation.introduction = ""
|
||||
return conversation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user():
|
||||
user = MagicMock(spec=Account)
|
||||
def user() -> Account:
|
||||
user = Account(name="User", email="user.com")
|
||||
user.id = "uid"
|
||||
return user
|
||||
|
||||
|
||||
class TestConversationListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app: Flask, chat_app, user):
|
||||
def test_get_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pagination = MagicMock(
|
||||
pagination = InfiniteScrollPagination(
|
||||
data=[make_conversation(id="c1"), make_conversation(id="c2")],
|
||||
limit=20,
|
||||
has_more=False,
|
||||
data=[FakeConversation("c1"), FakeConversation("c2")],
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?limit=20"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.WebConversationService,
|
||||
"pagination_by_last_id",
|
||||
return_value=pagination,
|
||||
),
|
||||
):
|
||||
result = method(chat_app)
|
||||
result = method(api, user, chat_app)
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
|
||||
def test_last_conversation_not_exists(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.WebConversationService,
|
||||
"pagination_by_last_id",
|
||||
@ -96,47 +117,45 @@ class TestConversationListApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app)
|
||||
method(api, user, chat_app)
|
||||
|
||||
def test_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
def test_wrong_app_mode(self, app: Flask, non_chat_app: InstalledApp, user: Account) -> None:
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(non_chat_app)
|
||||
method(api, user, non_chat_app)
|
||||
|
||||
|
||||
class TestConversationApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_delete_success(self, app: Flask, chat_app, user):
|
||||
def test_delete_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.ConversationService,
|
||||
"delete",
|
||||
),
|
||||
):
|
||||
result = method(chat_app, "cid")
|
||||
result = method(api, user, chat_app, "cid")
|
||||
|
||||
body, status = result
|
||||
assert status == 204
|
||||
assert body == ""
|
||||
|
||||
def test_delete_not_found(self, app: Flask, chat_app, user):
|
||||
def test_delete_not_found(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.ConversationService,
|
||||
"delete",
|
||||
@ -144,48 +163,46 @@ class TestConversationApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app, "cid")
|
||||
method(api, user, chat_app, "cid")
|
||||
|
||||
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app: InstalledApp, user: Account) -> None:
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(non_chat_app, "cid")
|
||||
method(api, user, non_chat_app, "cid")
|
||||
|
||||
|
||||
class TestConversationRenameApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_rename_success(self, app: Flask, chat_app, user):
|
||||
def test_rename_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
conversation = FakeConversation("cid")
|
||||
conversation = make_conversation(id="cid")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "new"}),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.ConversationService,
|
||||
"rename",
|
||||
return_value=conversation,
|
||||
),
|
||||
):
|
||||
result = method(chat_app, "cid")
|
||||
result = method(api, user, chat_app, "cid")
|
||||
|
||||
assert result["id"] == "cid"
|
||||
|
||||
def test_rename_not_found(self, app: Flask, chat_app, user):
|
||||
def test_rename_not_found(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "new"}),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.ConversationService,
|
||||
"rename",
|
||||
@ -193,48 +210,46 @@ class TestConversationRenameApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app, "cid")
|
||||
method(api, user, chat_app, "cid")
|
||||
|
||||
|
||||
class TestConversationPinApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_pin_success(self, app: Flask, chat_app, user):
|
||||
def test_pin_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
|
||||
api = conversation_module.ConversationPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.WebConversationService,
|
||||
"pin",
|
||||
),
|
||||
):
|
||||
result = method(chat_app, "cid")
|
||||
result = method(api, user, chat_app, "cid")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_unpin_success(self, app: Flask, chat_app, user):
|
||||
def test_unpin_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None:
|
||||
api = conversation_module.ConversationUnPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.WebConversationService,
|
||||
"unpin",
|
||||
),
|
||||
):
|
||||
result = method(chat_app, "cid")
|
||||
result = method(api, user, chat_app, "cid")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import unwrap
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -31,123 +32,110 @@ from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def mock_user():
|
||||
user = MagicMock(spec=Account)
|
||||
def mock_user() -> Account:
|
||||
user = Account(name="User", email="user.com")
|
||||
user.id = "u1"
|
||||
user.current_tenant_id = "t1"
|
||||
return user
|
||||
|
||||
|
||||
class TestTriggerProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_icon_success(self, app: Flask):
|
||||
def test_icon_success(self, app: Flask) -> None:
|
||||
api = TriggerProviderIconApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_plugin_icon",
|
||||
return_value="icon",
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == "icon"
|
||||
assert method(api, "t1", "github") == "icon"
|
||||
|
||||
def test_list_providers(self, app: Flask):
|
||||
def test_list_providers(self, app: Flask) -> None:
|
||||
api = TriggerProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_providers",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
assert method(api) == []
|
||||
assert method(api, "t1") == []
|
||||
|
||||
def test_provider_info(self, app: Flask):
|
||||
def test_provider_info(self, app: Flask) -> None:
|
||||
api = TriggerProviderInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider",
|
||||
return_value={"id": "p1"},
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == {"id": "p1"}
|
||||
assert method(api, "t1", "github") == {"id": "p1"}
|
||||
|
||||
|
||||
class TestTriggerSubscriptionListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_list_success(self, app: Flask):
|
||||
def test_list_success(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == []
|
||||
assert method(api, "t1", mock_user(), "github") == []
|
||||
|
||||
def test_list_invalid_provider(self, app: Flask):
|
||||
def test_list_invalid_provider(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
|
||||
side_effect=ValueError("bad"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, "bad")
|
||||
result, status = method(api, "t1", mock_user(), "bad")
|
||||
assert status == 404
|
||||
|
||||
|
||||
class TestTriggerSubscriptionBuilderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_create_builder(self, app: Flask):
|
||||
def test_create_builder(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionBuilderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
|
||||
return_value={"id": "b1"},
|
||||
),
|
||||
):
|
||||
result = method(api, "github")
|
||||
result = method(api, "t1", mock_user(), "github")
|
||||
assert "subscription_builder" in result
|
||||
|
||||
def test_get_builder(self, app: Flask):
|
||||
def test_get_builder(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionBuilderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -160,50 +148,47 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
):
|
||||
assert method(api, "github", "b1") == {"id": "b1"}
|
||||
|
||||
def test_verify_builder(self, app: Flask):
|
||||
def test_verify_builder(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionBuilderVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {"a": 1}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == {"ok": True}
|
||||
assert method(api, "t1", mock_user(), "github", "b1") == {"ok": True}
|
||||
|
||||
def test_verify_builder_error(self, app: Flask):
|
||||
def test_verify_builder_error(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionBuilderVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
|
||||
side_effect=Exception("err"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "github", "b1")
|
||||
method(api, "t1", mock_user(), "github", "b1")
|
||||
|
||||
def test_update_builder(self, app: Flask):
|
||||
def test_update_builder(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionBuilderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "n"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder",
|
||||
return_value={"id": "b1"},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == {"id": "b1"}
|
||||
assert method(api, "t1", "github", "b1") == {"id": "b1"}
|
||||
|
||||
def test_logs(self, app: Flask):
|
||||
def test_logs(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionBuilderLogsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -212,7 +197,6 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs",
|
||||
return_value=[log],
|
||||
@ -220,27 +204,26 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
):
|
||||
assert "logs" in method(api, "github", "b1")
|
||||
|
||||
def test_build(self, app: Flask):
|
||||
def test_build(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionBuilderBuildApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "x"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_build_builder",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == 200
|
||||
assert method(api, "t1", mock_user(), "github", "b1") == 200
|
||||
|
||||
|
||||
class TestTriggerSubscriptionCrud:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_update_rename_only(self, app: Flask):
|
||||
def test_update_rename_only(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -250,43 +233,40 @@ class TestTriggerSubscriptionCrud:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "x"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
|
||||
return_value=sub,
|
||||
),
|
||||
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"),
|
||||
):
|
||||
assert method(api, "s1") == 200
|
||||
assert method(api, "t1", "s1") == 200
|
||||
|
||||
def test_update_not_found(self, app: Flask):
|
||||
def test_update_not_found(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "x"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, "x")
|
||||
method(api, "t1", "x")
|
||||
|
||||
def test_update_rebuild(self, app: Flask):
|
||||
def test_update_rebuild(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
sub = MagicMock()
|
||||
sub.provider_id = "github"
|
||||
sub.credential_type = CredentialType.OAUTH2
|
||||
sub.credentials = {}
|
||||
sub.parameters = {}
|
||||
sub.credentials = {"token": "old"}
|
||||
sub.parameters = {"repo": "demo"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
|
||||
return_value=sub,
|
||||
@ -295,9 +275,9 @@ class TestTriggerSubscriptionCrud:
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription"
|
||||
),
|
||||
):
|
||||
assert method(api, "s1") == 200
|
||||
assert method(api, "t1", "s1") == 200
|
||||
|
||||
def test_delete_subscription(self, app: Flask):
|
||||
def test_delete_subscription(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -305,7 +285,6 @@ class TestTriggerSubscriptionCrud:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
|
||||
patch("controllers.console.workspace.trigger_providers.sessionmaker") as mock_session_cls,
|
||||
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"),
|
||||
@ -316,17 +295,16 @@ class TestTriggerSubscriptionCrud:
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
|
||||
result = method(api, "sub1")
|
||||
result = method(api, "t1", "sub1")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_delete_subscription_value_error(self, app: Flask):
|
||||
def test_delete_subscription_value_error(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
|
||||
patch("controllers.console.workspace.trigger_providers.sessionmaker") as session_cls,
|
||||
patch(
|
||||
@ -338,21 +316,20 @@ class TestTriggerSubscriptionCrud:
|
||||
session_cls.return_value.begin.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, "sub1")
|
||||
method(api, "t1", "sub1")
|
||||
|
||||
|
||||
class TestTriggerOAuthApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_oauth_authorize_success(self, app: Flask):
|
||||
def test_oauth_authorize_success(self, app: Flask) -> None:
|
||||
api = TriggerOAuthAuthorizeApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value={"a": 1},
|
||||
@ -370,25 +347,24 @@ class TestTriggerOAuthApis:
|
||||
return_value=MagicMock(authorization_url="url"),
|
||||
),
|
||||
):
|
||||
resp = method(api, "github")
|
||||
resp = method(api, "t1", mock_user(), "github")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_oauth_authorize_no_client(self, app: Flask):
|
||||
def test_oauth_authorize_no_client(self, app: Flask) -> None:
|
||||
api = TriggerOAuthAuthorizeApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, "github")
|
||||
method(api, "t1", mock_user(), "github")
|
||||
|
||||
def test_oauth_callback_forbidden(self, app: Flask):
|
||||
def test_oauth_callback_forbidden(self, app: Flask) -> None:
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -396,7 +372,7 @@ class TestTriggerOAuthApis:
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_success(self, app: Flask):
|
||||
def test_oauth_callback_success(self, app: Flask) -> None:
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -426,7 +402,7 @@ class TestTriggerOAuthApis:
|
||||
resp = method(api, "github")
|
||||
assert resp.status_code == 302
|
||||
|
||||
def test_oauth_callback_no_oauth_client(self, app: Flask):
|
||||
def test_oauth_callback_no_oauth_client(self, app: Flask) -> None:
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -450,7 +426,7 @@ class TestTriggerOAuthApis:
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_empty_credentials(self, app: Flask):
|
||||
def test_oauth_callback_empty_credentials(self, app: Flask) -> None:
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -481,16 +457,15 @@ class TestTriggerOAuthApis:
|
||||
|
||||
class TestTriggerOAuthClientManageApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_client(self, app: Flask):
|
||||
def test_get_client(self, app: Flask) -> None:
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_custom_oauth_client_params",
|
||||
return_value={},
|
||||
@ -508,84 +483,79 @@ class TestTriggerOAuthClientManageApi:
|
||||
return_value=MagicMock(get_oauth_client_schema=lambda: {}),
|
||||
),
|
||||
):
|
||||
result = method(api, "github")
|
||||
result = method(api, "t1", "github")
|
||||
assert "configured" in result
|
||||
|
||||
def test_post_client(self, app: Flask):
|
||||
def test_post_client(self, app: Flask) -> None:
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"enabled": True}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == {"ok": True}
|
||||
assert method(api, "t1", "github") == {"ok": True}
|
||||
|
||||
def test_delete_client(self, app: Flask):
|
||||
def test_delete_client(self, app: Flask) -> None:
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == {"ok": True}
|
||||
assert method(api, "t1", "github") == {"ok": True}
|
||||
|
||||
def test_oauth_client_post_value_error(self, app: Flask):
|
||||
def test_oauth_client_post_value_error(self, app: Flask) -> None:
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"enabled": True}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
|
||||
side_effect=ValueError("bad"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, "github")
|
||||
method(api, "t1", "github")
|
||||
|
||||
|
||||
class TestTriggerSubscriptionVerifyApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_verify_success(self, app: Flask):
|
||||
def test_verify_success(self, app: Flask) -> None:
|
||||
api = TriggerSubscriptionVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "s1") == {"ok": True}
|
||||
assert method(api, "t1", mock_user(), "github", "s1") == {"ok": True}
|
||||
|
||||
@pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")])
|
||||
def test_verify_errors(self, app: Flask, raised_exception):
|
||||
def test_verify_errors(self, app: Flask, raised_exception: Exception) -> None:
|
||||
api = TriggerSubscriptionVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
|
||||
side_effect=raised_exception,
|
||||
),
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, "github", "s1")
|
||||
method(api, "t1", mock_user(), "github", "s1")
|
||||
|
||||
@ -11,19 +11,29 @@ Covers:
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models import Account, Tenant
|
||||
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate
|
||||
from models.enums import DataSourceType
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
def _make_account(account_id: str, tenant_id: str) -> Account:
|
||||
account = Account(name="Test User", email=f"{account_id}@example.com")
|
||||
account.id = account_id
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = tenant_id
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
|
||||
class TestRagPipelineServiceGetPipeline:
|
||||
"""Integration tests for RagPipelineService.get_pipeline."""
|
||||
|
||||
@ -32,7 +42,7 @@ class TestRagPipelineServiceGetPipeline:
|
||||
yield
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def _make_service(self, flask_app_with_containers) -> RagPipelineService:
|
||||
def _make_service(self, flask_app_with_containers: Flask) -> RagPipelineService:
|
||||
with (
|
||||
patch(
|
||||
"services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository",
|
||||
@ -72,7 +82,7 @@ class TestRagPipelineServiceGetPipeline:
|
||||
return dataset
|
||||
|
||||
def test_get_pipeline_raises_when_dataset_not_found(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
self, db_session_with_containers: Session, flask_app_with_containers: Flask
|
||||
) -> None:
|
||||
"""get_pipeline raises ValueError when dataset does not exist."""
|
||||
service = self._make_service(flask_app_with_containers)
|
||||
@ -81,7 +91,7 @@ class TestRagPipelineServiceGetPipeline:
|
||||
service.get_pipeline(tenant_id=str(uuid4()), dataset_id=str(uuid4()))
|
||||
|
||||
def test_get_pipeline_raises_when_pipeline_not_found(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
self, db_session_with_containers: Session, flask_app_with_containers: Flask
|
||||
) -> None:
|
||||
"""get_pipeline raises ValueError when dataset exists but has no linked pipeline."""
|
||||
tenant_id = str(uuid4())
|
||||
@ -95,7 +105,7 @@ class TestRagPipelineServiceGetPipeline:
|
||||
service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id)
|
||||
|
||||
def test_get_pipeline_returns_pipeline_when_found(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
self, db_session_with_containers: Session, flask_app_with_containers: Flask
|
||||
) -> None:
|
||||
"""get_pipeline returns the Pipeline when both Dataset and Pipeline exist."""
|
||||
tenant_id = str(uuid4())
|
||||
@ -139,43 +149,44 @@ class TestUpdateCustomizedPipelineTemplate:
|
||||
db_session.flush()
|
||||
return template
|
||||
|
||||
def test_update_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None:
|
||||
def test_update_template_succeeds(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers: Flask
|
||||
) -> None:
|
||||
"""update_customized_pipeline_template updates name and description."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
template = self._create_template(db_session_with_containers, tenant_id, created_by)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id)
|
||||
account = _make_account(created_by, tenant_id)
|
||||
|
||||
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
|
||||
info = PipelineTemplateInfoEntity(
|
||||
name="Updated Name",
|
||||
description="Updated description",
|
||||
icon_info=IconInfo(icon="🔥"),
|
||||
)
|
||||
result = RagPipelineService.update_customized_pipeline_template(template.id, info)
|
||||
info = PipelineTemplateInfoEntity(
|
||||
name="Updated Name",
|
||||
description="Updated description",
|
||||
icon_info=IconInfo(icon="🔥"),
|
||||
)
|
||||
result = RagPipelineService.update_customized_pipeline_template(template.id, info, account, tenant_id)
|
||||
|
||||
assert result.name == "Updated Name"
|
||||
assert result.description == "Updated description"
|
||||
|
||||
def test_update_template_raises_when_not_found(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
self, db_session_with_containers: Session, flask_app_with_containers: Flask
|
||||
) -> None:
|
||||
"""update_customized_pipeline_template raises ValueError when template doesn't exist."""
|
||||
fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4()))
|
||||
tenant_id = str(uuid4())
|
||||
account = _make_account(str(uuid4()), tenant_id)
|
||||
|
||||
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
|
||||
info = PipelineTemplateInfoEntity(
|
||||
name="New Name",
|
||||
description="desc",
|
||||
icon_info=IconInfo(icon="📄"),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Customized pipeline template not found"):
|
||||
RagPipelineService.update_customized_pipeline_template(str(uuid4()), info)
|
||||
info = PipelineTemplateInfoEntity(
|
||||
name="New Name",
|
||||
description="desc",
|
||||
icon_info=IconInfo(icon="📄"),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Customized pipeline template not found"):
|
||||
RagPipelineService.update_customized_pipeline_template(str(uuid4()), info, account, tenant_id)
|
||||
|
||||
def test_update_template_raises_on_duplicate_name(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
self, db_session_with_containers: Session, flask_app_with_containers: Flask
|
||||
) -> None:
|
||||
"""update_customized_pipeline_template raises ValueError when new name already exists."""
|
||||
tenant_id = str(uuid4())
|
||||
@ -184,16 +195,15 @@ class TestUpdateCustomizedPipelineTemplate:
|
||||
self._create_template(db_session_with_containers, tenant_id, created_by, name="Duplicate")
|
||||
db_session_with_containers.flush()
|
||||
|
||||
fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id)
|
||||
account = _make_account(created_by, tenant_id)
|
||||
|
||||
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
|
||||
info = PipelineTemplateInfoEntity(
|
||||
name="Duplicate",
|
||||
description="desc",
|
||||
icon_info=IconInfo(icon="📄"),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Template name is already exists"):
|
||||
RagPipelineService.update_customized_pipeline_template(template1.id, info)
|
||||
info = PipelineTemplateInfoEntity(
|
||||
name="Duplicate",
|
||||
description="desc",
|
||||
icon_info=IconInfo(icon="📄"),
|
||||
)
|
||||
with pytest.raises(ValueError, match="Template name is already exists"):
|
||||
RagPipelineService.update_customized_pipeline_template(template1.id, info, account, tenant_id)
|
||||
|
||||
|
||||
class TestDeleteCustomizedPipelineTemplate:
|
||||
@ -221,7 +231,9 @@ class TestDeleteCustomizedPipelineTemplate:
|
||||
db_session.flush()
|
||||
return template
|
||||
|
||||
def test_delete_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None:
|
||||
def test_delete_template_succeeds(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers: Flask
|
||||
) -> None:
|
||||
"""delete_customized_pipeline_template removes the template from the DB."""
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
@ -229,27 +241,23 @@ class TestDeleteCustomizedPipelineTemplate:
|
||||
template_id = template.id
|
||||
db_session_with_containers.flush()
|
||||
|
||||
fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id)
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id, tenant_id)
|
||||
|
||||
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
# Verify the record is deleted within the same context
|
||||
from sqlalchemy import select
|
||||
|
||||
# Verify the record is deleted within the same context
|
||||
from sqlalchemy import select
|
||||
from extensions.ext_database import db as ext_db
|
||||
|
||||
from extensions.ext_database import db as ext_db
|
||||
|
||||
remaining = ext_db.session.scalar(
|
||||
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id)
|
||||
)
|
||||
assert remaining is None
|
||||
remaining = ext_db.session.scalar(
|
||||
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id)
|
||||
)
|
||||
assert remaining is None
|
||||
|
||||
def test_delete_template_raises_when_not_found(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers
|
||||
self, db_session_with_containers: Session, flask_app_with_containers: Flask
|
||||
) -> None:
|
||||
"""delete_customized_pipeline_template raises ValueError when template doesn't exist."""
|
||||
fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4()))
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user):
|
||||
with pytest.raises(ValueError, match="Customized pipeline template not found"):
|
||||
RagPipelineService.delete_customized_pipeline_template(str(uuid4()))
|
||||
with pytest.raises(ValueError, match="Customized pipeline template not found"):
|
||||
RagPipelineService.delete_customized_pipeline_template(str(uuid4()), tenant_id)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -8,6 +8,7 @@ from flask import Flask
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account, Tenant
|
||||
from models.dataset import Dataset, DatasetMetadataBinding, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
@ -33,7 +34,7 @@ def _create_dataset(db_session: Session, *, tenant_id: str, built_in_field_enabl
|
||||
|
||||
|
||||
def _create_document(
|
||||
db_session: Session, *, dataset_id: str, tenant_id: str, doc_metadata: dict | None = None
|
||||
db_session: Session, *, dataset_id: str, tenant_id: str, doc_metadata: dict[str, str] | None = None
|
||||
) -> Document:
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
@ -63,18 +64,21 @@ class TestMetadataPartialUpdate:
|
||||
return str(uuid4())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_account(self, user_id, tenant_id):
|
||||
account = Mock(id=user_id, current_tenant_id=tenant_id)
|
||||
with patch("services.metadata_service.current_account_with_tenant", return_value=(account, tenant_id)):
|
||||
yield account
|
||||
def current_account(self, user_id: str, tenant_id: str) -> Account:
|
||||
account = Account(name="Test User", email=f"{user_id}@example.com")
|
||||
account.id = user_id
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = tenant_id
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
def test_partial_update_merges_metadata(
|
||||
self,
|
||||
flask_app_with_containers: Flask,
|
||||
db_session_with_containers: Session,
|
||||
tenant_id: str,
|
||||
mock_current_account,
|
||||
):
|
||||
current_account: Account,
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
|
||||
document = _create_document(
|
||||
db_session_with_containers,
|
||||
@ -91,7 +95,7 @@ class TestMetadataPartialUpdate:
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args, current_account)
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
updated_doc = db_session_with_containers.get(Document, document.id)
|
||||
@ -104,8 +108,8 @@ class TestMetadataPartialUpdate:
|
||||
flask_app_with_containers: Flask,
|
||||
db_session_with_containers: Session,
|
||||
tenant_id: str,
|
||||
mock_current_account,
|
||||
):
|
||||
current_account: Account,
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
|
||||
document = _create_document(
|
||||
db_session_with_containers,
|
||||
@ -122,7 +126,7 @@ class TestMetadataPartialUpdate:
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args, current_account)
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
updated_doc = db_session_with_containers.get(Document, document.id)
|
||||
@ -134,10 +138,10 @@ class TestMetadataPartialUpdate:
|
||||
self,
|
||||
flask_app_with_containers: Flask,
|
||||
db_session_with_containers: Session,
|
||||
tenant_id,
|
||||
user_id,
|
||||
mock_current_account,
|
||||
):
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
current_account: Account,
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
|
||||
document = _create_document(
|
||||
db_session_with_containers,
|
||||
@ -164,7 +168,7 @@ class TestMetadataPartialUpdate:
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args, current_account)
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
bindings = db_session_with_containers.scalars(
|
||||
@ -180,8 +184,8 @@ class TestMetadataPartialUpdate:
|
||||
flask_app_with_containers: Flask,
|
||||
db_session_with_containers: Session,
|
||||
tenant_id: str,
|
||||
mock_current_account,
|
||||
):
|
||||
current_account: Account,
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
|
||||
document = _create_document(
|
||||
db_session_with_containers,
|
||||
@ -200,4 +204,4 @@ class TestMetadataPartialUpdate:
|
||||
|
||||
with patch("services.metadata_service.db.session.commit", side_effect=RuntimeError("database connection lost")):
|
||||
with pytest.raises(RuntimeError, match="database connection lost"):
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args, current_account)
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from unittest.mock import create_autospec, patch
|
||||
from collections.abc import Generator
|
||||
from typing import TypedDict
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@ -6,21 +8,25 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
|
||||
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document
|
||||
from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class MetadataServiceDeps(TypedDict):
|
||||
redis_client: Mock
|
||||
document_service: Mock
|
||||
|
||||
|
||||
class TestMetadataService:
|
||||
"""Integration tests for MetadataService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
def mock_external_service_dependencies(self) -> Generator[MetadataServiceDeps, None, None]:
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("libs.login.current_user", create_autospec(Account, instance=True)) as mock_current_user,
|
||||
patch("services.metadata_service.redis_client") as mock_redis_client,
|
||||
patch("services.dataset_service.DocumentService") as mock_document_service,
|
||||
):
|
||||
@ -30,12 +36,15 @@ class TestMetadataService:
|
||||
mock_redis_client.delete.return_value = 1
|
||||
|
||||
yield {
|
||||
"current_user": mock_current_user,
|
||||
"redis_client": mock_redis_client,
|
||||
"document_service": mock_document_service,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MetadataServiceDeps,
|
||||
) -> tuple[Account, Tenant]:
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -53,7 +62,7 @@ class TestMetadataService:
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
|
||||
db_session_with_containers.add(account)
|
||||
@ -62,7 +71,7 @@ class TestMetadataService:
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
status=TenantStatus.NORMAL,
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
@ -83,8 +92,12 @@ class TestMetadataService:
|
||||
return account, tenant
|
||||
|
||||
def _create_test_dataset(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, account, tenant
|
||||
):
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MetadataServiceDeps,
|
||||
account: Account,
|
||||
tenant: Tenant,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Helper method to create a test dataset for testing.
|
||||
|
||||
@ -114,8 +127,12 @@ class TestMetadataService:
|
||||
return dataset
|
||||
|
||||
def _create_test_document(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, dataset, account
|
||||
):
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MetadataServiceDeps,
|
||||
dataset: Dataset,
|
||||
account: Account,
|
||||
) -> Document:
|
||||
"""
|
||||
Helper method to create a test document for testing.
|
||||
|
||||
@ -149,7 +166,9 @@ class TestMetadataService:
|
||||
|
||||
return document
|
||||
|
||||
def test_create_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
def test_create_metadata_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test successful metadata creation with valid parameters.
|
||||
"""
|
||||
@ -161,14 +180,10 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
result = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -185,8 +200,8 @@ class TestMetadataService:
|
||||
assert result.created_at is not None
|
||||
|
||||
def test_create_metadata_name_too_long(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata creation fails when name exceeds 255 characters.
|
||||
"""
|
||||
@ -198,20 +213,16 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
long_name = "a" * 256 # 256 characters, exceeding 255 limit
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=long_name)
|
||||
metadata_args = MetadataArgs(type="string", name=long_name)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
|
||||
MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
def test_create_metadata_name_already_exists(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata creation fails when name already exists in the same dataset.
|
||||
"""
|
||||
@ -223,24 +234,20 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create first metadata
|
||||
first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="duplicate_name")
|
||||
MetadataService.create_metadata(dataset.id, first_metadata_args)
|
||||
first_metadata_args = MetadataArgs(type="string", name="duplicate_name")
|
||||
MetadataService.create_metadata(dataset.id, first_metadata_args, account, tenant.id)
|
||||
|
||||
# Try to create second metadata with same name
|
||||
second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="duplicate_name")
|
||||
second_metadata_args = MetadataArgs(type="number", name="duplicate_name")
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError, match="Metadata name already exists."):
|
||||
MetadataService.create_metadata(dataset.id, second_metadata_args)
|
||||
MetadataService.create_metadata(dataset.id, second_metadata_args, account, tenant.id)
|
||||
|
||||
def test_create_metadata_name_conflicts_with_built_in_field(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata creation fails when name conflicts with built-in field names.
|
||||
"""
|
||||
@ -252,21 +259,17 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Try to create metadata with built-in field name
|
||||
built_in_field_name = BuiltInField.document_name
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=built_in_field_name)
|
||||
metadata_args = MetadataArgs(type="string", name=built_in_field_name)
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
|
||||
MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
def test_update_metadata_name_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test successful metadata name update with valid parameters.
|
||||
"""
|
||||
@ -278,17 +281,13 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata first
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
metadata_args = MetadataArgs(type="string", name="old_name")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
# Act: Execute the method under test
|
||||
new_name = "new_name"
|
||||
result = MetadataService.update_metadata_name(dataset.id, metadata.id, new_name)
|
||||
result = MetadataService.update_metadata_name(dataset.id, metadata.id, new_name, account, tenant.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -302,8 +301,8 @@ class TestMetadataService:
|
||||
assert result.name == new_name
|
||||
|
||||
def test_update_metadata_name_too_long(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata name update fails when new name exceeds 255 characters.
|
||||
"""
|
||||
@ -315,24 +314,20 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata first
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
metadata_args = MetadataArgs(type="string", name="old_name")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
# Try to update with too long name
|
||||
long_name = "a" * 256 # 256 characters, exceeding 255 limit
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
|
||||
MetadataService.update_metadata_name(dataset.id, metadata.id, long_name)
|
||||
MetadataService.update_metadata_name(dataset.id, metadata.id, long_name, account, tenant.id)
|
||||
|
||||
def test_update_metadata_name_already_exists(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata name update fails when new name already exists in the same dataset.
|
||||
"""
|
||||
@ -344,24 +339,20 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create two metadata entries
|
||||
first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="first_metadata")
|
||||
first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args)
|
||||
first_metadata_args = MetadataArgs(type="string", name="first_metadata")
|
||||
first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args, account, tenant.id)
|
||||
|
||||
second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="second_metadata")
|
||||
second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args)
|
||||
second_metadata_args = MetadataArgs(type="number", name="second_metadata")
|
||||
second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args, account, tenant.id)
|
||||
|
||||
# Try to update first metadata with second metadata's name
|
||||
with pytest.raises(ValueError, match="Metadata name already exists."):
|
||||
MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata")
|
||||
MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata", account, tenant.id)
|
||||
|
||||
def test_update_metadata_name_conflicts_with_built_in_field(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata name update fails when new name conflicts with built-in field names.
|
||||
"""
|
||||
@ -373,23 +364,19 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata first
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
metadata_args = MetadataArgs(type="string", name="old_name")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
# Try to update with built-in field name
|
||||
built_in_field_name = BuiltInField.document_name
|
||||
|
||||
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
|
||||
MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name)
|
||||
MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name, account, tenant.id)
|
||||
|
||||
def test_update_metadata_name_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata name update fails when metadata ID does not exist.
|
||||
"""
|
||||
@ -401,10 +388,6 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Try to update non-existent metadata
|
||||
import uuid
|
||||
|
||||
@ -412,12 +395,14 @@ class TestMetadataService:
|
||||
new_name = "new_name"
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = MetadataService.update_metadata_name(dataset.id, fake_metadata_id, new_name)
|
||||
result = MetadataService.update_metadata_name(dataset.id, fake_metadata_id, new_name, account, tenant.id)
|
||||
|
||||
# Assert: Verify the method returns None when metadata is not found
|
||||
assert result is None
|
||||
|
||||
def test_delete_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
def test_delete_metadata_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test successful metadata deletion with valid parameters.
|
||||
"""
|
||||
@ -429,13 +414,9 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata first
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="to_be_deleted")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
metadata_args = MetadataArgs(type="string", name="to_be_deleted")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = MetadataService.delete_metadata(dataset.id, metadata.id)
|
||||
@ -449,7 +430,9 @@ class TestMetadataService:
|
||||
deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first()
|
||||
assert deleted_metadata is None
|
||||
|
||||
def test_delete_metadata_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
def test_delete_metadata_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata deletion fails when metadata ID does not exist.
|
||||
"""
|
||||
@ -461,10 +444,6 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Try to delete non-existent metadata
|
||||
import uuid
|
||||
|
||||
@ -477,8 +456,8 @@ class TestMetadataService:
|
||||
assert result is None
|
||||
|
||||
def test_delete_metadata_with_document_bindings(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata deletion successfully removes document metadata bindings.
|
||||
"""
|
||||
@ -493,13 +472,9 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, dataset, account
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
# Create metadata binding
|
||||
binding = DatasetMetadataBinding(
|
||||
@ -531,7 +506,9 @@ class TestMetadataService:
|
||||
# Note: The service attempts to update document metadata but may not succeed
|
||||
# due to mock configuration. The main functionality (metadata deletion) is verified.
|
||||
|
||||
def test_get_built_in_fields_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
def test_get_built_in_fields_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test successful retrieval of built-in metadata fields.
|
||||
"""
|
||||
@ -557,8 +534,8 @@ class TestMetadataService:
|
||||
assert "time" in field_types
|
||||
|
||||
def test_enable_built_in_field_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test successful enabling of built-in fields for a dataset.
|
||||
"""
|
||||
@ -573,10 +550,6 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, dataset, account
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Mock DocumentService.get_working_documents_by_dataset_id
|
||||
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [
|
||||
document
|
||||
@ -591,14 +564,14 @@ class TestMetadataService:
|
||||
# Assert: Verify the expected outcomes
|
||||
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.built_in_field_enabled is True
|
||||
assert dataset.built_in_field_enabled
|
||||
|
||||
# Note: Document metadata update depends on DocumentService mock working correctly
|
||||
# The main functionality (enabling built-in fields) is verified
|
||||
|
||||
def test_enable_built_in_field_already_enabled(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test enabling built-in fields when they are already enabled.
|
||||
"""
|
||||
@ -610,10 +583,6 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Enable built-in fields first
|
||||
dataset.built_in_field_enabled = True
|
||||
|
||||
@ -621,7 +590,9 @@ class TestMetadataService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock DocumentService.get_working_documents_by_dataset_id
|
||||
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
|
||||
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = list[
|
||||
Document
|
||||
]()
|
||||
|
||||
# Act: Execute the method under test
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
@ -631,8 +602,8 @@ class TestMetadataService:
|
||||
assert dataset.built_in_field_enabled is True
|
||||
|
||||
def test_enable_built_in_field_with_no_documents(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test enabling built-in fields for a dataset with no documents.
|
||||
"""
|
||||
@ -644,12 +615,10 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Mock DocumentService.get_working_documents_by_dataset_id to return empty list
|
||||
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
|
||||
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = list[
|
||||
Document
|
||||
]()
|
||||
|
||||
# Act: Execute the method under test
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
@ -657,11 +626,11 @@ class TestMetadataService:
|
||||
# Assert: Verify the expected outcomes
|
||||
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.built_in_field_enabled is True
|
||||
assert dataset.built_in_field_enabled
|
||||
|
||||
def test_disable_built_in_field_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test successful disabling of built-in fields for a dataset.
|
||||
"""
|
||||
@ -676,10 +645,6 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, dataset, account
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Enable built-in fields first
|
||||
dataset.built_in_field_enabled = True
|
||||
|
||||
@ -713,8 +678,8 @@ class TestMetadataService:
|
||||
# The main functionality (disabling built-in fields) is verified
|
||||
|
||||
def test_disable_built_in_field_already_disabled(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test disabling built-in fields when they are already disabled.
|
||||
"""
|
||||
@ -726,15 +691,13 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Verify dataset starts with built-in fields disabled
|
||||
assert dataset.built_in_field_enabled is False
|
||||
|
||||
# Mock DocumentService.get_working_documents_by_dataset_id
|
||||
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
|
||||
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = list[
|
||||
Document
|
||||
]()
|
||||
|
||||
# Act: Execute the method under test
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
@ -742,11 +705,11 @@ class TestMetadataService:
|
||||
# Assert: Verify the method returns early without changes
|
||||
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.built_in_field_enabled is False
|
||||
assert not dataset.built_in_field_enabled
|
||||
|
||||
def test_disable_built_in_field_with_no_documents(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test disabling built-in fields for a dataset with no documents.
|
||||
"""
|
||||
@ -758,10 +721,6 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Enable built-in fields first
|
||||
dataset.built_in_field_enabled = True
|
||||
|
||||
@ -769,7 +728,9 @@ class TestMetadataService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock DocumentService.get_working_documents_by_dataset_id to return empty list
|
||||
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
|
||||
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = list[
|
||||
Document
|
||||
]()
|
||||
|
||||
# Act: Execute the method under test
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
@ -779,8 +740,8 @@ class TestMetadataService:
|
||||
assert dataset.built_in_field_enabled is False
|
||||
|
||||
def test_update_documents_metadata_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test successful update of documents metadata.
|
||||
"""
|
||||
@ -795,13 +756,9 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, dataset, account
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
# Mock DocumentService.get_document
|
||||
mock_external_service_dependencies["document_service"].get_document.return_value = document
|
||||
@ -820,7 +777,7 @@ class TestMetadataService:
|
||||
operation_data = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act: Execute the method under test
|
||||
MetadataService.update_documents_metadata(dataset, operation_data)
|
||||
MetadataService.update_documents_metadata(dataset, operation_data, account)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
||||
@ -841,8 +798,8 @@ class TestMetadataService:
|
||||
assert binding.dataset_id == dataset.id
|
||||
|
||||
def test_update_documents_metadata_with_built_in_fields_enabled(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test update of documents metadata when built-in fields are enabled.
|
||||
"""
|
||||
@ -863,13 +820,9 @@ class TestMetadataService:
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
# Mock DocumentService.get_document
|
||||
mock_external_service_dependencies["document_service"].get_document.return_value = document
|
||||
@ -888,7 +841,7 @@ class TestMetadataService:
|
||||
operation_data = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act: Execute the method under test
|
||||
MetadataService.update_documents_metadata(dataset, operation_data)
|
||||
MetadataService.update_documents_metadata(dataset, operation_data, account)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify document metadata was updated with both custom and built-in fields
|
||||
@ -901,8 +854,8 @@ class TestMetadataService:
|
||||
# The main functionality (custom metadata update) is verified
|
||||
|
||||
def test_update_documents_metadata_document_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test update of documents metadata when document is not found.
|
||||
"""
|
||||
@ -914,13 +867,9 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
# Create metadata operation data
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
@ -941,11 +890,11 @@ class TestMetadataService:
|
||||
# Act & Assert: The method should raise ValueError("Document not found.")
|
||||
# because the exception is now re-raised after rollback
|
||||
with pytest.raises(ValueError, match="Document not found"):
|
||||
MetadataService.update_documents_metadata(dataset, operation_data)
|
||||
MetadataService.update_documents_metadata(dataset, operation_data, account)
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_dataset_id(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata lock check for dataset operations.
|
||||
"""
|
||||
@ -967,8 +916,8 @@ class TestMetadataService:
|
||||
assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}"
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_document_id(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata lock check for document operations.
|
||||
"""
|
||||
@ -990,8 +939,8 @@ class TestMetadataService:
|
||||
assert call_args[0][0] == f"document_metadata_lock_{document_id}"
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_lock_exists(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata lock check when lock already exists.
|
||||
"""
|
||||
@ -1007,8 +956,8 @@ class TestMetadataService:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_document_lock_exists(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test metadata lock check when document lock already exists.
|
||||
"""
|
||||
@ -1022,8 +971,8 @@ class TestMetadataService:
|
||||
MetadataService.knowledge_base_metadata_lock_check(None, document_id)
|
||||
|
||||
def test_get_dataset_metadatas_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test successful retrieval of dataset metadata information.
|
||||
"""
|
||||
@ -1035,13 +984,9 @@ class TestMetadataService:
|
||||
db_session_with_containers, mock_external_service_dependencies, account, tenant
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
# Create document and metadata binding
|
||||
document = self._create_test_document(
|
||||
@ -1079,8 +1024,8 @@ class TestMetadataService:
|
||||
assert result["built_in_field_enabled"] is False
|
||||
|
||||
def test_get_dataset_metadatas_with_built_in_fields_enabled(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test retrieval of dataset metadata when built-in fields are enabled.
|
||||
"""
|
||||
@ -1098,13 +1043,9 @@ class TestMetadataService:
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
mock_external_service_dependencies["current_user"].id = account.id
|
||||
|
||||
# Create metadata
|
||||
metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
metadata_args = MetadataArgs(type="string", name="test_metadata")
|
||||
metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id)
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = MetadataService.get_dataset_metadatas(dataset)
|
||||
@ -1122,8 +1063,8 @@ class TestMetadataService:
|
||||
assert result["built_in_field_enabled"] is True
|
||||
|
||||
def test_get_dataset_metadatas_no_metadata(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps
|
||||
) -> None:
|
||||
"""
|
||||
Test retrieval of dataset metadata when no metadata exists.
|
||||
"""
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
PipelineTemplateListApi,
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
)
|
||||
from models.account import Account
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
|
||||
@ -50,24 +51,31 @@ def _payload() -> dict[str, object]:
|
||||
}
|
||||
|
||||
|
||||
def _account() -> Account:
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.id = "account-1"
|
||||
return account
|
||||
|
||||
|
||||
class TestPipelineTemplateListApi:
|
||||
def test_get_uses_query_defaults_and_serializes_nullable_fields(self, app: Flask) -> None:
|
||||
api = PipelineTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
service_calls: list[tuple[str, str]] = []
|
||||
tenant_id = "tenant-1"
|
||||
service_calls: list[tuple[str, str, str]] = []
|
||||
|
||||
def get_pipeline_templates(template_type: str, language: str) -> dict[str, object]:
|
||||
service_calls.append((template_type, language))
|
||||
def get_pipeline_templates(template_type: str, language: str, current_tenant_id: str) -> dict[str, object]:
|
||||
service_calls.append((template_type, language, current_tenant_id))
|
||||
return {"pipeline_templates": [_template_item()]}
|
||||
|
||||
with (
|
||||
app.test_request_context("/rag/pipeline/templates"),
|
||||
patch.object(module.RagPipelineService, "get_pipeline_templates", side_effect=get_pipeline_templates),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, tenant_id)
|
||||
|
||||
assert status == 200
|
||||
assert service_calls == [("built-in", "en-US")]
|
||||
assert service_calls == [("built-in", "en-US", tenant_id)]
|
||||
assert response == {
|
||||
"pipeline_templates": [
|
||||
{
|
||||
@ -81,21 +89,22 @@ class TestPipelineTemplateListApi:
|
||||
def test_get_passes_explicit_query_to_service(self, app: Flask) -> None:
|
||||
api = PipelineTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
service_calls: list[tuple[str, str]] = []
|
||||
tenant_id = "tenant-1"
|
||||
service_calls: list[tuple[str, str, str]] = []
|
||||
|
||||
def get_pipeline_templates(template_type: str, language: str) -> dict[str, object]:
|
||||
service_calls.append((template_type, language))
|
||||
def get_pipeline_templates(template_type: str, language: str, current_tenant_id: str) -> dict[str, object]:
|
||||
service_calls.append((template_type, language, current_tenant_id))
|
||||
return {"pipeline_templates": []}
|
||||
|
||||
with (
|
||||
app.test_request_context("/rag/pipeline/templates?type=customized&language=ja-JP"),
|
||||
patch.object(module.RagPipelineService, "get_pipeline_templates", side_effect=get_pipeline_templates),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, tenant_id)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"pipeline_templates": []}
|
||||
assert service_calls == [("customized", "ja-JP")]
|
||||
assert service_calls == [("customized", "ja-JP", tenant_id)]
|
||||
|
||||
|
||||
class TestPipelineTemplateDetailApi:
|
||||
@ -140,22 +149,28 @@ class TestCustomizedPipelineTemplateApi:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.patch)
|
||||
payload = _payload()
|
||||
service_calls: list[tuple[str, PipelineTemplateInfoEntity]] = []
|
||||
account = _account()
|
||||
tenant_id = "tenant-1"
|
||||
service_calls: list[tuple[str, PipelineTemplateInfoEntity, Account, str]] = []
|
||||
|
||||
def update_template(template_id: str, template_info: PipelineTemplateInfoEntity) -> None:
|
||||
service_calls.append((template_id, template_info))
|
||||
def update_template(
|
||||
template_id: str, template_info: PipelineTemplateInfoEntity, current_user: Account, current_tenant_id: str
|
||||
) -> None:
|
||||
service_calls.append((template_id, template_info, current_user, current_tenant_id))
|
||||
|
||||
with (
|
||||
app.test_request_context("/rag/pipeline/customized/templates/template-1", method="PATCH", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(module.RagPipelineService, "update_customized_pipeline_template", side_effect=update_template),
|
||||
):
|
||||
response, status = method(api, "template-1")
|
||||
response, status = method(api, tenant_id, account, "template-1")
|
||||
|
||||
assert (response, status) == ("", 204)
|
||||
assert len(service_calls) == 1
|
||||
template_id, template_info = service_calls[0]
|
||||
template_id, template_info, current_user, current_tenant_id = service_calls[0]
|
||||
assert template_id == "template-1"
|
||||
assert current_user is account
|
||||
assert current_tenant_id == tenant_id
|
||||
assert template_info.name == "Updated template"
|
||||
assert template_info.description == "Updated description"
|
||||
assert template_info.icon_info.model_dump() == {
|
||||
@ -172,22 +187,28 @@ class TestCustomizedPipelineTemplateApi:
|
||||
"name": "Updated template",
|
||||
"description": "Updated description",
|
||||
}
|
||||
service_calls: list[tuple[str, PipelineTemplateInfoEntity]] = []
|
||||
account = _account()
|
||||
tenant_id = "tenant-1"
|
||||
service_calls: list[tuple[str, PipelineTemplateInfoEntity, Account, str]] = []
|
||||
|
||||
def update_template(template_id: str, template_info: PipelineTemplateInfoEntity) -> None:
|
||||
service_calls.append((template_id, template_info))
|
||||
def update_template(
|
||||
template_id: str, template_info: PipelineTemplateInfoEntity, current_user: Account, current_tenant_id: str
|
||||
) -> None:
|
||||
service_calls.append((template_id, template_info, current_user, current_tenant_id))
|
||||
|
||||
with (
|
||||
app.test_request_context("/rag/pipeline/customized/templates/template-1", method="PATCH", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(module.RagPipelineService, "update_customized_pipeline_template", side_effect=update_template),
|
||||
):
|
||||
response, status = method(api, "template-1")
|
||||
response, status = method(api, tenant_id, account, "template-1")
|
||||
|
||||
assert (response, status) == ("", 204)
|
||||
assert len(service_calls) == 1
|
||||
template_id, template_info = service_calls[0]
|
||||
template_id, template_info, current_user, current_tenant_id = service_calls[0]
|
||||
assert template_id == "template-1"
|
||||
assert current_user is account
|
||||
assert current_tenant_id == tenant_id
|
||||
assert template_info.icon_info.model_dump() == {
|
||||
"icon": "",
|
||||
"icon_background": None,
|
||||
@ -198,19 +219,20 @@ class TestCustomizedPipelineTemplateApi:
|
||||
def test_delete_returns_empty_204(self, app: Flask) -> None:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.delete)
|
||||
deleted_template_ids: list[str] = []
|
||||
tenant_id = "tenant-1"
|
||||
deleted_templates: list[tuple[str, str]] = []
|
||||
|
||||
def delete_template(template_id: str) -> None:
|
||||
deleted_template_ids.append(template_id)
|
||||
def delete_template(template_id: str, current_tenant_id: str) -> None:
|
||||
deleted_templates.append((template_id, current_tenant_id))
|
||||
|
||||
with (
|
||||
app.test_request_context("/rag/pipeline/customized/templates/template-1", method="DELETE"),
|
||||
patch.object(module.RagPipelineService, "delete_customized_pipeline_template", side_effect=delete_template),
|
||||
):
|
||||
response, status = method(api, "template-1")
|
||||
response, status = method(api, tenant_id, "template-1")
|
||||
|
||||
assert (response, status) == ("", 204)
|
||||
assert deleted_template_ids == ["template-1"]
|
||||
assert deleted_templates == [("template-1", tenant_id)]
|
||||
|
||||
def test_post_exports_yaml_from_orm_template(self, app: Flask) -> None:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
@ -292,21 +314,25 @@ class TestPublishCustomizedPipelineTemplateApi:
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
payload = _payload()
|
||||
service_calls: list[tuple[str, dict[str, object]]] = []
|
||||
account = _account()
|
||||
tenant_id = "tenant-1"
|
||||
service_calls: list[tuple[str, dict[str, object], Account, str]] = []
|
||||
|
||||
class Service:
|
||||
def publish_customized_pipeline_template(self, pipeline_id: str, data: dict[str, object]) -> None:
|
||||
service_calls.append((pipeline_id, data))
|
||||
def publish_customized_pipeline_template(
|
||||
self, pipeline_id: str, data: dict[str, object], current_user: Account, current_tenant_id: str
|
||||
) -> None:
|
||||
service_calls.append((pipeline_id, data, current_user, current_tenant_id))
|
||||
|
||||
with (
|
||||
app.test_request_context("/rag/pipelines/pipeline-1/customized/publish", method="POST", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(module, "RagPipelineService", Service),
|
||||
):
|
||||
response, status = method(api, "pipeline-1")
|
||||
response, status = method(api, tenant_id, account, "pipeline-1")
|
||||
|
||||
assert (response, status) == ("", 204)
|
||||
assert service_calls == [("pipeline-1", payload)]
|
||||
assert service_calls == [("pipeline-1", payload, account, tenant_id)]
|
||||
|
||||
def test_post_allows_missing_icon_info_for_publish_service_fallback(self, app: Flask) -> None:
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
@ -315,18 +341,22 @@ class TestPublishCustomizedPipelineTemplateApi:
|
||||
"name": "Published template",
|
||||
"description": "Description",
|
||||
}
|
||||
service_calls: list[tuple[str, dict[str, object]]] = []
|
||||
account = _account()
|
||||
tenant_id = "tenant-1"
|
||||
service_calls: list[tuple[str, dict[str, object], Account, str]] = []
|
||||
|
||||
class Service:
|
||||
def publish_customized_pipeline_template(self, pipeline_id: str, data: dict[str, object]) -> None:
|
||||
service_calls.append((pipeline_id, data))
|
||||
def publish_customized_pipeline_template(
|
||||
self, pipeline_id: str, data: dict[str, object], current_user: Account, current_tenant_id: str
|
||||
) -> None:
|
||||
service_calls.append((pipeline_id, data, current_user, current_tenant_id))
|
||||
|
||||
with (
|
||||
app.test_request_context("/rag/pipelines/pipeline-1/customized/publish", method="POST", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(module, "RagPipelineService", Service),
|
||||
):
|
||||
response, status = method(api, "pipeline-1")
|
||||
response, status = method(api, tenant_id, account, "pipeline-1")
|
||||
|
||||
assert (response, status) == ("", 204)
|
||||
assert service_calls == [
|
||||
@ -341,5 +371,7 @@ class TestPublishCustomizedPipelineTemplateApi:
|
||||
"icon_url": None,
|
||||
},
|
||||
},
|
||||
account,
|
||||
tenant_id,
|
||||
)
|
||||
]
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from inspect import unwrap
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@ -8,16 +9,10 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.hit_testing import HitTestingApi
|
||||
from models.account import Account, Tenant, TenantAccountRole
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""Recursively unwrap decorated functions."""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_hit_testing")
|
||||
@ -35,6 +30,17 @@ def dataset():
|
||||
return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def account() -> Account:
|
||||
account = Account(name="User", email="user@example.com")
|
||||
account.id = "account-1"
|
||||
tenant = Tenant(name="Tenant")
|
||||
tenant.id = "tenant-1"
|
||||
account._current_tenant = tenant
|
||||
account.role = TenantAccountRole.OWNER
|
||||
return account
|
||||
|
||||
|
||||
def hit_testing_record() -> dict[str, object]:
|
||||
return {
|
||||
"segment": {
|
||||
@ -98,7 +104,7 @@ def bypass_decorators(mocker: MockerFixture):
|
||||
|
||||
|
||||
class TestHitTestingApi:
|
||||
def test_hit_testing_success(self, app: Flask, dataset, dataset_id):
|
||||
def test_hit_testing_success(self, app: Flask, dataset, dataset_id, account: Account):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -129,13 +135,13 @@ class TestHitTestingApi:
|
||||
return_value={"query": {"content": "what is vector search"}, "records": []},
|
||||
),
|
||||
):
|
||||
result = method(api, dataset_id)
|
||||
result = method(api, account, "tenant-1", dataset_id)
|
||||
|
||||
assert "query" in result
|
||||
assert "records" in result
|
||||
assert result["records"] == []
|
||||
|
||||
def test_hit_testing_success_with_optional_record_fields(self, app: Flask, dataset, dataset_id):
|
||||
def test_hit_testing_success_with_optional_record_fields(self, app: Flask, dataset, dataset_id, account: Account):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -167,7 +173,7 @@ class TestHitTestingApi:
|
||||
return_value={"query": {"content": payload["query"]}, "records": records},
|
||||
),
|
||||
):
|
||||
result = method(api, dataset_id)
|
||||
result = method(api, account, "tenant-1", dataset_id)
|
||||
|
||||
assert result["query"] == {"content": payload["query"]}
|
||||
assert result["records"][0]["segment"]["keywords"] == []
|
||||
@ -175,7 +181,7 @@ class TestHitTestingApi:
|
||||
assert result["records"][0]["files"] == []
|
||||
assert result["records"][0]["score"] is None
|
||||
|
||||
def test_hit_testing_dataset_not_found(self, app: Flask, dataset_id):
|
||||
def test_hit_testing_dataset_not_found(self, app: Flask, dataset_id, account: Account):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -198,9 +204,9 @@ class TestHitTestingApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
method(api, dataset_id)
|
||||
method(api, account, "tenant-1", dataset_id)
|
||||
|
||||
def test_hit_testing_invalid_args(self, app: Flask, dataset, dataset_id):
|
||||
def test_hit_testing_invalid_args(self, app: Flask, dataset, dataset_id, account: Account):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -228,4 +234,4 @@ class TestHitTestingApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Invalid parameters"):
|
||||
method(api, dataset_id)
|
||||
method(api, account, "tenant-1", dataset_id)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
@ -21,7 +21,7 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from models.account import Account
|
||||
from models.account import Account, Tenant, TenantAccountRole
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@ -29,19 +29,15 @@ from services.hit_testing_service import HitTestingService
|
||||
|
||||
@pytest.fixture
|
||||
def account():
|
||||
acc = MagicMock(spec=Account)
|
||||
acc = Account(name="User", email="user@example.com")
|
||||
acc.id = "account-1"
|
||||
tenant = Tenant(name="Tenant")
|
||||
tenant.id = "tenant-1"
|
||||
acc._current_tenant = tenant
|
||||
acc.role = TenantAccountRole.OWNER
|
||||
return acc
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_current_user(mocker, account):
|
||||
"""Patch current_user to a valid Account."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing_base.current_user",
|
||||
account,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset():
|
||||
return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1")
|
||||
@ -86,7 +82,7 @@ def hit_testing_record() -> dict[str, object]:
|
||||
|
||||
|
||||
class TestGetAndValidateDataset:
|
||||
def test_success(self, dataset):
|
||||
def test_success(self, dataset, account):
|
||||
with (
|
||||
patch.object(
|
||||
DatasetService,
|
||||
@ -98,20 +94,20 @@ class TestGetAndValidateDataset:
|
||||
"check_dataset_permission",
|
||||
),
|
||||
):
|
||||
result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
|
||||
result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1", account, "tenant-1")
|
||||
|
||||
assert result == dataset
|
||||
|
||||
def test_dataset_not_found(self):
|
||||
def test_dataset_not_found(self, account):
|
||||
with patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
|
||||
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1", account, "tenant-1")
|
||||
|
||||
def test_permission_denied(self, dataset):
|
||||
def test_permission_denied(self, dataset, account):
|
||||
with (
|
||||
patch.object(
|
||||
DatasetService,
|
||||
@ -125,7 +121,7 @@ class TestGetAndValidateDataset:
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden, match="no access"):
|
||||
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
|
||||
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1", account, "tenant-1")
|
||||
|
||||
|
||||
class TestHitTestingArgsCheck:
|
||||
@ -164,7 +160,7 @@ class TestParseArgs:
|
||||
|
||||
|
||||
class TestPerformHitTesting:
|
||||
def test_success(self, dataset):
|
||||
def test_success(self, dataset, account):
|
||||
response = {
|
||||
"query": {"content": "hello"},
|
||||
"records": [],
|
||||
@ -175,12 +171,12 @@ class TestPerformHitTesting:
|
||||
"retrieve",
|
||||
return_value=response,
|
||||
):
|
||||
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
|
||||
|
||||
assert result["query"] == {"content": "hello"}
|
||||
assert result["records"] == []
|
||||
|
||||
def test_success_prepares_nullable_list_fields(self, dataset):
|
||||
def test_success_prepares_nullable_list_fields(self, dataset, account):
|
||||
response = {
|
||||
"query": {"content": "hello"},
|
||||
"records": [hit_testing_record()],
|
||||
@ -191,7 +187,7 @@ class TestPerformHitTesting:
|
||||
"retrieve",
|
||||
return_value=response,
|
||||
):
|
||||
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
|
||||
|
||||
assert result["query"] == {"content": "hello"}
|
||||
record = result["records"][0]
|
||||
@ -203,7 +199,7 @@ class TestPerformHitTesting:
|
||||
assert record["tsne_position"] is None
|
||||
assert record["summary"] is None
|
||||
|
||||
def test_invalid_query_response_raises_value_error(self, dataset):
|
||||
def test_invalid_query_response_raises_value_error(self, dataset, account):
|
||||
with (
|
||||
patch.object(
|
||||
HitTestingService,
|
||||
@ -212,7 +208,7 @@ class TestPerformHitTesting:
|
||||
),
|
||||
pytest.raises(ValueError, match="Invalid hit testing query response"),
|
||||
):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
|
||||
|
||||
def test_invalid_records_response_raises_value_error(self):
|
||||
with pytest.raises(ValueError, match="Invalid hit testing records response"):
|
||||
@ -222,74 +218,74 @@ class TestPerformHitTesting:
|
||||
with pytest.raises(ValueError, match="Invalid hit testing record response"):
|
||||
DatasetsHitTestingBase._prepare_hit_testing_records(["record"])
|
||||
|
||||
def test_index_not_initialized(self, dataset):
|
||||
def test_index_not_initialized(self, dataset, account):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=services.errors.index.IndexNotInitializedError(),
|
||||
):
|
||||
with pytest.raises(DatasetNotInitializedError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
|
||||
|
||||
def test_provider_token_not_init(self, dataset):
|
||||
def test_provider_token_not_init(self, dataset, account):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=ProviderTokenNotInitError("token missing"),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
|
||||
|
||||
def test_quota_exceeded(self, dataset):
|
||||
def test_quota_exceeded(self, dataset, account):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=QuotaExceededError(),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
|
||||
|
||||
def test_model_not_supported(self, dataset):
|
||||
def test_model_not_supported(self, dataset, account):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
|
||||
|
||||
def test_llm_bad_request(self, dataset):
|
||||
def test_llm_bad_request(self, dataset, account):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=LLMBadRequestError("bad request"),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
|
||||
|
||||
def test_invoke_error(self, dataset):
|
||||
def test_invoke_error(self, dataset, account):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=InvokeError("invoke failed"),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
|
||||
|
||||
def test_value_error(self, dataset):
|
||||
def test_value_error(self, dataset, account):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=ValueError("bad args"),
|
||||
):
|
||||
with pytest.raises(ValueError, match="bad args"):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
|
||||
|
||||
def test_unexpected_error(self, dataset):
|
||||
def test_unexpected_error(self, dataset, account):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=Exception("boom"),
|
||||
):
|
||||
with pytest.raises(InternalServerError, match="boom"):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from inspect import unwrap
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@ -14,6 +15,7 @@ from controllers.console.datasets.metadata import (
|
||||
DatasetMetadataCreateApi,
|
||||
DocumentMetadataEditApi,
|
||||
)
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
@ -22,13 +24,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""Recursively unwrap decorated functions."""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_dataset_metadata")
|
||||
@ -37,8 +32,8 @@ def app():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user():
|
||||
user = MagicMock()
|
||||
def current_user() -> Account:
|
||||
user = Account(name="Test User", email="test@example.com")
|
||||
user.id = "user-1"
|
||||
return user
|
||||
|
||||
@ -116,7 +111,7 @@ class TestDatasetMetadataCreateApi:
|
||||
return_value={"id": "m1", "type": "string", "name": "author"},
|
||||
),
|
||||
):
|
||||
result, status = method(api, current_user, dataset_id)
|
||||
result, status = method(api, "tenant-1", current_user, dataset_id)
|
||||
|
||||
assert status == 201
|
||||
assert result["type"] == "string"
|
||||
@ -151,7 +146,7 @@ class TestDatasetMetadataCreateApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
method(api, current_user, dataset_id)
|
||||
method(api, "tenant-1", current_user, dataset_id)
|
||||
|
||||
|
||||
class TestDatasetMetadataGetApi:
|
||||
@ -227,7 +222,7 @@ class TestDatasetMetadataApi:
|
||||
return_value={"id": "m1", "type": "string", "name": "updated-name"},
|
||||
),
|
||||
):
|
||||
result, status = method(api, current_user, dataset_id, metadata_id)
|
||||
result, status = method(api, "tenant-1", current_user, dataset_id, metadata_id)
|
||||
|
||||
assert status == 200
|
||||
assert result["type"] == "string"
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from inspect import unwrap
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@ -15,15 +16,11 @@ from models.model import AppMode
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user():
|
||||
return MagicMock(spec=Account)
|
||||
account = Account(name="User", email="user.com")
|
||||
account.id = "uid"
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -59,7 +56,6 @@ class TestCompletionApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -71,18 +67,18 @@ class TestCompletionApi:
|
||||
return_value=("ok", 200),
|
||||
),
|
||||
):
|
||||
result = method(completion_app)
|
||||
result = method(api, user, completion_app)
|
||||
|
||||
assert result == ("ok", 200)
|
||||
|
||||
def test_post_wrong_app_mode(self):
|
||||
def test_post_wrong_app_mode(self, user):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT))
|
||||
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app)
|
||||
method(api, user, installed_app)
|
||||
|
||||
def test_conversation_completed(self, app: Flask, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
@ -91,7 +87,6 @@ class TestCompletionApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -99,7 +94,7 @@ class TestCompletionApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ConversationCompletedError):
|
||||
method(completion_app)
|
||||
method(api, user, completion_app)
|
||||
|
||||
def test_internal_error(self, app: Flask, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
@ -108,7 +103,6 @@ class TestCompletionApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -116,7 +110,7 @@ class TestCompletionApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(completion_app)
|
||||
method(api, user, completion_app)
|
||||
|
||||
def test_conversation_not_exists(self, app: Flask, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
@ -125,7 +119,6 @@ class TestCompletionApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -133,7 +126,7 @@ class TestCompletionApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.NotFound):
|
||||
method(completion_app)
|
||||
method(api, user, completion_app)
|
||||
|
||||
def test_app_unavailable(self, app: Flask, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
@ -142,7 +135,6 @@ class TestCompletionApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -150,7 +142,7 @@ class TestCompletionApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.AppUnavailableError):
|
||||
method(completion_app)
|
||||
method(api, user, completion_app)
|
||||
|
||||
def test_provider_not_initialized(self, app: Flask, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
@ -159,7 +151,6 @@ class TestCompletionApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -167,7 +158,7 @@ class TestCompletionApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(completion_app)
|
||||
method(api, user, completion_app)
|
||||
|
||||
def test_quota_exceeded(self, app: Flask, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
@ -176,7 +167,6 @@ class TestCompletionApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -184,7 +174,7 @@ class TestCompletionApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderQuotaExceededError):
|
||||
method(completion_app)
|
||||
method(api, user, completion_app)
|
||||
|
||||
def test_model_not_supported(self, app: Flask, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
@ -193,7 +183,6 @@ class TestCompletionApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -201,7 +190,7 @@ class TestCompletionApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError):
|
||||
method(completion_app)
|
||||
method(api, user, completion_app)
|
||||
|
||||
def test_invoke_error(self, app: Flask, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
@ -210,7 +199,6 @@ class TestCompletionApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -218,7 +206,7 @@ class TestCompletionApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.CompletionRequestError):
|
||||
method(completion_app)
|
||||
method(api, user, completion_app)
|
||||
|
||||
|
||||
class TestCompletionStopApi:
|
||||
@ -250,7 +238,6 @@ class TestChatApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -262,18 +249,18 @@ class TestChatApi:
|
||||
return_value=("ok", 200),
|
||||
),
|
||||
):
|
||||
result = method(chat_app)
|
||||
result = method(api, user, chat_app)
|
||||
|
||||
assert result == ("ok", 200)
|
||||
|
||||
def test_post_not_chat_app(self):
|
||||
def test_post_not_chat_app(self, user):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
|
||||
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(installed_app)
|
||||
method(api, user, installed_app)
|
||||
|
||||
def test_rate_limit_error(self, app: Flask, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
@ -282,7 +269,6 @@ class TestChatApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -290,7 +276,7 @@ class TestChatApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(chat_app)
|
||||
method(api, user, chat_app)
|
||||
|
||||
def test_conversation_completed_chat(self, app: Flask, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
@ -299,7 +285,6 @@ class TestChatApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -307,7 +292,7 @@ class TestChatApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ConversationCompletedError):
|
||||
method(chat_app)
|
||||
method(api, user, chat_app)
|
||||
|
||||
def test_conversation_not_exists_chat(self, app: Flask, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
@ -316,7 +301,6 @@ class TestChatApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -324,7 +308,7 @@ class TestChatApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.NotFound):
|
||||
method(chat_app)
|
||||
method(api, user, chat_app)
|
||||
|
||||
def test_app_unavailable_chat(self, app: Flask, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
@ -333,7 +317,6 @@ class TestChatApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -341,7 +324,7 @@ class TestChatApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.AppUnavailableError):
|
||||
method(chat_app)
|
||||
method(api, user, chat_app)
|
||||
|
||||
def test_provider_not_initialized_chat(self, app: Flask, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
@ -350,7 +333,6 @@ class TestChatApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -358,7 +340,7 @@ class TestChatApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(chat_app)
|
||||
method(api, user, chat_app)
|
||||
|
||||
def test_quota_exceeded_chat(self, app: Flask, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
@ -367,7 +349,6 @@ class TestChatApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -375,7 +356,7 @@ class TestChatApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderQuotaExceededError):
|
||||
method(chat_app)
|
||||
method(api, user, chat_app)
|
||||
|
||||
def test_model_not_supported_chat(self, app: Flask, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
@ -384,7 +365,6 @@ class TestChatApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -392,7 +372,7 @@ class TestChatApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError):
|
||||
method(chat_app)
|
||||
method(api, user, chat_app)
|
||||
|
||||
def test_invoke_error_chat(self, app: Flask, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
@ -401,7 +381,6 @@ class TestChatApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -409,7 +388,7 @@ class TestChatApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.CompletionRequestError):
|
||||
method(chat_app)
|
||||
method(api, user, chat_app)
|
||||
|
||||
def test_internal_error_chat(self, app: Flask, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
@ -418,7 +397,6 @@ class TestChatApi:
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -426,7 +404,7 @@ class TestChatApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(chat_app)
|
||||
method(api, user, chat_app)
|
||||
|
||||
|
||||
class TestChatStopApi:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,32 +1,36 @@
|
||||
from inspect import unwrap
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from models import Account
|
||||
from services.feature_service import FeatureModel, LimitationModel, SystemFeatureModel
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""
|
||||
Recursively unwrap decorated functions.
|
||||
"""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
def make_account() -> Account:
|
||||
account = Account(name="Alice", email="alice@example.com")
|
||||
account.id = "account-1"
|
||||
return account
|
||||
|
||||
|
||||
class TestFeatureApi:
|
||||
def test_get_tenant_features_success(self, mocker: MockerFixture):
|
||||
from controllers.console.feature import FeatureApi
|
||||
|
||||
features = FeatureModel(
|
||||
knowledge_rate_limit=42,
|
||||
vector_space=LimitationModel(size=1, limit=2),
|
||||
)
|
||||
get_features = mocker.patch("controllers.console.feature.FeatureService.get_features")
|
||||
get_features.return_value.model_dump.return_value = {
|
||||
"features": {"feature_a": True},
|
||||
"vector_space": {"size": 1, "limit": 2},
|
||||
}
|
||||
get_features.return_value = features
|
||||
|
||||
api = FeatureApi()
|
||||
|
||||
raw_get = unwrap(FeatureApi.get)
|
||||
result = raw_get(api, "tenant_123")
|
||||
|
||||
assert result == {"features": {"feature_a": True}}
|
||||
expected = features.model_dump()
|
||||
expected.pop("vector_space")
|
||||
assert result == expected
|
||||
get_features.assert_called_once_with("tenant_123", exclude_vector_space=True)
|
||||
|
||||
|
||||
@ -35,7 +39,7 @@ class TestFeatureVectorSpaceApi:
|
||||
from controllers.console.feature import FeatureVectorSpaceApi
|
||||
|
||||
get_vector_space = mocker.patch("controllers.console.feature.FeatureService.get_vector_space")
|
||||
get_vector_space.return_value.model_dump.return_value = {"size": 5120, "limit": 20480}
|
||||
get_vector_space.return_value = LimitationModel(size=5120, limit=20480)
|
||||
|
||||
api = FeatureVectorSpaceApi()
|
||||
|
||||
@ -85,22 +89,23 @@ class TestSystemFeatureApi:
|
||||
|
||||
from controllers.console.feature import SystemFeatureApi
|
||||
|
||||
fake_user = mocker.Mock()
|
||||
fake_user.is_authenticated = True
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.current_user",
|
||||
fake_user,
|
||||
account = make_account()
|
||||
current_account = mocker.patch(
|
||||
"controllers.console.feature.current_account_with_tenant_optional",
|
||||
return_value=(account, "tenant-123"),
|
||||
)
|
||||
system_features = SystemFeatureModel(is_allow_register=True)
|
||||
get_system_features = mocker.patch(
|
||||
"controllers.console.feature.FeatureService.get_system_features",
|
||||
return_value=system_features,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.FeatureService.get_system_features"
|
||||
).return_value.model_dump.return_value = {"features": {"sys_feature": True}}
|
||||
|
||||
api = SystemFeatureApi()
|
||||
result = api.get()
|
||||
|
||||
assert result == {"features": {"sys_feature": True}}
|
||||
assert result == system_features.model_dump()
|
||||
current_account.assert_called_once_with()
|
||||
get_system_features.assert_called_once_with(is_authenticated=True)
|
||||
|
||||
def test_get_system_features_unauthenticated(self, mocker: MockerFixture):
|
||||
"""
|
||||
@ -109,19 +114,19 @@ class TestSystemFeatureApi:
|
||||
|
||||
from controllers.console.feature import SystemFeatureApi
|
||||
|
||||
fake_user = mocker.Mock()
|
||||
type(fake_user).is_authenticated = mocker.PropertyMock(side_effect=Unauthorized())
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.current_user",
|
||||
fake_user,
|
||||
current_account = mocker.patch(
|
||||
"controllers.console.feature.current_account_with_tenant_optional",
|
||||
return_value=(None, None),
|
||||
)
|
||||
system_features = SystemFeatureModel(is_allow_register=False)
|
||||
get_system_features = mocker.patch(
|
||||
"controllers.console.feature.FeatureService.get_system_features",
|
||||
return_value=system_features,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.FeatureService.get_system_features"
|
||||
).return_value.model_dump.return_value = {"features": {"sys_feature": False}}
|
||||
|
||||
api = SystemFeatureApi()
|
||||
result = api.get()
|
||||
|
||||
assert result == {"features": {"sys_feature": False}}
|
||||
assert result == system_features.model_dump()
|
||||
current_account.assert_called_once_with()
|
||||
get_system_features.assert_called_once_with(is_authenticated=False)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from typing import override
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -36,6 +37,7 @@ class MockUser(UserMixin):
|
||||
self.id = user_id
|
||||
self.current_tenant_id = "tenant123"
|
||||
|
||||
@override
|
||||
def get_id(self) -> str:
|
||||
return self.id
|
||||
|
||||
@ -210,6 +212,7 @@ class TestModelValidationInjection:
|
||||
Handler().post()
|
||||
|
||||
assert exc_info.value.code == 422
|
||||
assert exc_info.value.description is not None
|
||||
assert "count" in exc_info.value.description
|
||||
|
||||
|
||||
|
||||
@ -9,21 +9,21 @@ Strategy:
|
||||
- ``HitTestingApi.post`` is decorated with ``@cloud_edition_billing_rate_limit_check``
|
||||
which preserves ``__wrapped__``. We call ``post.__wrapped__(self, ...)`` to skip
|
||||
the billing decorator and test the business logic directly.
|
||||
- Base-class methods (``get_and_validate_dataset``, ``perform_hit_testing``) read
|
||||
``current_user`` from ``controllers.console.datasets.hit_testing_base``, so we
|
||||
patch it there.
|
||||
- ``validate_dataset_token`` installs the tenant owner account into Flask-Login's
|
||||
request context before calling the handler, so direct method-call tests install
|
||||
the same concrete account on ``g._login_user``.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask import Flask, g
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api.dataset.hit_testing import HitTestingApi, HitTestingPayload
|
||||
from models.account import Account
|
||||
from models.account import Account, Tenant, TenantAccountRole
|
||||
from models.dataset import Dataset
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
|
||||
@ -131,13 +131,21 @@ class TestHitTestingApiPost:
|
||||
def _dataset(dataset_id: str, tenant_id: str) -> Dataset:
|
||||
return Dataset(id=dataset_id, tenant_id=tenant_id, name="Dataset", created_by="account-1")
|
||||
|
||||
@staticmethod
|
||||
def _account(tenant_id: str) -> Account:
|
||||
account = Account(name="Service API", email="service-api@example.com")
|
||||
account.id = "account-1"
|
||||
tenant = Tenant(name="Tenant")
|
||||
tenant.id = tenant_id
|
||||
account._current_tenant = tenant
|
||||
account.role = TenantAccountRole.OWNER
|
||||
return account
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
def test_post_success(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_hit_svc,
|
||||
mock_ns,
|
||||
@ -148,6 +156,7 @@ class TestHitTestingApiPost:
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = self._dataset(dataset_id, tenant_id)
|
||||
account = self._account(tenant_id)
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
@ -158,6 +167,8 @@ class TestHitTestingApiPost:
|
||||
mock_ns.payload = {"query": "test query"}
|
||||
|
||||
with app.test_request_context():
|
||||
# TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack
|
||||
g._login_user = account
|
||||
api = HitTestingApi()
|
||||
# Skip billing decorator via __wrapped__
|
||||
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
@ -168,10 +179,8 @@ class TestHitTestingApiPost:
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
def test_post_with_retrieval_model(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_hit_svc,
|
||||
mock_ns,
|
||||
@ -182,6 +191,7 @@ class TestHitTestingApiPost:
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = self._dataset(dataset_id, tenant_id)
|
||||
account = self._account(tenant_id)
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
@ -204,6 +214,8 @@ class TestHitTestingApiPost:
|
||||
}
|
||||
|
||||
with app.test_request_context():
|
||||
# TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack
|
||||
g._login_user = account
|
||||
api = HitTestingApi()
|
||||
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
|
||||
@ -218,10 +230,8 @@ class TestHitTestingApiPost:
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
def test_post_preserves_retrieval_model_metadata_filtering_conditions(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_hit_svc,
|
||||
mock_ns,
|
||||
@ -232,6 +242,7 @@ class TestHitTestingApiPost:
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = self._dataset(dataset_id, tenant_id)
|
||||
account = self._account(tenant_id)
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
@ -260,6 +271,8 @@ class TestHitTestingApiPost:
|
||||
}
|
||||
|
||||
with app.test_request_context():
|
||||
# TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack
|
||||
g._login_user = account
|
||||
api = HitTestingApi()
|
||||
HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
|
||||
@ -270,10 +283,8 @@ class TestHitTestingApiPost:
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
def test_post_prepares_nullable_list_fields(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_hit_svc,
|
||||
mock_ns,
|
||||
@ -284,6 +295,7 @@ class TestHitTestingApiPost:
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = self._dataset(dataset_id, tenant_id)
|
||||
account = self._account(tenant_id)
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
@ -297,6 +309,8 @@ class TestHitTestingApiPost:
|
||||
mock_ns.payload = {"query": "legacy query"}
|
||||
|
||||
with app.test_request_context():
|
||||
# TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack
|
||||
g._login_user = account
|
||||
api = HitTestingApi()
|
||||
response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
|
||||
@ -312,10 +326,8 @@ class TestHitTestingApiPost:
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
def test_post_dataset_not_found(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_ns,
|
||||
app: Flask,
|
||||
@ -323,21 +335,22 @@ class TestHitTestingApiPost:
|
||||
"""Test hit testing with non-existent dataset."""
|
||||
dataset_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
account = self._account(tenant_id)
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = None
|
||||
mock_ns.payload = {"query": "test query"}
|
||||
|
||||
with app.test_request_context():
|
||||
# TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack
|
||||
g._login_user = account
|
||||
api = HitTestingApi()
|
||||
with pytest.raises(NotFound):
|
||||
HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
def test_post_no_dataset_permission(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_ns,
|
||||
app: Flask,
|
||||
@ -347,6 +360,7 @@ class TestHitTestingApiPost:
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = self._dataset(dataset_id, tenant_id)
|
||||
account = self._account(tenant_id)
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.side_effect = services.errors.account.NoPermissionError(
|
||||
@ -355,6 +369,8 @@ class TestHitTestingApiPost:
|
||||
mock_ns.payload = {"query": "test query"}
|
||||
|
||||
with app.test_request_context():
|
||||
# TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack
|
||||
g._login_user = account
|
||||
api = HitTestingApi()
|
||||
with pytest.raises(Forbidden):
|
||||
HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask, Response, g
|
||||
from flask_login import UserMixin
|
||||
from pytest_mock import MockerFixture
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import libs.login as login_module
|
||||
from extensions.ext_login import DifyLoginManager
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.account import Account, Tenant
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -23,7 +23,7 @@ def protected_view():
|
||||
return _protected_view
|
||||
|
||||
|
||||
class MockUser(UserMixin):
|
||||
class MockUser:
|
||||
"""Mock user class for testing."""
|
||||
|
||||
def __init__(self, id: str, is_authenticated: bool = True):
|
||||
@ -35,6 +35,22 @@ class MockUser(UserMixin):
|
||||
return self._is_authenticated
|
||||
|
||||
|
||||
class LoginManagerStub:
|
||||
def __init__(self, unauthorized_response: Response) -> None:
|
||||
self._unauthorized_response = unauthorized_response
|
||||
|
||||
def unauthorized(self) -> Response:
|
||||
return self._unauthorized_response
|
||||
|
||||
|
||||
def _login_manager(app: Flask) -> DifyLoginManager:
|
||||
return cast(DifyLoginManager, app.__dict__["login_manager"])
|
||||
|
||||
|
||||
def _unauthorized_mock(app: Flask) -> MagicMock:
|
||||
return cast(MagicMock, _login_manager(app).unauthorized)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def login_app(mocker: MockerFixture) -> Flask:
|
||||
app = Flask(__name__)
|
||||
@ -95,7 +111,7 @@ class TestLoginRequired:
|
||||
|
||||
assert result == "Protected content"
|
||||
resolve_user.assert_called_once_with()
|
||||
login_app.login_manager.unauthorized.assert_not_called()
|
||||
_unauthorized_mock(login_app).assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("resolved_user", "description"),
|
||||
@ -120,11 +136,11 @@ class TestLoginRequired:
|
||||
with login_app.test_request_context():
|
||||
result = protected_view()
|
||||
|
||||
assert result is login_app.login_manager.unauthorized.return_value, description
|
||||
assert result is _unauthorized_mock(login_app).return_value, description
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 401
|
||||
resolve_user.assert_called_once_with()
|
||||
login_app.login_manager.unauthorized.assert_called_once_with()
|
||||
_unauthorized_mock(login_app).assert_called_once_with()
|
||||
csrf_check.assert_not_called()
|
||||
|
||||
def test_unauthorized_access_propagates_response_object(
|
||||
@ -138,9 +154,7 @@ class TestLoginRequired:
|
||||
"""Test that unauthorized responses are propagated as Flask Response objects."""
|
||||
resolve_user = resolve_current_user(None)
|
||||
response = Response("Unauthorized", status=401, content_type="application/json")
|
||||
mocker.patch.object(
|
||||
login_module, "_get_login_manager", return_value=SimpleNamespace(unauthorized=lambda: response)
|
||||
)
|
||||
mocker.patch.object(login_module, "_get_login_manager", return_value=LoginManagerStub(response))
|
||||
|
||||
with login_app.test_request_context():
|
||||
result = protected_view()
|
||||
@ -177,7 +191,7 @@ class TestLoginRequired:
|
||||
assert result == "Protected content"
|
||||
resolve_user.assert_not_called()
|
||||
csrf_check.assert_not_called()
|
||||
login_app.login_manager.unauthorized.assert_not_called()
|
||||
_unauthorized_mock(login_app).assert_not_called()
|
||||
|
||||
|
||||
class TestGetUser:
|
||||
@ -191,6 +205,7 @@ class TestGetUser:
|
||||
g._login_user = mock_user
|
||||
user = login_module._get_user()
|
||||
assert user == mock_user
|
||||
assert user is not None
|
||||
assert user.id == "test_user"
|
||||
|
||||
def test_get_user_loads_user_if_not_in_g(self, login_app: Flask, mocker: MockerFixture):
|
||||
@ -201,7 +216,7 @@ class TestGetUser:
|
||||
g._login_user = mock_user
|
||||
|
||||
load_user = mocker.patch.object(
|
||||
login_app.login_manager,
|
||||
_login_manager(login_app),
|
||||
"load_user_from_request_context",
|
||||
side_effect=load_user_from_request_context,
|
||||
)
|
||||
@ -244,7 +259,9 @@ class TestCurrentAccountWithTenant:
|
||||
|
||||
def test_returns_account_and_tenant_id(self, mocker: MockerFixture):
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account._current_tenant = SimpleNamespace(id="tenant-123")
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = "tenant-123"
|
||||
account._current_tenant = tenant
|
||||
current_user_proxy = mocker.Mock()
|
||||
current_user_proxy._get_current_object.return_value = account
|
||||
mocker.patch.object(login_module, "current_user", new=current_user_proxy)
|
||||
@ -267,3 +284,58 @@ class TestCurrentAccountWithTenant:
|
||||
|
||||
with pytest.raises(AssertionError, match="tenant information should be loaded"):
|
||||
login_module.current_account_with_tenant()
|
||||
|
||||
|
||||
class TestCurrentAccountWithTenantOptional:
|
||||
"""Test cases for optional current account resolution."""
|
||||
|
||||
def test_returns_account_and_tenant_id_for_authenticated_account(self, mocker: MockerFixture) -> None:
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = "tenant-123"
|
||||
account._current_tenant = tenant
|
||||
mocker.patch.object(login_module, "_resolve_current_user", return_value=account)
|
||||
|
||||
user, tenant_id = login_module.current_account_with_tenant_optional()
|
||||
|
||||
assert user is account
|
||||
assert tenant_id == "tenant-123"
|
||||
|
||||
def test_returns_none_pair_when_request_loader_raises_unauthorized(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(login_module, "_resolve_current_user", side_effect=Unauthorized())
|
||||
|
||||
user, tenant_id = login_module.current_account_with_tenant_optional()
|
||||
|
||||
assert user is None
|
||||
assert tenant_id is None
|
||||
|
||||
def test_returns_none_pair_when_resolved_user_is_not_account(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(login_module, "_resolve_current_user", return_value=MockUser("end-user"))
|
||||
|
||||
user, tenant_id = login_module.current_account_with_tenant_optional()
|
||||
|
||||
assert user is None
|
||||
assert tenant_id is None
|
||||
|
||||
|
||||
class TestResolveTenantIdFallback:
|
||||
"""Test cases for tenant-only fallback helper."""
|
||||
|
||||
def test_returns_provided_tenant_id_without_current_user_lookup(self, mocker: MockerFixture) -> None:
|
||||
current_account_with_tenant = mocker.patch.object(login_module, "current_account_with_tenant")
|
||||
|
||||
tenant_id = login_module.resolve_tenant_id_fallback("tenant-123")
|
||||
|
||||
assert tenant_id == "tenant-123"
|
||||
current_account_with_tenant.assert_not_called()
|
||||
|
||||
def test_falls_back_to_current_account_tenant(self, mocker: MockerFixture) -> None:
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = "tenant-123"
|
||||
account._current_tenant = tenant
|
||||
mocker.patch.object(login_module, "current_account_with_tenant", return_value=(account, tenant.id))
|
||||
|
||||
tenant_id = login_module.resolve_tenant_id_fallback()
|
||||
|
||||
assert tenant_id == "tenant-123"
|
||||
|
||||
@ -5,10 +5,6 @@ from services.rag_pipeline.pipeline_template.pipeline_template_type import Pipel
|
||||
|
||||
|
||||
def test_get_pipeline_templates(mocker) -> None:
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.pipeline_template.customized.customized_retrieval.current_account_with_tenant",
|
||||
return_value=("account-id", "tenant-id"),
|
||||
)
|
||||
customized_template = SimpleNamespace(
|
||||
id="tpl-1",
|
||||
name="Custom Template",
|
||||
@ -27,7 +23,7 @@ def test_get_pipeline_templates(mocker) -> None:
|
||||
)
|
||||
retrieval = CustomizedPipelineTemplateRetrieval()
|
||||
|
||||
result = retrieval.get_pipeline_templates("en-US")
|
||||
result = retrieval.get_pipeline_templates("en-US", "tenant-id")
|
||||
|
||||
assert retrieval.get_type() == PipelineTemplateType.CUSTOMIZED
|
||||
assert result == {
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from models import Account, Tenant
|
||||
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin
|
||||
from models.workflow import Workflow
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
@ -25,6 +30,89 @@ class MockRepo:
|
||||
pass
|
||||
|
||||
|
||||
def _make_account(account_id: str = "u1", tenant_id: str = "t1") -> Account:
|
||||
account = Account(name="Test User", email=f"{account_id}@example.com")
|
||||
account.id = account_id
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = tenant_id
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
|
||||
def _make_pipeline(
|
||||
*,
|
||||
pipeline_id: str = "p1",
|
||||
tenant_id: str = "t1",
|
||||
workflow_id: str | None = None,
|
||||
is_published: bool = False,
|
||||
) -> Pipeline:
|
||||
pipeline = Pipeline(tenant_id=tenant_id, name="Test Pipeline", description="test")
|
||||
pipeline.id = pipeline_id
|
||||
pipeline.workflow_id = workflow_id
|
||||
pipeline.is_published = is_published
|
||||
return pipeline
|
||||
|
||||
|
||||
def _make_workflow(
|
||||
*,
|
||||
workflow_id: str = "wf-1",
|
||||
tenant_id: str = "t1",
|
||||
app_id: str = "p1",
|
||||
graph: dict[str, object] | None = None,
|
||||
features: dict[str, object] | None = None,
|
||||
created_by: str = "u1",
|
||||
) -> Workflow:
|
||||
workflow = Workflow(
|
||||
id=workflow_id,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
marked_name="",
|
||||
marked_comment="",
|
||||
graph=json.dumps(graph or {"nodes": []}),
|
||||
features=json.dumps(features or {}),
|
||||
created_by=created_by,
|
||||
created_at=datetime(2024, 1, 1),
|
||||
updated_by=None,
|
||||
updated_at=datetime(2024, 1, 1),
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
return workflow
|
||||
|
||||
|
||||
def _make_dataset(*, dataset_id: str = "d1", pipeline_id: str = "p1", tenant_id: str = "t1") -> Dataset:
|
||||
dataset = Dataset(
|
||||
id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
name="Test Dataset",
|
||||
created_by="u1",
|
||||
)
|
||||
dataset.pipeline_id = pipeline_id
|
||||
return dataset
|
||||
|
||||
|
||||
def _make_customized_template() -> PipelineCustomizedTemplate:
|
||||
return PipelineCustomizedTemplate(
|
||||
tenant_id="t1",
|
||||
name="old",
|
||||
description="old",
|
||||
chunk_structure="paragraph",
|
||||
icon={},
|
||||
position=1,
|
||||
yaml_content="",
|
||||
install_count=0,
|
||||
language="en-US",
|
||||
created_by="u1",
|
||||
)
|
||||
|
||||
|
||||
def _make_recommended_plugin(plugin_id: str) -> PipelineRecommendedPlugin:
|
||||
return PipelineRecommendedPlugin(plugin_id=plugin_id, provider_name=plugin_id, type="tool", position=0, active=True)
|
||||
|
||||
|
||||
def test_get_pipeline_templates_fallbacks_to_builtin_for_non_english_empty_result(mocker) -> None:
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE", "remote")
|
||||
|
||||
@ -74,7 +162,7 @@ def test_get_pipeline_template_detail_uses_expected_mode(mocker, template_type:
|
||||
|
||||
|
||||
def test_get_published_workflow_returns_none_when_pipeline_has_no_workflow_id(rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(workflow_id=None)
|
||||
pipeline = _make_pipeline(workflow_id=None)
|
||||
|
||||
result = rag_pipeline_service.get_published_workflow(pipeline)
|
||||
|
||||
@ -82,7 +170,7 @@ def test_get_published_workflow_returns_none_when_pipeline_has_no_workflow_id(ra
|
||||
|
||||
|
||||
def test_get_all_published_workflow_returns_empty_for_unpublished_pipeline(rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(workflow_id=None)
|
||||
pipeline = _make_pipeline(workflow_id=None)
|
||||
session = SimpleNamespace()
|
||||
|
||||
workflows, has_more = rag_pipeline_service.get_all_published_workflow(
|
||||
@ -101,7 +189,7 @@ def test_get_all_published_workflow_returns_empty_for_unpublished_pipeline(rag_p
|
||||
def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_service) -> None:
|
||||
scalars_result = SimpleNamespace(all=lambda: ["wf1", "wf2", "wf3"])
|
||||
session = SimpleNamespace(scalars=lambda stmt: scalars_result)
|
||||
pipeline = SimpleNamespace(id="pipeline-1", workflow_id="wf-live")
|
||||
pipeline = _make_pipeline(pipeline_id="pipeline-1", workflow_id="wf-live")
|
||||
|
||||
workflows, has_more = rag_pipeline_service.get_all_published_workflow(
|
||||
session=session,
|
||||
@ -133,8 +221,8 @@ def test_sync_draft_workflow_creates_new_when_none_exists(mocker, rag_pipeline_s
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.flush")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
|
||||
pipeline = SimpleNamespace(tenant_id="t1", id="p1", workflow_id=None)
|
||||
account = SimpleNamespace(id="u1")
|
||||
pipeline = _make_pipeline(workflow_id=None)
|
||||
account = _make_account()
|
||||
|
||||
result = rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
@ -153,11 +241,11 @@ def test_sync_draft_workflow_creates_new_when_none_exists(mocker, rag_pipeline_s
|
||||
def test_sync_draft_workflow_raises_on_hash_mismatch(mocker, rag_pipeline_service) -> None:
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
|
||||
existing_wf = SimpleNamespace(unique_hash="hash-old")
|
||||
existing_wf = _make_workflow(graph={"nodes": [{"id": "old"}]})
|
||||
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=existing_wf)
|
||||
|
||||
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
|
||||
account = SimpleNamespace(id="u1")
|
||||
pipeline = _make_pipeline()
|
||||
account = _make_account()
|
||||
|
||||
with pytest.raises(WorkflowHashNotEqualError):
|
||||
rag_pipeline_service.sync_draft_workflow(
|
||||
@ -184,8 +272,8 @@ def test_sync_draft_workflow_updates_existing(mocker, rag_pipeline_service) -> N
|
||||
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=existing_wf)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
|
||||
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
|
||||
account = SimpleNamespace(id="u1")
|
||||
pipeline = _make_pipeline()
|
||||
account = _make_account()
|
||||
|
||||
result = rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
@ -275,7 +363,7 @@ def test_get_rag_pipeline_paginate_workflow_runs_delegates(mocker, rag_pipeline_
|
||||
repo_mock.get_paginated_workflow_runs.return_value = expected
|
||||
rag_pipeline_service._workflow_run_repo = repo_mock
|
||||
|
||||
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
|
||||
pipeline = _make_pipeline()
|
||||
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline, {"limit": 10, "last_id": "abc"})
|
||||
|
||||
assert result is expected
|
||||
@ -297,7 +385,7 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) -
|
||||
repo_mock.get_workflow_run_by_id.return_value = expected
|
||||
rag_pipeline_service._workflow_run_repo = repo_mock
|
||||
|
||||
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
|
||||
pipeline = _make_pipeline()
|
||||
result = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline, "run-1")
|
||||
|
||||
assert result is expected
|
||||
@ -310,14 +398,14 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) -
|
||||
def test_is_workflow_exist_returns_true_when_draft_exists(mocker, rag_pipeline_service) -> None:
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=1)
|
||||
|
||||
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
|
||||
pipeline = _make_pipeline()
|
||||
assert rag_pipeline_service.is_workflow_exist(pipeline) is True
|
||||
|
||||
|
||||
def test_is_workflow_exist_returns_false_when_no_draft(mocker, rag_pipeline_service) -> None:
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=0)
|
||||
|
||||
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
|
||||
pipeline = _make_pipeline()
|
||||
assert rag_pipeline_service.is_workflow_exist(pipeline) is False
|
||||
|
||||
|
||||
@ -635,17 +723,10 @@ def test_get_second_step_parameters_success(mocker, rag_pipeline_service) -> Non
|
||||
|
||||
|
||||
def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_service) -> None:
|
||||
from models.dataset import Pipeline
|
||||
|
||||
# 1. Setup mocks
|
||||
pipeline = mocker.Mock(spec=Pipeline)
|
||||
pipeline.id = "p1"
|
||||
pipeline.tenant_id = "t1"
|
||||
pipeline.workflow_id = "wf-1"
|
||||
pipeline.is_published = True
|
||||
pipeline = _make_pipeline(workflow_id="wf-1", is_published=True)
|
||||
|
||||
workflow = mocker.Mock()
|
||||
workflow.id = "wf-1"
|
||||
workflow = _make_workflow(workflow_id="wf-1")
|
||||
|
||||
# Mock db itself to avoid app context errors
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
@ -656,8 +737,8 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
|
||||
mock_db.session.scalar.side_effect = [None, 5]
|
||||
|
||||
# Mock retrieve_dataset
|
||||
dataset = mocker.Mock()
|
||||
pipeline.retrieve_dataset.return_value = dataset
|
||||
dataset = _make_dataset()
|
||||
dataset.chunk_structure = "paragraph"
|
||||
|
||||
# Mock RagPipelineDslService
|
||||
mock_dsl_service = mocker.Mock()
|
||||
@ -665,16 +746,14 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.RagPipelineDslService", return_value=mock_dsl_service)
|
||||
|
||||
# Mock Session and commit
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.Session", return_value=mocker.MagicMock())
|
||||
session_factory = mocker.patch("services.rag_pipeline.rag_pipeline.sessionmaker")
|
||||
session_factory.return_value.begin.return_value.__enter__.return_value.scalar.return_value = dataset
|
||||
|
||||
# Mock current_user
|
||||
mock_user = mocker.Mock()
|
||||
mock_user.id = "user-123"
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", mock_user)
|
||||
account = _make_account(account_id="user-123")
|
||||
|
||||
# 2. Run test
|
||||
args = {"name": "New Template", "description": "Desc", "icon_info": {"icon": "star"}, "tags": ["tag1"]}
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", args)
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", args, account, "t1")
|
||||
|
||||
# 3. Assertions
|
||||
# Verify a new template was added to session or similar?
|
||||
@ -687,14 +766,10 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
|
||||
|
||||
|
||||
def test_get_datasource_plugins_success(mocker, rag_pipeline_service) -> None:
|
||||
from models.dataset import Dataset, Pipeline
|
||||
|
||||
# 1. Setup mocks
|
||||
dataset = mocker.Mock(spec=Dataset)
|
||||
dataset.pipeline_id = "p1"
|
||||
dataset = _make_dataset()
|
||||
|
||||
pipeline = mocker.Mock(spec=Pipeline)
|
||||
pipeline.id = "p1"
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
workflow = mocker.Mock()
|
||||
workflow.graph_dict = {
|
||||
@ -835,15 +910,10 @@ def test_set_datasource_variables_success(mocker, rag_pipeline_service) -> None:
|
||||
|
||||
|
||||
def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None:
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import Workflow
|
||||
|
||||
# 1. Setup mocks
|
||||
pipeline = mocker.Mock(spec=Pipeline)
|
||||
pipeline.id = "p1"
|
||||
pipeline.tenant_id = "t1"
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
workflow = mocker.Mock(spec=Workflow)
|
||||
workflow = _make_workflow()
|
||||
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.scalar.return_value = workflow
|
||||
@ -856,16 +926,10 @@ def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None:
|
||||
|
||||
|
||||
def test_get_published_workflow_success(mocker, rag_pipeline_service) -> None:
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import Workflow
|
||||
|
||||
# 1. Setup mocks
|
||||
pipeline = mocker.Mock(spec=Pipeline)
|
||||
pipeline.id = "p1"
|
||||
pipeline.tenant_id = "t1"
|
||||
pipeline.workflow_id = "wf-pub"
|
||||
pipeline = _make_pipeline(workflow_id="wf-pub")
|
||||
|
||||
workflow = mocker.Mock(spec=Workflow)
|
||||
workflow = _make_workflow(workflow_id="wf-pub")
|
||||
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.scalar.return_value = workflow
|
||||
@ -896,8 +960,8 @@ def test_get_default_block_config_success(rag_pipeline_service) -> None:
|
||||
def test_publish_workflow_raises_when_draft_workflow_missing(mocker, rag_pipeline_service) -> None:
|
||||
session = mocker.Mock()
|
||||
session.scalar.return_value = None
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
|
||||
account = SimpleNamespace(id="u1")
|
||||
pipeline = _make_pipeline()
|
||||
account = _make_account()
|
||||
|
||||
with pytest.raises(ValueError, match="No valid workflow found"):
|
||||
rag_pipeline_service.publish_workflow(session=session, pipeline=pipeline, account=account)
|
||||
@ -929,8 +993,8 @@ def test_get_default_block_config_injects_http_request_filter(mocker, rag_pipeli
|
||||
|
||||
|
||||
def test_run_draft_workflow_node_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
|
||||
account = SimpleNamespace(id="u1")
|
||||
pipeline = _make_pipeline()
|
||||
account = _make_account()
|
||||
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not initialized"):
|
||||
@ -939,8 +1003,8 @@ def test_run_draft_workflow_node_raises_when_workflow_missing(mocker, rag_pipeli
|
||||
|
||||
def test_run_draft_workflow_node_saves_execution_and_variables(mocker, rag_pipeline_service) -> None:
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db", mocker.Mock(engine=mocker.Mock()))
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
|
||||
account = SimpleNamespace(id="u1")
|
||||
pipeline = _make_pipeline()
|
||||
account = _make_account()
|
||||
draft_workflow = mocker.Mock(id="wf-1")
|
||||
draft_workflow.get_node_config_by_id.return_value = {"id": "node-1"}
|
||||
draft_workflow.get_enclosing_node_type_and_id.return_value = ("loop", "enclosing-node")
|
||||
@ -1163,11 +1227,11 @@ def test_get_second_step_parameters_handles_string_and_list_variable_references(
|
||||
|
||||
|
||||
def test_get_rag_pipeline_workflow_run_node_executions_empty_when_run_missing(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
|
||||
pipeline = _make_pipeline()
|
||||
mocker.patch.object(rag_pipeline_service, "get_rag_pipeline_workflow_run", return_value=None)
|
||||
|
||||
result = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
|
||||
pipeline=pipeline, run_id="run-1", user=SimpleNamespace(id="u1")
|
||||
pipeline=pipeline, run_id="run-1", user=_make_account()
|
||||
)
|
||||
|
||||
assert result == []
|
||||
@ -1175,14 +1239,14 @@ def test_get_rag_pipeline_workflow_run_node_executions_empty_when_run_missing(mo
|
||||
|
||||
def test_get_rag_pipeline_workflow_run_node_executions_returns_sorted_executions(mocker, rag_pipeline_service) -> None:
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db", mocker.Mock(engine=mocker.Mock()))
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
|
||||
pipeline = _make_pipeline()
|
||||
mocker.patch.object(rag_pipeline_service, "get_rag_pipeline_workflow_run", return_value=SimpleNamespace(id="run-1"))
|
||||
repo = mocker.Mock()
|
||||
repo.get_db_models_by_workflow_run.return_value = ["n1", "n2"]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.SQLAlchemyWorkflowNodeExecutionRepository", return_value=repo)
|
||||
|
||||
result = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
|
||||
pipeline=pipeline, run_id="run-1", user=SimpleNamespace(id="u1")
|
||||
pipeline=pipeline, run_id="run-1", user=_make_account()
|
||||
)
|
||||
|
||||
assert result == ["n1", "n2"]
|
||||
@ -1192,7 +1256,7 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = rag_pipeline_service.get_recommended_plugins("all")
|
||||
result = rag_pipeline_service.get_recommended_plugins("all", _make_account(), "t1")
|
||||
|
||||
assert result == {
|
||||
"installed_recommended_plugins": [],
|
||||
@ -1201,11 +1265,10 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra
|
||||
|
||||
|
||||
def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_pipeline_service) -> None:
|
||||
plugin_a = SimpleNamespace(plugin_id="plugin-a")
|
||||
plugin_b = SimpleNamespace(plugin_id="plugin-b")
|
||||
plugin_a = _make_recommended_plugin("plugin-a")
|
||||
plugin_b = _make_recommended_plugin("plugin-b")
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.scalars.return_value.all.return_value = [plugin_a, plugin_b]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
mocker.patch(
|
||||
"services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools",
|
||||
return_value=[SimpleNamespace(plugin_id="plugin-a", to_dict=lambda: {"plugin_id": "plugin-a"})],
|
||||
@ -1215,7 +1278,7 @@ def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_p
|
||||
return_value=[{"plugin_id": "plugin-b", "name": "Plugin B"}],
|
||||
)
|
||||
|
||||
result = rag_pipeline_service.get_recommended_plugins("custom")
|
||||
result = rag_pipeline_service.get_recommended_plugins("custom", _make_account(), "t1")
|
||||
|
||||
assert result["installed_recommended_plugins"] == [{"plugin_id": "plugin-a"}]
|
||||
assert result["uninstalled_recommended_plugins"] == [{"plugin_id": "plugin-b", "name": "Plugin B"}]
|
||||
@ -1229,8 +1292,8 @@ def test_get_node_last_run_delegates_to_repository(mocker, rag_pipeline_service)
|
||||
"services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository",
|
||||
return_value=repo,
|
||||
)
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
|
||||
workflow = SimpleNamespace(id="wf1")
|
||||
pipeline = _make_pipeline()
|
||||
workflow = _make_workflow(workflow_id="wf1")
|
||||
|
||||
result = rag_pipeline_service.get_node_last_run(pipeline, workflow, "node-1")
|
||||
|
||||
@ -1572,15 +1635,17 @@ def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline not found"):
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {}, _make_account(), "t1")
|
||||
|
||||
|
||||
def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id=None)
|
||||
pipeline = _make_pipeline(workflow_id=None)
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
|
||||
|
||||
with pytest.raises(ValueError, match="Pipeline workflow not found"):
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {"name": "template-name"})
|
||||
rag_pipeline_service.publish_customized_pipeline_template(
|
||||
"p1", {"name": "template-name"}, _make_account(), "t1"
|
||||
)
|
||||
|
||||
|
||||
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
|
||||
@ -1630,13 +1695,12 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
|
||||
|
||||
|
||||
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
|
||||
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
|
||||
template = _make_customized_template()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
|
||||
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
|
||||
info = PipelineTemplateInfoEntity(name="", description="updated", icon_info=IconInfo(icon="i"))
|
||||
result = RagPipelineService.update_customized_pipeline_template("tpl-1", info)
|
||||
result = RagPipelineService.update_customized_pipeline_template("tpl-1", info, _make_account(), "t1")
|
||||
|
||||
assert result.description == "updated"
|
||||
commit.assert_called_once()
|
||||
@ -1644,7 +1708,7 @@ def test_update_customized_pipeline_template_commits_when_name_empty(mocker) ->
|
||||
|
||||
def test_get_all_published_workflow_without_filters_has_no_more(rag_pipeline_service) -> None:
|
||||
session = SimpleNamespace(scalars=lambda stmt: SimpleNamespace(all=lambda: ["wf1"]))
|
||||
pipeline = SimpleNamespace(id="p1", workflow_id="wf-live")
|
||||
pipeline = _make_pipeline(workflow_id="wf-live")
|
||||
|
||||
workflows, has_more = rag_pipeline_service.get_all_published_workflow(
|
||||
session=session,
|
||||
@ -1856,38 +1920,34 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
|
||||
|
||||
|
||||
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
|
||||
pipeline = _make_pipeline(workflow_id="wf-1")
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {}, _make_account(), "t1")
|
||||
|
||||
|
||||
def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
|
||||
workflow = SimpleNamespace(id="wf-1")
|
||||
pipeline = _make_pipeline(workflow_id="wf-1")
|
||||
workflow = _make_workflow(workflow_id="wf-1")
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.engine = mocker.Mock()
|
||||
mock_db.session.get.side_effect = [pipeline, workflow]
|
||||
session_ctx = mocker.MagicMock()
|
||||
session_ctx.__enter__.return_value = SimpleNamespace()
|
||||
session_ctx.__exit__.return_value = False
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.Session", return_value=session_ctx)
|
||||
pipeline.retrieve_dataset = lambda session: None
|
||||
session_factory = mocker.patch("services.rag_pipeline.rag_pipeline.sessionmaker")
|
||||
session_factory.return_value.begin.return_value.__enter__.return_value.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
|
||||
rag_pipeline_service.publish_customized_pipeline_template("p1", {}, _make_account(), "t1")
|
||||
|
||||
|
||||
def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipeline_service) -> None:
|
||||
plugin = SimpleNamespace(plugin_id="plugin-a")
|
||||
plugin = _make_recommended_plugin("plugin-a")
|
||||
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
|
||||
mock_db.session.scalars.return_value.all.return_value = [plugin]
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[])
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids", return_value=[])
|
||||
|
||||
result = rag_pipeline_service.get_recommended_plugins("all")
|
||||
result = rag_pipeline_service.get_recommended_plugins("all", _make_account(), "t1")
|
||||
|
||||
assert result["installed_recommended_plugins"] == []
|
||||
assert result["uninstalled_recommended_plugins"] == []
|
||||
@ -1918,8 +1978,8 @@ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_
|
||||
|
||||
|
||||
def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, rag_pipeline_service) -> None:
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
|
||||
dataset = _make_dataset()
|
||||
pipeline = _make_pipeline()
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
|
||||
)
|
||||
@ -2079,8 +2139,8 @@ def test_set_datasource_variables_raises_when_workflow_missing(mocker, rag_pipel
|
||||
|
||||
|
||||
def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(mocker, rag_pipeline_service) -> None:
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
|
||||
dataset = _make_dataset()
|
||||
pipeline = _make_pipeline()
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={"nodes": [{"id": "n1", "data": {"type": "datasource", "datasource_parameters": {}}}]},
|
||||
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
|
||||
@ -2097,8 +2157,8 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
|
||||
|
||||
|
||||
def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag_pipeline_service) -> None:
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
|
||||
dataset = _make_dataset()
|
||||
pipeline = _make_pipeline()
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
@ -2139,8 +2199,8 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
|
||||
|
||||
|
||||
def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service) -> None:
|
||||
dataset = SimpleNamespace(pipeline_id="p1")
|
||||
pipeline = SimpleNamespace(id="p1")
|
||||
dataset = _make_dataset()
|
||||
pipeline = _make_pipeline()
|
||||
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
|
||||
|
||||
result = rag_pipeline_service.get_pipeline("t1", "d1")
|
||||
|
||||
@ -1,65 +1,62 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from models.account import Account
|
||||
from models import Account, Tenant
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
def _make_account(account_id: str = "user-456", tenant_id: str = "tenant-123") -> Account:
|
||||
account = Account(name="Test User", email=f"{account_id}@example.com")
|
||||
account.id = account_id
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = tenant_id
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
|
||||
class TestMetadataBugCompleteValidation:
|
||||
"""Complete test suite to verify the metadata nullable bug and its fix."""
|
||||
|
||||
def test_1_pydantic_layer_validation(self):
|
||||
def test_1_pydantic_layer_validation(self) -> None:
|
||||
"""Test Layer 1: Pydantic model validation correctly rejects None values."""
|
||||
# Pydantic should reject None values for required fields
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs(type=None, name=None)
|
||||
MetadataArgs(type=None, name=None) # pyrefly: ignore[bad-argument-type]
|
||||
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs(type="string", name=None)
|
||||
MetadataArgs(type="string", name=None) # pyrefly: ignore[bad-argument-type]
|
||||
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs(type=None, name="test")
|
||||
MetadataArgs(type=None, name="test") # pyrefly: ignore[bad-argument-type]
|
||||
|
||||
# Valid values should work
|
||||
valid_args = MetadataArgs(type="string", name="test_name")
|
||||
assert valid_args.type == "string"
|
||||
assert valid_args.name == "test_name"
|
||||
|
||||
def test_2_business_logic_layer_crashes_on_none(self):
|
||||
def test_2_business_logic_layer_crashes_on_none(self) -> None:
|
||||
"""Test Layer 2: Business logic crashes when None values slip through."""
|
||||
# Create mock that bypasses Pydantic validation
|
||||
mock_metadata_args = Mock()
|
||||
mock_metadata_args.name = None
|
||||
mock_metadata_args.type = "string"
|
||||
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
|
||||
with patch(
|
||||
"services.metadata_service.current_account_with_tenant",
|
||||
return_value=(mock_user, mock_user.current_tenant_id),
|
||||
):
|
||||
# Should crash with TypeError
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
account = _make_account()
|
||||
# Should crash with TypeError
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123")
|
||||
|
||||
# Test update method as well
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
account = _make_account()
|
||||
none_name = cast(str, None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.update_metadata_name("dataset-123", "metadata-456", none_name, account, "tenant-123")
|
||||
|
||||
with patch(
|
||||
"services.metadata_service.current_account_with_tenant",
|
||||
return_value=(mock_user, mock_user.current_tenant_id),
|
||||
):
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
|
||||
|
||||
def test_3_database_constraints_verification(self):
|
||||
def test_3_database_constraints_verification(self) -> None:
|
||||
"""Test Layer 3: Verify database model has nullable=False constraints."""
|
||||
from sqlalchemy import inspect
|
||||
|
||||
@ -75,7 +72,7 @@ class TestMetadataBugCompleteValidation:
|
||||
assert type_column.nullable is False, "type column should be nullable=False"
|
||||
assert name_column.nullable is False, "name column should be nullable=False"
|
||||
|
||||
def test_4_fixed_api_layer_rejects_null(self):
|
||||
def test_4_fixed_api_layer_rejects_null(self) -> None:
|
||||
"""Test Layer 4: Fixed API configuration properly rejects null values using Pydantic."""
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs.model_validate({"type": None, "name": None})
|
||||
@ -86,30 +83,23 @@ class TestMetadataBugCompleteValidation:
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
MetadataArgs.model_validate({"type": None, "name": "test"})
|
||||
|
||||
def test_5_fixed_api_accepts_valid_values(self):
|
||||
def test_5_fixed_api_accepts_valid_values(self) -> None:
|
||||
"""Test that fixed API still accepts valid non-null values."""
|
||||
args = MetadataArgs.model_validate({"type": "string", "name": "valid_name"})
|
||||
assert args.type == "string"
|
||||
assert args.name == "valid_name"
|
||||
|
||||
def test_6_simulated_buggy_behavior(self):
|
||||
def test_6_simulated_buggy_behavior(self) -> None:
|
||||
"""Test simulating the original buggy behavior by bypassing Pydantic validation."""
|
||||
mock_metadata_args = Mock()
|
||||
mock_metadata_args.name = None
|
||||
mock_metadata_args.type = None
|
||||
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
account = _make_account()
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123")
|
||||
|
||||
with patch(
|
||||
"services.metadata_service.current_account_with_tenant",
|
||||
return_value=(mock_user, mock_user.current_tenant_id),
|
||||
):
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
def test_7_end_to_end_validation_layers(self):
|
||||
def test_7_end_to_end_validation_layers(self) -> None:
|
||||
"""Test all validation layers work together correctly."""
|
||||
# Layer 1: API should reject null at parameter level (with fix)
|
||||
# Layer 2: Pydantic should reject null at model level
|
||||
@ -128,7 +118,7 @@ class TestMetadataBugCompleteValidation:
|
||||
assert len(metadata_args.name) <= 255 # This should not crash
|
||||
assert len(metadata_args.type) > 0 # This should not crash
|
||||
|
||||
def test_8_verify_specific_fix_locations(self):
|
||||
def test_8_verify_specific_fix_locations(self) -> None:
|
||||
"""Verify that the specific locations mentioned in bug report are fixed."""
|
||||
# Read the actual files to verify fixes
|
||||
import os
|
||||
@ -152,7 +142,7 @@ class TestMetadataBugCompleteValidation:
|
||||
class TestMetadataValidationSummary:
|
||||
"""Summary tests that demonstrate the complete validation architecture."""
|
||||
|
||||
def test_validation_layer_architecture(self):
|
||||
def test_validation_layer_architecture(self) -> None:
|
||||
"""Document and test the 4-layer validation architecture."""
|
||||
# Layer 1: API Parameter Validation (Flask-RESTful reqparse)
|
||||
# - Role: First line of defense, validates HTTP request parameters
|
||||
|
||||
@ -1,56 +1,53 @@
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from models.account import Account
|
||||
from models import Account, Tenant
|
||||
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
def _make_account(account_id: str = "user-456", tenant_id: str = "tenant-123") -> Account:
|
||||
account = Account(name="Test User", email=f"{account_id}@example.com")
|
||||
account.id = account_id
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.id = tenant_id
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
|
||||
class TestMetadataNullableBug:
|
||||
"""Test case to reproduce the metadata nullable validation bug."""
|
||||
|
||||
def test_metadata_args_with_none_values_should_fail(self):
|
||||
def test_metadata_args_with_none_values_should_fail(self) -> None:
|
||||
"""Test that MetadataArgs validation should reject None values."""
|
||||
# This test demonstrates the expected behavior - should fail validation
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
# This should fail because Pydantic expects non-None values
|
||||
MetadataArgs(type=None, name=None)
|
||||
MetadataArgs(type=None, name=None) # pyrefly: ignore[bad-argument-type]
|
||||
|
||||
def test_metadata_service_create_with_none_name_crashes(self):
|
||||
def test_metadata_service_create_with_none_name_crashes(self) -> None:
|
||||
"""Test that MetadataService.create_metadata crashes when name is None."""
|
||||
# Mock the MetadataArgs to bypass Pydantic validation
|
||||
mock_metadata_args = Mock()
|
||||
mock_metadata_args.name = None # This will cause len() to crash
|
||||
mock_metadata_args.type = "string"
|
||||
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
account = _make_account()
|
||||
# This should crash with TypeError when calling len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123")
|
||||
|
||||
with patch(
|
||||
"services.metadata_service.current_account_with_tenant",
|
||||
return_value=(mock_user, mock_user.current_tenant_id),
|
||||
):
|
||||
# This should crash with TypeError when calling len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.create_metadata("dataset-123", mock_metadata_args)
|
||||
|
||||
def test_metadata_service_update_with_none_name_crashes(self):
|
||||
def test_metadata_service_update_with_none_name_crashes(self) -> None:
|
||||
"""Test that MetadataService.update_metadata_name crashes when name is None."""
|
||||
mock_user = create_autospec(Account, instance=True)
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
mock_user.id = "user-456"
|
||||
account = _make_account()
|
||||
none_name = cast(str, None)
|
||||
# This should crash with TypeError when calling len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.update_metadata_name("dataset-123", "metadata-456", none_name, account, "tenant-123")
|
||||
|
||||
with patch(
|
||||
"services.metadata_service.current_account_with_tenant",
|
||||
return_value=(mock_user, mock_user.current_tenant_id),
|
||||
):
|
||||
# This should crash with TypeError when calling len(None)
|
||||
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
|
||||
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
|
||||
|
||||
def test_api_layer_now_uses_pydantic_validation(self):
|
||||
def test_api_layer_now_uses_pydantic_validation(self) -> None:
|
||||
"""Verify that API layer relies on Pydantic validation instead of reqparse."""
|
||||
invalid_payload = {"type": None, "name": None}
|
||||
with pytest.raises((ValueError, TypeError)):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user