diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py index 0ed2a4b8f2..9cfb8f36aa 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py @@ -1,18 +1,11 @@ -from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Protocol -class PipelineTemplateRetrievalBase(ABC): +class PipelineTemplateRetrievalBase(Protocol): """Interface for pipeline template retrieval.""" - @abstractmethod - def get_pipeline_templates(self, language: str) -> dict[str, Any]: - raise NotImplementedError + def get_pipeline_templates(self, language: str) -> dict[str, Any]: ... - @abstractmethod - def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - raise NotImplementedError + def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: ... - @abstractmethod - def get_type(self) -> str: - raise NotImplementedError + def get_type(self) -> str: ... diff --git a/api/tests/unit_tests/services/rag_pipeline/pipeline_template/test_pipeline_template_base.py b/api/tests/unit_tests/services/rag_pipeline/pipeline_template/test_pipeline_template_base.py index 304ee8faa3..5918d74f89 100644 --- a/api/tests/unit_tests/services/rag_pipeline/pipeline_template/test_pipeline_template_base.py +++ b/api/tests/unit_tests/services/rag_pipeline/pipeline_template/test_pipeline_template_base.py @@ -1,5 +1,3 @@ -import pytest - from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase @@ -14,30 +12,9 @@ class DummyRetrieval(PipelineTemplateRetrievalBase): return "dummy" -class MissingTypeRetrieval(PipelineTemplateRetrievalBase): - def get_pipeline_templates(self, language: str) -> dict: - return {"language": language} - - def get_pipeline_template_detail(self, template_id: str) -> dict | None: - return {"id": template_id} - - def test_pipeline_template_retrieval_base_concrete_implementation() -> None: retrieval = DummyRetrieval() assert retrieval.get_pipeline_templates("en-US") == {"language": "en-US"} assert retrieval.get_pipeline_template_detail("tpl-1") == {"id": "tpl-1"} assert retrieval.get_type() == "dummy" - - -def test_pipeline_template_retrieval_base_requires_abstract_methods() -> None: - assert "get_type" in MissingTypeRetrieval.__abstractmethods__ - - -def test_pipeline_template_retrieval_base_default_methods_raise() -> None: - with pytest.raises(NotImplementedError): - PipelineTemplateRetrievalBase.get_pipeline_templates(DummyRetrieval(), "en-US") - with pytest.raises(NotImplementedError): - PipelineTemplateRetrievalBase.get_pipeline_template_detail(DummyRetrieval(), "tpl-1") - with pytest.raises(NotImplementedError): - PipelineTemplateRetrievalBase.get_type(DummyRetrieval())