From 22e67b46730e9053240ee922c15e3edf7227b923 Mon Sep 17 00:00:00 2001 From: Eric Cao Date: Tue, 9 Jun 2026 11:14:50 +0800 Subject: [PATCH] chore(api): convert PipelineTemplateRetrievalBase from ABC to Protocol (#37201) --- .../pipeline_template_base.py | 17 ++++---------- .../test_pipeline_template_base.py | 23 ------------------- 2 files changed, 5 insertions(+), 35 deletions(-) diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py index 0ed2a4b8f2..9cfb8f36aa 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py @@ -1,18 +1,11 @@ -from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Protocol -class PipelineTemplateRetrievalBase(ABC): +class PipelineTemplateRetrievalBase(Protocol): """Interface for pipeline template retrieval.""" - @abstractmethod - def get_pipeline_templates(self, language: str) -> dict[str, Any]: - raise NotImplementedError + def get_pipeline_templates(self, language: str) -> dict[str, Any]: ... - @abstractmethod - def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - raise NotImplementedError + def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: ... - @abstractmethod - def get_type(self) -> str: - raise NotImplementedError + def get_type(self) -> str: ... diff --git a/api/tests/unit_tests/services/rag_pipeline/pipeline_template/test_pipeline_template_base.py b/api/tests/unit_tests/services/rag_pipeline/pipeline_template/test_pipeline_template_base.py index 304ee8faa3..5918d74f89 100644 --- a/api/tests/unit_tests/services/rag_pipeline/pipeline_template/test_pipeline_template_base.py +++ b/api/tests/unit_tests/services/rag_pipeline/pipeline_template/test_pipeline_template_base.py @@ -1,5 +1,3 @@ -import pytest - from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase @@ -14,30 +12,9 @@ class DummyRetrieval(PipelineTemplateRetrievalBase): return "dummy" -class MissingTypeRetrieval(PipelineTemplateRetrievalBase): - def get_pipeline_templates(self, language: str) -> dict: - return {"language": language} - - def get_pipeline_template_detail(self, template_id: str) -> dict | None: - return {"id": template_id} - - def test_pipeline_template_retrieval_base_concrete_implementation() -> None: retrieval = DummyRetrieval() assert retrieval.get_pipeline_templates("en-US") == {"language": "en-US"} assert retrieval.get_pipeline_template_detail("tpl-1") == {"id": "tpl-1"} assert retrieval.get_type() == "dummy" - - -def test_pipeline_template_retrieval_base_requires_abstract_methods() -> None: - assert "get_type" in MissingTypeRetrieval.__abstractmethods__ - - -def test_pipeline_template_retrieval_base_default_methods_raise() -> None: - with pytest.raises(NotImplementedError): - PipelineTemplateRetrievalBase.get_pipeline_templates(DummyRetrieval(), "en-US") - with pytest.raises(NotImplementedError): - PipelineTemplateRetrievalBase.get_pipeline_template_detail(DummyRetrieval(), "tpl-1") - with pytest.raises(NotImplementedError): - PipelineTemplateRetrievalBase.get_type(DummyRetrieval())