mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 10:16:40 +08:00
refactor: rename agent to agent strategy
This commit is contained in:
parent
c2983ecbb7
commit
3c628d0c26
@ -6,26 +6,26 @@ from core.tools.entities.common_entities import I18nObject
|
|||||||
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderIdentity
|
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderIdentity
|
||||||
|
|
||||||
|
|
||||||
class AgentProviderIdentity(ToolProviderIdentity):
|
class AgentStrategyProviderIdentity(ToolProviderIdentity):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AgentParameter(ToolParameter):
|
class AgentStrategyParameter(ToolParameter):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AgentProviderEntity(BaseModel):
|
class AgentStrategyProviderEntity(BaseModel):
|
||||||
identity: AgentProviderIdentity
|
identity: AgentStrategyProviderIdentity
|
||||||
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
|
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
|
||||||
|
|
||||||
|
|
||||||
class AgentIdentity(ToolIdentity):
|
class AgentStrategyIdentity(ToolIdentity):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AgentStrategyEntity(BaseModel):
|
class AgentStrategyEntity(BaseModel):
|
||||||
identity: AgentIdentity
|
identity: AgentStrategyIdentity
|
||||||
parameters: list[AgentParameter] = Field(default_factory=list)
|
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||||
output_schema: Optional[dict] = None
|
output_schema: Optional[dict] = None
|
||||||
|
|
||||||
@ -34,9 +34,9 @@ class AgentStrategyEntity(BaseModel):
|
|||||||
|
|
||||||
@field_validator("parameters", mode="before")
|
@field_validator("parameters", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentParameter]:
|
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
|
||||||
return v or []
|
return v or []
|
||||||
|
|
||||||
|
|
||||||
class AgentProviderEntityWithPlugin(AgentProviderEntity):
|
class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
|
||||||
strategies: list[AgentStrategyEntity] = Field(default_factory=list)
|
strategies: list[AgentStrategyEntity] = Field(default_factory=list)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Any, Generator, Optional, Sequence
|
from typing import Any, Generator, Optional, Sequence
|
||||||
|
|
||||||
from core.agent.entities import AgentInvokeMessage
|
from core.agent.entities import AgentInvokeMessage
|
||||||
from core.agent.plugin_entities import AgentParameter
|
from core.agent.plugin_entities import AgentStrategyParameter
|
||||||
|
|
||||||
|
|
||||||
class BaseAgentStrategy(ABC):
|
class BaseAgentStrategy(ABC):
|
||||||
@ -23,7 +23,7 @@ class BaseAgentStrategy(ABC):
|
|||||||
"""
|
"""
|
||||||
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
|
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
|
||||||
|
|
||||||
def get_parameters(self) -> Sequence[AgentParameter]:
|
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||||
"""
|
"""
|
||||||
Get the parameters for the agent strategy.
|
Get the parameters for the agent strategy.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from typing import Any, Generator, Optional, Sequence
|
from typing import Any, Generator, Optional, Sequence
|
||||||
|
|
||||||
from core.agent.entities import AgentInvokeMessage
|
from core.agent.entities import AgentInvokeMessage
|
||||||
from core.agent.plugin_entities import AgentParameter, AgentStrategyEntity
|
from core.agent.plugin_entities import AgentStrategyParameter, AgentStrategyEntity
|
||||||
from core.agent.strategy.base import BaseAgentStrategy
|
from core.agent.strategy.base import BaseAgentStrategy
|
||||||
from core.plugin.manager.agent import PluginAgentManager
|
from core.plugin.manager.agent import PluginAgentManager
|
||||||
from core.tools.plugin_tool.tool import PluginTool
|
from core.tools.plugin_tool.tool import PluginTool
|
||||||
@ -21,7 +21,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
|||||||
self.plugin_unique_identifier = plugin_unique_identifier
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
self.declaration = declaration
|
self.declaration = declaration
|
||||||
|
|
||||||
def get_parameters(self) -> Sequence[AgentParameter]:
|
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||||
return self.declaration.parameters
|
return self.declaration.parameters
|
||||||
|
|
||||||
def _invoke(
|
def _invoke(
|
||||||
|
|||||||
@ -43,7 +43,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
# invoke model
|
# invoke model
|
||||||
response = model_instance.invoke_llm(
|
response = model_instance.invoke_llm(
|
||||||
prompt_messages=payload.prompt_messages,
|
prompt_messages=payload.prompt_messages,
|
||||||
model_parameters=payload.model_parameters,
|
model_parameters=payload.completion_params,
|
||||||
tools=payload.tools,
|
tools=payload.tools,
|
||||||
stop=payload.stop,
|
stop=payload.stop,
|
||||||
stream=payload.stream or True,
|
stream=payload.stream or True,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
||||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||||
from core.plugin.entities.base import BasePluginEntity
|
from core.plugin.entities.base import BasePluginEntity
|
||||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||||
@ -59,6 +60,7 @@ class PluginCategory(enum.StrEnum):
|
|||||||
Tool = "tool"
|
Tool = "tool"
|
||||||
Model = "model"
|
Model = "model"
|
||||||
Extension = "extension"
|
Extension = "extension"
|
||||||
|
AgentStrategy = "agent_strategy"
|
||||||
|
|
||||||
|
|
||||||
class PluginDeclaration(BaseModel):
|
class PluginDeclaration(BaseModel):
|
||||||
@ -82,6 +84,7 @@ class PluginDeclaration(BaseModel):
|
|||||||
tool: Optional[ToolProviderEntity] = None
|
tool: Optional[ToolProviderEntity] = None
|
||||||
model: Optional[ProviderEntity] = None
|
model: Optional[ProviderEntity] = None
|
||||||
endpoint: Optional[EndpointProviderDeclaration] = None
|
endpoint: Optional[EndpointProviderDeclaration] = None
|
||||||
|
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -91,6 +94,8 @@ class PluginDeclaration(BaseModel):
|
|||||||
values["category"] = PluginCategory.Tool
|
values["category"] = PluginCategory.Tool
|
||||||
elif values.get("model"):
|
elif values.get("model"):
|
||||||
values["category"] = PluginCategory.Model
|
values["category"] = PluginCategory.Model
|
||||||
|
elif values.get("agent_strategy"):
|
||||||
|
values["category"] = PluginCategory.AgentStrategy
|
||||||
else:
|
else:
|
||||||
values["category"] = PluginCategory.Extension
|
values["category"] = PluginCategory.Extension
|
||||||
return values
|
return values
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
|||||||
|
|
||||||
model_type: ModelType = ModelType.LLM
|
model_type: ModelType = ModelType.LLM
|
||||||
mode: str
|
mode: str
|
||||||
model_parameters: dict[str, Any] = Field(default_factory=dict)
|
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||||
prompt_messages: list[PromptMessage] = Field(default_factory=list)
|
prompt_messages: list[PromptMessage] = Field(default_factory=list)
|
||||||
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
|
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
|
||||||
stop: Optional[list[str]] = Field(default_factory=list)
|
stop: Optional[list[str]] = Field(default_factory=list)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from core.plugin.manager.base import BasePluginManager
|
|||||||
|
|
||||||
|
|
||||||
class PluginAgentManager(BasePluginManager):
|
class PluginAgentManager(BasePluginManager):
|
||||||
def fetch_agent_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
|
def fetch_agent_strategy_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
|
||||||
"""
|
"""
|
||||||
Fetch agent providers for the given tenant.
|
Fetch agent providers for the given tenant.
|
||||||
"""
|
"""
|
||||||
@ -26,7 +26,7 @@ class PluginAgentManager(BasePluginManager):
|
|||||||
|
|
||||||
response = self._request_with_plugin_daemon_response(
|
response = self._request_with_plugin_daemon_response(
|
||||||
"GET",
|
"GET",
|
||||||
f"plugin/{tenant_id}/management/agents",
|
f"plugin/{tenant_id}/management/agent_strategies",
|
||||||
list[PluginAgentProviderEntity],
|
list[PluginAgentProviderEntity],
|
||||||
params={"page": 1, "page_size": 256},
|
params={"page": 1, "page_size": 256},
|
||||||
transformer=transformer,
|
transformer=transformer,
|
||||||
@ -41,7 +41,7 @@ class PluginAgentManager(BasePluginManager):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def fetch_agent_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity:
|
def fetch_agent_strategy_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity:
|
||||||
"""
|
"""
|
||||||
Fetch tool provider for the given tenant and plugin.
|
Fetch tool provider for the given tenant and plugin.
|
||||||
"""
|
"""
|
||||||
@ -55,7 +55,7 @@ class PluginAgentManager(BasePluginManager):
|
|||||||
|
|
||||||
response = self._request_with_plugin_daemon_response(
|
response = self._request_with_plugin_daemon_response(
|
||||||
"GET",
|
"GET",
|
||||||
f"plugin/{tenant_id}/management/agent",
|
f"plugin/{tenant_id}/management/agent_strategy",
|
||||||
PluginAgentProviderEntity,
|
PluginAgentProviderEntity,
|
||||||
params={"provider": agent_provider_id.provider_name, "plugin_id": agent_provider_id.plugin_id},
|
params={"provider": agent_provider_id.provider_name, "plugin_id": agent_provider_id.plugin_id},
|
||||||
transformer=transformer,
|
transformer=transformer,
|
||||||
@ -96,9 +96,9 @@ class PluginAgentManager(BasePluginManager):
|
|||||||
"app_id": app_id,
|
"app_id": app_id,
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
"data": {
|
"data": {
|
||||||
"provider": agent_provider_id.provider_name,
|
"agent_strategy_provider": agent_provider_id.provider_name,
|
||||||
"strategy": agent_strategy,
|
"agent_strategy": agent_strategy,
|
||||||
"agent_params": agent_params,
|
"agent_strategy_params": agent_params,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
headers={
|
headers={
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any, Sequence, cast
|
from typing import Any, Sequence, cast
|
||||||
|
|
||||||
from core.agent.plugin_entities import AgentParameter
|
from core.agent.plugin_entities import AgentStrategyParameter
|
||||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
@ -90,7 +90,7 @@ class AgentNode(ToolNode):
|
|||||||
def _generate_parameters(
|
def _generate_parameters(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
agent_parameters: Sequence[AgentParameter],
|
agent_parameters: Sequence[AgentStrategyParameter],
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
node_data: AgentNodeData,
|
node_data: AgentNodeData,
|
||||||
for_log: bool = False,
|
for_log: bool = False,
|
||||||
|
|||||||
@ -246,7 +246,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
)
|
)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||||
text += message.message.text + "\n"
|
text += message.message.text
|
||||||
yield RunStreamChunkEvent(
|
yield RunStreamChunkEvent(
|
||||||
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ def get_plugin_agent_strategy(
|
|||||||
) -> PluginAgentStrategy:
|
) -> PluginAgentStrategy:
|
||||||
# TODO: use contexts to cache the agent provider
|
# TODO: use contexts to cache the agent provider
|
||||||
manager = PluginAgentManager()
|
manager = PluginAgentManager()
|
||||||
agent_provider = manager.fetch_agent_provider(tenant_id, agent_strategy_provider_name)
|
agent_provider = manager.fetch_agent_strategy_provider(tenant_id, agent_strategy_provider_name)
|
||||||
for agent_strategy in agent_provider.declaration.strategies:
|
for agent_strategy in agent_provider.declaration.strategies:
|
||||||
if agent_strategy.identity.name == agent_strategy_name:
|
if agent_strategy.identity.name == agent_strategy_name:
|
||||||
return PluginAgentStrategy(tenant_id, plugin_unique_identifier, agent_strategy)
|
return PluginAgentStrategy(tenant_id, plugin_unique_identifier, agent_strategy)
|
||||||
|
|||||||
@ -160,7 +160,7 @@ class AgentService:
|
|||||||
List agent providers
|
List agent providers
|
||||||
"""
|
"""
|
||||||
manager = PluginAgentManager()
|
manager = PluginAgentManager()
|
||||||
return manager.fetch_agent_providers(tenant_id)
|
return manager.fetch_agent_strategy_providers(tenant_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_agent_provider(cls, user_id: str, tenant_id: str, provider_name: str):
|
def get_agent_provider(cls, user_id: str, tenant_id: str, provider_name: str):
|
||||||
@ -168,4 +168,4 @@ class AgentService:
|
|||||||
Get agent provider
|
Get agent provider
|
||||||
"""
|
"""
|
||||||
manager = PluginAgentManager()
|
manager = PluginAgentManager()
|
||||||
return manager.fetch_agent_provider(tenant_id, provider_name)
|
return manager.fetch_agent_strategy_provider(tenant_id, provider_name)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user