diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 63c839a4e4..0c710c8716 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -3,7 +3,7 @@ from typing import Generic, Optional, TypeVar from pydantic import BaseModel -T = TypeVar("T", bound=(BaseModel | dict | bool)) +T = TypeVar("T", bound=(BaseModel | dict | list | bool)) class PluginDaemonBasicResponse(BaseModel, Generic[T]): diff --git a/api/core/plugin/manager/base.py b/api/core/plugin/manager/base.py index f6b44d05dd..3f6a87dca8 100644 --- a/api/core/plugin/manager/base.py +++ b/api/core/plugin/manager/base.py @@ -12,7 +12,7 @@ from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_API_URL plugin_daemon_inner_api_key = dify_config.PLUGIN_API_KEY -T = TypeVar("T", bound=(BaseModel | dict | bool)) +T = TypeVar("T", bound=(BaseModel | dict | list | bool)) class BasePluginManager: @@ -22,6 +22,7 @@ class BasePluginManager: path: str, headers: dict | None = None, data: bytes | dict | None = None, + params: dict | None = None, stream: bool = False, ) -> requests.Response: """ @@ -30,16 +31,23 @@ class BasePluginManager: url = URL(str(plugin_daemon_inner_api_baseurl)) / path headers = headers or {} headers["X-Api-Key"] = plugin_daemon_inner_api_key - response = requests.request(method=method, url=str(url), headers=headers, data=data, stream=stream) + response = requests.request( + method=method, url=str(url), headers=headers, data=data, params=params, stream=stream + ) return response def _stream_request( - self, method: str, path: str, headers: dict | None = None, data: bytes | dict | None = None + self, + method: str, + path: str, + params: dict | None = None, + headers: dict | None = None, + data: bytes | dict | None = None, ) -> Generator[bytes, None, None]: """ Make a stream request to the plugin daemon inner API """ - response = self._request(method, path, headers, data, stream=True) + response = self._request(method, path, headers, data, params, stream=True) yield from response.iter_lines() def _stream_request_with_model( @@ -49,29 +57,42 @@ class BasePluginManager: type: type[T], headers: dict | None = None, data: bytes | dict | None = None, + params: dict | None = None, ) -> Generator[T, None, None]: """ Make a stream request to the plugin daemon inner API and yield the response as a model. """ - for line in self._stream_request(method, path, headers, data): + for line in self._stream_request(method, path, params, headers, data): yield type(**json.loads(line)) def _request_with_model( - self, method: str, path: str, type: type[T], headers: dict | None = None, data: bytes | None = None + self, + method: str, + path: str, + type: type[T], + headers: dict | None = None, + data: bytes | None = None, + params: dict | None = None, ) -> T: """ Make a request to the plugin daemon inner API and return the response as a model. """ - response = self._request(method, path, headers, data) + response = self._request(method, path, headers, data, params) return type(**response.json()) def _request_with_plugin_daemon_response( - self, method: str, path: str, type: type[T], headers: dict | None = None, data: bytes | dict | None = None + self, + method: str, + path: str, + type: type[T], + headers: dict | None = None, + data: bytes | dict | None = None, + params: dict | None = None, ) -> T: """ Make a request to the plugin daemon inner API and return the response as a model. """ - response = self._request(method, path, headers, data) + response = self._request(method, path, headers, data, params) rep = PluginDaemonBasicResponse[type](**response.json()) if rep.code != 0: raise ValueError(f"got error from plugin daemon: {rep.message}, code: {rep.code}") @@ -81,12 +102,18 @@ class BasePluginManager: return rep.data def _request_with_plugin_daemon_response_stream( - self, method: str, path: str, type: type[T], headers: dict | None = None, data: bytes | dict | None = None + self, + method: str, + path: str, + type: type[T], + headers: dict | None = None, + data: bytes | dict | None = None, + params: dict | None = None, ) -> Generator[T, None, None]: """ Make a stream request to the plugin daemon inner API and yield the response as a model. """ - for line in self._stream_request(method, path, headers, data): + for line in self._stream_request(method, path, params, headers, data): line_data = json.loads(line) rep = PluginDaemonBasicResponse[type](**line_data) if rep.code != 0: diff --git a/api/core/plugin/manager/model.py b/api/core/plugin/manager/model.py index f03dbfd1e3..4411d76fe1 100644 --- a/api/core/plugin/manager/model.py +++ b/api/core/plugin/manager/model.py @@ -1,5 +1,13 @@ +from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.manager.base import BasePluginManager class PluginModelManager(BasePluginManager): - pass \ No newline at end of file + def fetch_model_providers(self, tenant_id: str) -> list[ProviderEntity]: + """ + Fetch model providers for the given tenant. + """ + response = self._request_with_plugin_daemon_response( + "GET", f"plugin/{tenant_id}/models", list[ProviderEntity], params={"page": 1, "page_size": 256} + ) + return response diff --git a/api/core/plugin/manager/plugin.py b/api/core/plugin/manager/plugin.py index 101827246a..cabe028ff3 100644 --- a/api/core/plugin/manager/plugin.py +++ b/api/core/plugin/manager/plugin.py @@ -1,5 +1,4 @@ from collections.abc import Generator -from urllib.parse import quote from core.plugin.entities.plugin_daemon import InstallPluginMessage from core.plugin.manager.base import BasePluginManager @@ -9,9 +8,8 @@ class PluginInstallationManager(BasePluginManager): def fetch_plugin_by_identifier(self, tenant_id: str, identifier: str) -> bool: # urlencode the identifier - identifier = quote(identifier) return self._request_with_plugin_daemon_response( - "GET", f"/plugin/{tenant_id}/fetch/identifier?plugin_unique_identifier={identifier}", bool + "GET", f"plugin/{tenant_id}/fetch/identifier", bool, params={"plugin_unique_identifier": identifier} ) def install_from_pkg(self, tenant_id: str, pkg: bytes) -> Generator[InstallPluginMessage, None, None]: @@ -22,21 +20,20 @@ class PluginInstallationManager(BasePluginManager): body = {"dify_pkg": ("dify_pkg", pkg, "application/octet-stream")} return self._request_with_plugin_daemon_response_stream( - "POST", f"/plugin/{tenant_id}/install/pkg", InstallPluginMessage, data=body + "POST", f"plugin/{tenant_id}/install/pkg", InstallPluginMessage, data=body ) def install_from_identifier(self, tenant_id: str, identifier: str) -> bool: """ Install a plugin from an identifier. """ - identifier = quote(identifier) # exception will be raised if the request failed return self._request_with_plugin_daemon_response( "POST", - f"/plugin/{tenant_id}/install/identifier", + f"plugin/{tenant_id}/install/identifier", bool, - headers={ - "Content-Type": "application/json", + params={ + "plugin_unique_identifier": identifier, }, data={ "plugin_unique_identifier": identifier, @@ -48,5 +45,5 @@ class PluginInstallationManager(BasePluginManager): Uninstall a plugin. """ return self._request_with_plugin_daemon_response( - "DELETE", f"/plugin/{tenant_id}/uninstall?plugin_unique_identifier={identifier}", bool + "DELETE", f"plugin/{tenant_id}/uninstall", bool, params={"plugin_unique_identifier": identifier} ) diff --git a/api/core/plugin/manager/tool.py b/api/core/plugin/manager/tool.py index 10ce33d5e7..f617400355 100644 --- a/api/core/plugin/manager/tool.py +++ b/api/core/plugin/manager/tool.py @@ -1,9 +1,13 @@ from core.plugin.manager.base import BasePluginManager +from core.tools.entities.tool_entities import ToolProviderEntity class PluginToolManager(BasePluginManager): - def fetch_tool_providers(self, asset_id: str) -> list[str]: + def fetch_tool_providers(self, tenant_id: str) -> list[ToolProviderEntity]: """ Fetch tool providers for the given asset. """ - response = self._request('GET', f'/plugin/asset/{asset_id}') \ No newline at end of file + response = self._request_with_plugin_daemon_response( + "GET", f"plugin/{tenant_id}/tools", list[ToolProviderEntity], params={"page": 1, "page_size": 256} + ) + return response diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 80334a274e..2a85c0f882 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -274,9 +274,12 @@ class ToolProviderIdentity(BaseModel): ) -class ToolProviderEntity(BaseModel): - identity: ToolProviderIdentity - credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict) +class ToolIdentity(BaseModel): + author: str = Field(..., description="The author of the tool") + name: str = Field(..., description="The name of the tool") + label: I18nObject = Field(..., description="The label of the tool") + provider: str = Field(..., description="The provider of the tool") + icon: Optional[str] = None class ToolDescription(BaseModel): @@ -284,12 +287,24 @@ class ToolDescription(BaseModel): llm: str = Field(..., description="The description presented to the LLM") -class ToolIdentity(BaseModel): - author: str = Field(..., description="The author of the tool") - name: str = Field(..., description="The name of the tool") - label: I18nObject = Field(..., description="The label of the tool") - provider: str = Field(..., description="The provider of the tool") - icon: Optional[str] = None +class ToolEntity(BaseModel): + identity: ToolIdentity + parameters: list[ToolParameter] = Field(default_factory=list) + description: Optional[ToolDescription] = None + + # pydantic configs + model_config = ConfigDict(protected_namespaces=()) + + @field_validator("parameters", mode="before") + @classmethod + def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: + return v or [] + + +class ToolProviderEntity(BaseModel): + identity: ToolProviderIdentity + credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict) + tools: list[ToolEntity] = Field(default_factory=list) class WorkflowToolParameterConfiguration(BaseModel): @@ -352,15 +367,4 @@ class ToolInvokeFrom(Enum): AGENT = "agent" -class ToolEntity(BaseModel): - identity: ToolIdentity - parameters: list[ToolParameter] = Field(default_factory=list) - description: Optional[ToolDescription] = None - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - @field_validator("parameters", mode="before") - @classmethod - def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: - return v or [] diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 2d52399d29..8bb1ab96fc 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -83,3 +83,8 @@ VOLC_EMBEDDING_ENDPOINT_ID= # 360 AI Credentials ZHINAO_API_KEY= + +# Plugin configuration +PLUGIN_API_KEY= +PLUGIN_API_URL= +INNER_API_KEY= \ No newline at end of file diff --git a/api/tests/integration_tests/plugin/__mock/http.py b/api/tests/integration_tests/plugin/__mock/http.py new file mode 100644 index 0000000000..25177274c6 --- /dev/null +++ b/api/tests/integration_tests/plugin/__mock/http.py @@ -0,0 +1,66 @@ +import os +from typing import Literal + +import pytest +import requests +from _pytest.monkeypatch import MonkeyPatch + +from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolProviderEntity, ToolProviderIdentity + + +class MockedHttp: + @classmethod + def list_tools(cls) -> list[ToolProviderEntity]: + return [ + ToolProviderEntity( + identity=ToolProviderIdentity( + author="Yeuoly", + name="Yeuoly", + description=I18nObject(en_US="Yeuoly"), + icon="ssss.svg", + label=I18nObject(en_US="Yeuoly"), + ) + ) + ] + + @classmethod + def requests_request( + cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs + ) -> requests.Response: + """ + Mocked requests.request + """ + request = requests.PreparedRequest() + request.method = method + request.url = url + if url.endswith("/tools"): + content = PluginDaemonBasicResponse[list[ToolProviderEntity]]( + code=0, message="success", data=cls.list_tools() + ).model_dump_json() + else: + raise ValueError("") + + response = requests.Response() + response.status_code = 200 + response.request = request + response._content = content.encode("utf-8") + return response + + +MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_http_mock(request, monkeypatch: MonkeyPatch): + if MOCK_SWITCH: + monkeypatch.setattr(requests, "request", MockedHttp.requests_request) + + def unpatch(): + monkeypatch.undo() + + yield + + if MOCK_SWITCH: + unpatch() diff --git a/api/tests/integration_tests/plugin/tools/test_fetch_all_tools.py b/api/tests/integration_tests/plugin/tools/test_fetch_all_tools.py new file mode 100644 index 0000000000..d50bba4ecc --- /dev/null +++ b/api/tests/integration_tests/plugin/tools/test_fetch_all_tools.py @@ -0,0 +1,9 @@ +from core.plugin.manager.tool import PluginToolManager +from tests.integration_tests.plugin.__mock.http import setup_http_mock + + +def test_fetch_all_plugin_tools(setup_http_mock): + manager = PluginToolManager() + tools = manager.fetch_tool_providers(tenant_id="test-tenant") + assert len(tools) >= 1 + diff --git a/api/tests/integration_tests/tools/test_all_provider.py b/api/tests/integration_tests/tools/test_all_provider.py deleted file mode 100644 index 2dfce749b3..0000000000 --- a/api/tests/integration_tests/tools/test_all_provider.py +++ /dev/null @@ -1,23 +0,0 @@ -import pytest - -from core.tools.tool_manager import ToolManager - -provider_generator = ToolManager.list_builtin_providers() -provider_names = [provider.identity.name for provider in provider_generator] -ToolManager.clear_builtin_providers_cache() -provider_generator = ToolManager.list_builtin_providers() - - -@pytest.mark.parametrize("name", provider_names) -def test_tool_providers(benchmark, name): - """ - Test that all tool providers can be loaded - """ - - def test(generator): - try: - return next(generator) - except StopIteration: - return None - - benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1)