chore(api): convert PipelineTemplateRetrievalBase from ABC to Protocol (#37201)

This commit is contained in:
Eric Cao 2026-06-09 11:14:50 +08:00 committed by GitHub
parent f948e442e0
commit 22e67b4673
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 35 deletions

View File

@ -1,18 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Protocol
class PipelineTemplateRetrievalBase(ABC):
class PipelineTemplateRetrievalBase(Protocol):
"""Interface for pipeline template retrieval."""
@abstractmethod
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
raise NotImplementedError
def get_pipeline_templates(self, language: str) -> dict[str, Any]: ...
@abstractmethod
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
raise NotImplementedError
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: ...
@abstractmethod
def get_type(self) -> str:
raise NotImplementedError
def get_type(self) -> str: ...

View File

@ -1,5 +1,3 @@
import pytest
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
@ -14,30 +12,9 @@ class DummyRetrieval(PipelineTemplateRetrievalBase):
return "dummy"
class MissingTypeRetrieval(PipelineTemplateRetrievalBase):
def get_pipeline_templates(self, language: str) -> dict:
return {"language": language}
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
return {"id": template_id}
def test_pipeline_template_retrieval_base_concrete_implementation() -> None:
retrieval = DummyRetrieval()
assert retrieval.get_pipeline_templates("en-US") == {"language": "en-US"}
assert retrieval.get_pipeline_template_detail("tpl-1") == {"id": "tpl-1"}
assert retrieval.get_type() == "dummy"
def test_pipeline_template_retrieval_base_requires_abstract_methods() -> None:
assert "get_type" in MissingTypeRetrieval.__abstractmethods__
def test_pipeline_template_retrieval_base_default_methods_raise() -> None:
with pytest.raises(NotImplementedError):
PipelineTemplateRetrievalBase.get_pipeline_templates(DummyRetrieval(), "en-US")
with pytest.raises(NotImplementedError):
PipelineTemplateRetrievalBase.get_pipeline_template_detail(DummyRetrieval(), "tpl-1")
with pytest.raises(NotImplementedError):
PipelineTemplateRetrievalBase.get_type(DummyRetrieval())