diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 67d8b598b0..4f7f7d9a98 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -34,6 +34,7 @@ from .dataset import ( metadata, segment, ) +from .dataset.rag_pipeline import rag_pipeline_workflow from .end_user import end_user from .workspace import models @@ -53,6 +54,7 @@ __all__ = [ "message", "metadata", "models", + "rag_pipeline_workflow", "segment", "site", "workflow", diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 70b5030237..94cbee1f58 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -1,5 +1,3 @@ -import string -import uuid from collections.abc import Generator from typing import Any @@ -41,7 +39,7 @@ register_schema_model(service_api_ns, DatasourceNodeRunPayload) register_schema_model(service_api_ns, PipelineRunApiEntity) -@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins") +@service_api_ns.route("/datasets//pipeline/datasource-plugins") class DatasourcePluginsApi(DatasetApiResource): """Resource for datasource plugins.""" @@ -76,7 +74,7 @@ class DatasourcePluginsApi(DatasetApiResource): return datasource_plugins, 200 -@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run") +@service_api_ns.route("/datasets//pipeline/datasource/nodes//run") class DatasourceNodeRunApi(DatasetApiResource): """Resource for datasource node run.""" @@ -131,7 +129,7 @@ class DatasourceNodeRunApi(DatasetApiResource): ) -@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run") +@service_api_ns.route("/datasets//pipeline/run") class PipelineRunApi(DatasetApiResource): """Resource for datasource node run.""" diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index b80735914d..cc55c69c48 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -217,6 +217,8 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): def decorator(view: Callable[Concatenate[T, P], R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): + api_token = validate_and_get_api_token("dataset") + # get url path dataset_id from positional args or kwargs # Flask passes URL path parameters as positional arguments dataset_id = None @@ -253,12 +255,18 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): # Validate dataset if dataset_id is provided if dataset_id: dataset_id = str(dataset_id) - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = ( + db.session.query(Dataset) + .where( + Dataset.id == dataset_id, + Dataset.tenant_id == api_token.tenant_id, + ) + .first() + ) if not dataset: raise NotFound("Dataset not found.") if not dataset.enable_api: raise Forbidden("Dataset api access is not enabled.") - api_token = validate_and_get_api_token("dataset") tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) .where(Tenant.id == api_token.tenant_id) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index ccc6abcc06..4e33b312f4 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1329,10 +1329,24 @@ class RagPipelineService: """ Get datasource plugins """ - dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset: Dataset | None = ( + db.session.query(Dataset) + .where( + Dataset.id == dataset_id, + Dataset.tenant_id == tenant_id, + ) + .first() + ) if not dataset: raise ValueError("Dataset not found") - pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first() + pipeline: Pipeline | None = ( + db.session.query(Pipeline) + .where( + Pipeline.id == dataset.pipeline_id, + Pipeline.tenant_id == tenant_id, + ) + .first() + ) if not pipeline: raise ValueError("Pipeline not found") @@ -1413,10 +1427,24 @@ class RagPipelineService: """ Get pipeline """ - dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset: Dataset | None = ( + db.session.query(Dataset) + .where( + Dataset.id == dataset_id, + Dataset.tenant_id == tenant_id, + ) + .first() + ) if not dataset: raise ValueError("Dataset not found") - pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first() + pipeline: Pipeline | None = ( + db.session.query(Pipeline) + .where( + Pipeline.id == dataset.pipeline_id, + Pipeline.tenant_id == tenant_id, + ) + .first() + ) if not pipeline: raise ValueError("Pipeline not found") return pipeline diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_rag_pipeline_route_registration.py b/api/tests/unit_tests/controllers/service_api/dataset/test_rag_pipeline_route_registration.py new file mode 100644 index 0000000000..184e37014b --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_rag_pipeline_route_registration.py @@ -0,0 +1,54 @@ +""" +Unit tests for Service API knowledge pipeline route registration. +""" + +import ast +from pathlib import Path + + +def test_rag_pipeline_routes_registered(): + api_dir = Path(__file__).resolve().parents[5] + + service_api_init = api_dir / "controllers" / "service_api" / "__init__.py" + rag_pipeline_workflow = ( + api_dir / "controllers" / "service_api" / "dataset" / "rag_pipeline" / "rag_pipeline_workflow.py" + ) + + assert service_api_init.exists() + assert rag_pipeline_workflow.exists() + + init_tree = ast.parse(service_api_init.read_text(encoding="utf-8")) + import_found = False + for node in ast.walk(init_tree): + if not isinstance(node, ast.ImportFrom): + continue + if node.module != "dataset.rag_pipeline" or node.level != 1: + continue + if any(alias.name == "rag_pipeline_workflow" for alias in node.names): + import_found = True + break + assert import_found, "from .dataset.rag_pipeline import rag_pipeline_workflow not found in service_api/__init__.py" + + workflow_tree = ast.parse(rag_pipeline_workflow.read_text(encoding="utf-8")) + route_paths: set[str] = set() + + for node in ast.walk(workflow_tree): + if not isinstance(node, ast.ClassDef): + continue + for decorator in node.decorator_list: + if not isinstance(decorator, ast.Call): + continue + if not isinstance(decorator.func, ast.Attribute): + continue + if decorator.func.attr != "route": + continue + if not decorator.args: + continue + first_arg = decorator.args[0] + if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str): + route_paths.add(first_arg.value) + + assert "/datasets//pipeline/datasource-plugins" in route_paths + assert "/datasets//pipeline/datasource/nodes//run" in route_paths + assert "/datasets//pipeline/run" in route_paths + assert "/datasets/pipeline/file-upload" in route_paths