mirror of
https://github.com/langgenius/dify.git
synced 2026-06-13 20:24:18 +08:00
chore(api): Fix several typing errors (#37248)
This commit is contained in:
parent
ad96501e09
commit
7cf75c3cc5
@ -11,8 +11,10 @@ from core.tools.entities.tool_entities import (
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class ToolProviderController(ABC):
|
||||
def __init__(self, entity: ToolProviderEntity):
|
||||
class ToolProviderController[ToolProviderEntityT: ToolProviderEntity, ToolProviderToolT: Tool | None](ABC):
|
||||
entity: ToolProviderEntityT
|
||||
|
||||
def __init__(self, entity: ToolProviderEntityT):
|
||||
self.entity = entity
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
@ -24,7 +26,7 @@ class ToolProviderController(ABC):
|
||||
return deepcopy(self.entity.credentials_schema)
|
||||
|
||||
@abstractmethod
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
def get_tool(self, tool_name: str) -> ToolProviderToolT:
|
||||
"""
|
||||
returns a tool that the provider can provide
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ from core.tools.errors import (
|
||||
from core.tools.utils.yaml_utils import load_yaml_file_cached
|
||||
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
class BuiltinToolProviderController(ToolProviderController[ToolProviderEntity, BuiltinTool | None]):
|
||||
tools: list[BuiltinTool]
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
@ -163,7 +163,8 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self._get_builtin_tools()
|
||||
|
||||
def get_tool(self, tool_name: str) -> BuiltinTool | None: # type: ignore
|
||||
@override
|
||||
def get_tool(self, tool_name: str) -> BuiltinTool | None:
|
||||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
|
||||
@ -24,7 +24,7 @@ from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class ApiToolProviderController(ToolProviderController):
|
||||
class ApiToolProviderController(ToolProviderController[ToolProviderEntity, ApiTool]):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
tools: list[ApiTool] = Field(default_factory=list)
|
||||
|
||||
@ -18,7 +18,7 @@ from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class MCPToolProviderController(ToolProviderController):
|
||||
class MCPToolProviderController(ToolProviderController[ToolProviderEntityWithPlugin, MCPTool]):
|
||||
def __init__(
|
||||
self,
|
||||
entity: ToolProviderEntityWithPlugin,
|
||||
|
||||
@ -9,7 +9,9 @@ from core.tools.plugin_tool.tool import PluginTool
|
||||
|
||||
|
||||
class PluginToolProviderController(BuiltinToolProviderController):
|
||||
entity: ToolProviderEntityWithPlugin
|
||||
# TODO: Split the credential/schema helpers from BuiltinToolProviderController
|
||||
# so plugin providers do not need to inherit builtin tool-loading behavior.
|
||||
entity: ToolProviderEntityWithPlugin # pyrefly: ignore[bad-override-mutable-attribute]
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
@ -46,7 +48,8 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
||||
):
|
||||
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||
|
||||
def get_tool(self, tool_name: str) -> PluginTool: # type: ignore
|
||||
@override
|
||||
def get_tool(self, tool_name: str) -> PluginTool: # type: ignore[override] # pyrefly: ignore[bad-override]
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
@ -65,7 +68,8 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[PluginTool]: # type: ignore
|
||||
@override
|
||||
def get_tools(self) -> list[PluginTool]: # type: ignore[override] # pyrefly: ignore[bad-override]
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
|
||||
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping
|
||||
from typing import override
|
||||
|
||||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -43,13 +42,14 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
class WorkflowToolProviderController(ToolProviderController):
|
||||
class WorkflowToolProviderController(ToolProviderController[ToolProviderEntity, WorkflowTool | None]):
|
||||
provider_id: str
|
||||
tools: list[WorkflowTool] = Field(default_factory=list)
|
||||
tools: list[WorkflowTool] | None
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity, provider_id: str):
|
||||
super().__init__(entity=entity)
|
||||
self.provider_id = provider_id
|
||||
self.tools = None
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
|
||||
@ -241,7 +241,8 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> WorkflowTool | None: # type: ignore
|
||||
@override
|
||||
def get_tool(self, tool_name: str) -> WorkflowTool | None:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
|
||||
@ -27,12 +27,14 @@ class HuaweiObsStorage(BaseStorage):
|
||||
|
||||
@override
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read()
|
||||
# TODO: Huawei SDK lacks proper typing
|
||||
data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename).body.response.read() # type: ignore
|
||||
return data
|
||||
|
||||
@override
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response
|
||||
# TODO: Huawei SDK lacks proper typing
|
||||
response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename).body.response # type: ignore
|
||||
while chunk := response.read(4096):
|
||||
yield chunk
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
# ===================================================================
|
||||
|
||||
from hashlib import sha1
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import Crypto.Hash.SHA1
|
||||
import Crypto.Util.number
|
||||
@ -30,6 +31,9 @@ from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
|
||||
from Crypto.Util.py3compat import bord
|
||||
from Crypto.Util.strxor import strxor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from Crypto.Signature.pss import HashModule
|
||||
|
||||
|
||||
class PKCS1OAepCipher:
|
||||
"""Cipher object for PKCS#1 v1.5 OAEP.
|
||||
@ -70,7 +74,7 @@ class PKCS1OAepCipher:
|
||||
if mgfunc:
|
||||
self._mgf = mgfunc
|
||||
else:
|
||||
self._mgf = lambda x, y: MGF1(x, y, self._hashObj)
|
||||
self._mgf = lambda x, y: MGF1(x, y, cast("HashModule", self._hashObj))
|
||||
|
||||
self._label = bytes(label)
|
||||
self._randfunc = randfunc
|
||||
|
||||
@ -53,11 +53,3 @@ providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py
|
||||
providers/vdb/vdb-vikingdb/tests/unit_tests/test_vikingdb_vector.py
|
||||
providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py
|
||||
providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py
|
||||
core/tools/mcp_tool/provider.py
|
||||
core/tools/plugin_tool/provider.py
|
||||
core/tools/workflow_as_tool/provider.py
|
||||
extensions/storage/huawei_obs_storage.py
|
||||
libs/gmpy2_pkcs10aep_cipher.py
|
||||
services/audio_service.py
|
||||
services/document_indexing_proxy/document_indexing_task_proxy.py
|
||||
services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py
|
||||
|
||||
@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class AudioService:
|
||||
@classmethod
|
||||
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: str | None = None):
|
||||
def transcript_asr(cls, app_model: App, file: FileStorage | None, end_user: str | None = None):
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
@ -141,14 +141,14 @@ class AudioService:
|
||||
else:
|
||||
response = invoke_tts(text_content=message.answer, app_model=app_model, voice=voice, is_draft=is_draft)
|
||||
if isinstance(response, Generator):
|
||||
return Response(stream_with_context(response), content_type="audio/mpeg")
|
||||
return Response(stream_with_context(response), content_type="audio/mpeg") # type: ignore
|
||||
return response
|
||||
else:
|
||||
if text is None:
|
||||
raise ValueError("Text is required")
|
||||
response = invoke_tts(text_content=text, app_model=app_model, voice=voice, is_draft=is_draft)
|
||||
if isinstance(response, Generator):
|
||||
return Response(stream_with_context(response), content_type="audio/mpeg")
|
||||
return Response(stream_with_context(response), content_type="audio/mpeg") # type: ignore
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import ClassVar
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy
|
||||
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
|
||||
@ -8,5 +9,5 @@ class DocumentIndexingTaskProxy(BatchDocumentIndexingProxy):
|
||||
"""Proxy for document indexing tasks."""
|
||||
|
||||
QUEUE_NAME: ClassVar[str] = "document_indexing"
|
||||
NORMAL_TASK_FUNC = normal_document_indexing_task # pyrefly: ignore[missing-override-decorator]
|
||||
PRIORITY_TASK_FUNC = priority_document_indexing_task # pyrefly: ignore[missing-override-decorator]
|
||||
NORMAL_TASK_FUNC: ClassVar[Callable[..., Any]] = normal_document_indexing_task # pyrefly: ignore
|
||||
PRIORITY_TASK_FUNC: ClassVar[Callable[..., Any]] = priority_document_indexing_task # pyrefly: ignore
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import ClassVar
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy
|
||||
from tasks.duplicate_document_indexing_task import (
|
||||
@ -6,10 +7,12 @@ from tasks.duplicate_document_indexing_task import (
|
||||
priority_duplicate_document_indexing_task,
|
||||
)
|
||||
|
||||
TaskFunc = Callable[..., Any]
|
||||
|
||||
|
||||
class DuplicateDocumentIndexingTaskProxy(BatchDocumentIndexingProxy):
|
||||
"""Proxy for duplicate document indexing tasks."""
|
||||
|
||||
QUEUE_NAME: ClassVar[str] = "duplicate_document_indexing"
|
||||
NORMAL_TASK_FUNC = normal_duplicate_document_indexing_task # pyrefly: ignore[missing-override-decorator]
|
||||
PRIORITY_TASK_FUNC = priority_duplicate_document_indexing_task # pyrefly: ignore[missing-override-decorator]
|
||||
NORMAL_TASK_FUNC: ClassVar[TaskFunc] = normal_duplicate_document_indexing_task # pyrefly: ignore
|
||||
PRIORITY_TASK_FUNC: ClassVar[TaskFunc] = priority_duplicate_document_indexing_task # pyrefly: ignore
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -16,6 +16,7 @@ from core.tools.errors import ToolProviderNotFoundError
|
||||
|
||||
|
||||
class _FakeBuiltinTool(BuiltinTool):
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -30,6 +31,7 @@ class _FakeBuiltinTool(BuiltinTool):
|
||||
class _ConcreteBuiltinProvider(BuiltinToolProviderController):
|
||||
last_validation: tuple[str, dict[str, Any]] | None = None
|
||||
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
self.last_validation = (user_id, credentials)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@ -14,6 +14,7 @@ from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create a mock class for testing abstract/base classes
|
||||
class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController):
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
return None
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
import pytest
|
||||
|
||||
@ -22,9 +22,11 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class _DummyTool(Tool):
|
||||
@override
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -36,7 +38,8 @@ class _DummyTool(Tool):
|
||||
yield self.create_text_message("ok")
|
||||
|
||||
|
||||
class _DummyController(ToolProviderController):
|
||||
class _DummyController(ToolProviderController[ToolProviderEntity, Tool]):
|
||||
@override
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
|
||||
@ -1,19 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolParameter,
|
||||
ToolProviderEntity,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
|
||||
|
||||
def _controller() -> WorkflowToolProviderController:
|
||||
@ -30,6 +41,64 @@ def _controller() -> WorkflowToolProviderController:
|
||||
return WorkflowToolProviderController(entity=entity, provider_id="provider-1")
|
||||
|
||||
|
||||
def _app() -> App:
|
||||
return App(id="app-1")
|
||||
|
||||
|
||||
def _account() -> Account:
|
||||
return Account(name="Alice", email="alice@example.com")
|
||||
|
||||
|
||||
def _workflow() -> Workflow:
|
||||
return Workflow.new(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
version="1",
|
||||
graph=json.dumps({"nodes": []}),
|
||||
features="{}",
|
||||
created_by="user-1",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
|
||||
|
||||
def _db_provider(*, parameter_configuration: str = "[]") -> WorkflowToolProvider:
|
||||
return WorkflowToolProvider(
|
||||
name="workflow_tool",
|
||||
label="WF Provider",
|
||||
icon="icon.svg",
|
||||
app_id="app-1",
|
||||
version="1",
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
description="desc",
|
||||
parameter_configuration=parameter_configuration,
|
||||
)
|
||||
|
||||
|
||||
def _workflow_tool(name: str = "workflow_tool") -> WorkflowTool:
|
||||
return WorkflowTool(
|
||||
workflow_as_tool_id="provider-1",
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author="author",
|
||||
name=name,
|
||||
label=I18nObject(en_US=name),
|
||||
provider="provider-1",
|
||||
),
|
||||
description=ToolDescription(human=I18nObject(en_US="desc"), llm="desc"),
|
||||
parameters=[],
|
||||
),
|
||||
runtime=ToolRuntime(tenant_id="tenant-1"),
|
||||
workflow_app_id="app-1",
|
||||
workflow_entities={"app": _app(), "workflow": _workflow()},
|
||||
version="1",
|
||||
workflow_call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
def _mock_session_with_begin() -> Mock:
|
||||
session = Mock()
|
||||
begin_cm = Mock()
|
||||
@ -42,25 +111,18 @@ def _mock_session_with_begin() -> Mock:
|
||||
def test_get_db_provider_tool_builds_entity():
|
||||
controller = _controller()
|
||||
session = Mock()
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={})
|
||||
workflow = _workflow()
|
||||
session.scalar.return_value = workflow
|
||||
app = SimpleNamespace(id="app-1")
|
||||
db_provider = SimpleNamespace(
|
||||
id="provider-1",
|
||||
app_id="app-1",
|
||||
version="1",
|
||||
label="WF Provider",
|
||||
description="desc",
|
||||
icon="icon.svg",
|
||||
name="workflow_tool",
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
parameter_configurations=[
|
||||
SimpleNamespace(name="country", description="Country", form=ToolParameter.ToolParameterForm.FORM),
|
||||
SimpleNamespace(name="files", description="files", form=ToolParameter.ToolParameterForm.FORM),
|
||||
],
|
||||
app = _app()
|
||||
db_provider = _db_provider(
|
||||
parameter_configuration=json.dumps(
|
||||
[
|
||||
{"name": "country", "description": "Country", "form": ToolParameter.ToolParameterForm.FORM.value},
|
||||
{"name": "files", "description": "files", "form": ToolParameter.ToolParameterForm.FORM.value},
|
||||
]
|
||||
)
|
||||
)
|
||||
user = SimpleNamespace(name="Alice")
|
||||
user = _account()
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="country",
|
||||
@ -94,8 +156,9 @@ def test_get_db_provider_tool_builds_entity():
|
||||
|
||||
assert tool.entity.identity.name == "workflow_tool"
|
||||
# "json" output is reserved for ToolInvokeMessage.VariableMessage and filtered out.
|
||||
assert tool.entity.output_schema["properties"] == {"answer": {"type": "string", "description": ""}}
|
||||
assert "json" not in tool.entity.output_schema["properties"]
|
||||
properties = cast(dict[str, Any], tool.entity.output_schema["properties"])
|
||||
assert properties == {"answer": {"type": "string", "description": ""}}
|
||||
assert "json" not in properties
|
||||
assert tool.entity.parameters[0].type == ToolParameter.ToolParameterType.SELECT
|
||||
assert tool.entity.parameters[1].type == ToolParameter.ToolParameterType.SYSTEM_FILES
|
||||
assert controller.provider_type == ToolProviderType.WORKFLOW
|
||||
@ -103,7 +166,7 @@ def test_get_db_provider_tool_builds_entity():
|
||||
|
||||
def test_get_tool_returns_hit_or_none():
|
||||
controller = _controller()
|
||||
tool = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="workflow_tool")))
|
||||
tool = _workflow_tool()
|
||||
controller.tools = [tool]
|
||||
|
||||
assert controller.get_tool("workflow_tool") is tool
|
||||
@ -112,29 +175,16 @@ def test_get_tool_returns_hit_or_none():
|
||||
|
||||
def test_get_tools_returns_cached():
|
||||
controller = _controller()
|
||||
cached_tools = [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf-cached")))]
|
||||
controller.tools = cached_tools # type: ignore[assignment]
|
||||
cached_tools = [_workflow_tool("wf-cached")]
|
||||
controller.tools = cached_tools
|
||||
|
||||
assert controller.get_tools("tenant-1") == cached_tools
|
||||
|
||||
|
||||
def test_from_db_builds_controller():
|
||||
controller = _controller()
|
||||
|
||||
app = SimpleNamespace(id="app-1")
|
||||
user = SimpleNamespace(name="Alice")
|
||||
db_provider = SimpleNamespace(
|
||||
id="provider-1",
|
||||
app_id="app-1",
|
||||
version="1",
|
||||
user_id="user-1",
|
||||
label="WF Provider",
|
||||
description="desc",
|
||||
icon="icon.svg",
|
||||
name="workflow_tool",
|
||||
tenant_id="tenant-1",
|
||||
parameter_configurations=[],
|
||||
)
|
||||
app = _app()
|
||||
user = _account()
|
||||
db_provider = _db_provider()
|
||||
session = _mock_session_with_begin()
|
||||
session.scalar.return_value = db_provider
|
||||
session.get.side_effect = [app, user]
|
||||
@ -148,7 +198,7 @@ def test_from_db_builds_controller():
|
||||
with patch.object(
|
||||
WorkflowToolProviderController,
|
||||
"_get_db_provider_tool",
|
||||
return_value=SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf"))),
|
||||
return_value=_workflow_tool("wf"),
|
||||
):
|
||||
built = WorkflowToolProviderController.from_db(db_provider)
|
||||
assert isinstance(built, WorkflowToolProviderController)
|
||||
@ -157,7 +207,7 @@ def test_from_db_builds_controller():
|
||||
|
||||
def test_get_tools_returns_empty_when_provider_missing():
|
||||
controller = _controller()
|
||||
controller.tools = None # type: ignore[assignment]
|
||||
controller.tools = None
|
||||
|
||||
with patch("core.tools.workflow_as_tool.provider.db") as mock_db:
|
||||
mock_db.engine = object()
|
||||
@ -171,19 +221,8 @@ def test_get_tools_returns_empty_when_provider_missing():
|
||||
|
||||
def test_get_tools_raises_when_app_missing():
|
||||
controller = _controller()
|
||||
controller.tools = None # type: ignore[assignment]
|
||||
db_provider = SimpleNamespace(
|
||||
id="provider-1",
|
||||
app_id="app-1",
|
||||
version="1",
|
||||
user_id="user-1",
|
||||
label="WF Provider",
|
||||
description="desc",
|
||||
icon="icon.svg",
|
||||
name="workflow_tool",
|
||||
tenant_id="tenant-1",
|
||||
parameter_configurations=[],
|
||||
)
|
||||
controller.tools = None
|
||||
db_provider = _db_provider()
|
||||
|
||||
with patch("core.tools.workflow_as_tool.provider.db") as mock_db:
|
||||
mock_db.engine = object()
|
||||
|
||||
@ -10,7 +10,7 @@ class DocumentIndexingTaskProxyTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan | str | None = CloudPlan.SANDBOX) -> Mock:
|
||||
"""Create mock features with billing configuration."""
|
||||
features = Mock()
|
||||
features.billing = Mock()
|
||||
|
||||
@ -12,7 +12,7 @@ class DuplicateDocumentIndexingTaskProxyTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for DuplicateDocumentIndexingTaskProxy tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan = CloudPlan.SANDBOX) -> Mock:
|
||||
def create_mock_features(billing_enabled: bool = False, plan: CloudPlan | str | None = CloudPlan.SANDBOX) -> Mock:
|
||||
"""Create mock features with billing configuration."""
|
||||
features = Mock()
|
||||
features.billing = Mock()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user