feat: implement invoke app args

This commit is contained in:
Yeuoly 2024-08-29 20:50:36 +08:00
parent 41ed2e0cc2
commit 12ea085e22
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
5 changed files with 99 additions and 82 deletions

View File

@ -6,10 +6,13 @@ from controllers.console.setup import setup_required
from controllers.inner_api import api from controllers.inner_api import api
from controllers.inner_api.plugin.wraps import get_tenant, plugin_data from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
from controllers.inner_api.wraps import plugin_inner_api_only from controllers.inner_api.wraps import plugin_inner_api_only
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.entities.request import ( from core.plugin.entities.request import (
RequestInvokeApp,
RequestInvokeLLM, RequestInvokeLLM,
RequestInvokeModeration, RequestInvokeModeration,
RequestInvokeNode,
RequestInvokeRerank, RequestInvokeRerank,
RequestInvokeSpeech2Text, RequestInvokeSpeech2Text,
RequestInvokeTextEmbedding, RequestInvokeTextEmbedding,
@ -104,21 +107,33 @@ class PluginInvokeNodeApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_tenant @get_tenant
def post(self, user_id: str, tenant_model: Tenant): @plugin_data(payload_type=RequestInvokeNode)
parser = reqparse.RequestParser() def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeNode):
args = parser.parse_args() pass
return {'message': 'success'}
class PluginInvokeAppApi(Resource): class PluginInvokeAppApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_tenant @get_tenant
def post(self, user_id: str, tenant_model: Tenant): @plugin_data(payload_type=RequestInvokeApp)
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeApp):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
args = parser.parse_args() args = parser.parse_args()
return {'message': 'success'} response = PluginAppBackwardsInvocation.invoke_app(
app_id=payload.app_id,
user_id=user_id,
tenant_id=tenant_model.id,
conversation_id=payload.conversation_id,
query=payload.query,
stream=payload.stream,
inputs=payload.inputs,
files=payload.files
)
return compact_generate_response(
PluginAppBackwardsInvocation.convert_to_event_stream(response)
)
api.add_resource(PluginInvokeLLMApi, '/invoke/llm') api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding') api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')

View File

@ -1,7 +1,10 @@
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Literal, Union from typing import Optional, Union
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator
from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
@ -16,20 +19,29 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
cls, app_id: str, cls, app_id: str,
user_id: str, user_id: str,
tenant_id: str, tenant_id: str,
query: str, conversation_id: Optional[str],
query: Optional[str],
stream: bool,
inputs: Mapping, inputs: Mapping,
files: list[dict], files: list[dict],
) -> Generator[dict, None, None] | dict: ) -> Generator[dict | str, None, None] | dict:
""" """
invoke app invoke app
""" """
app = cls._get_app(app_id, tenant_id) app = cls._get_app(app_id, tenant_id)
user = cls._get_user(user_id)
conversation_id = conversation_id or ""
if app.mode in [AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value]: if app.mode in [AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value]:
return cls.invoke_chat_app(app, user_id, tenant_id, query, inputs, files) if not query:
raise ValueError("missing query")
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
elif app.mode in [AppMode.WORKFLOW.value]: elif app.mode in [AppMode.WORKFLOW.value]:
return cls.invoke_workflow_app(app, user_id, tenant_id, inputs, files) return cls.invoke_workflow_app(app, user, stream, inputs, files)
elif app.mode in [AppMode.COMPLETION]: elif app.mode in [AppMode.COMPLETION]:
return cls.invoke_completion_app(app, user_id, tenant_id, inputs, files) return cls.invoke_completion_app(app, user, stream, inputs, files)
raise ValueError("unexpected app type") raise ValueError("unexpected app type")
@ -38,13 +50,12 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
cls, cls,
app: App, app: App,
user: Account | EndUser, user: Account | EndUser,
tenant_id: str,
conversation_id: str, conversation_id: str,
query: str, query: str,
stream: bool, stream: bool,
inputs: Mapping, inputs: Mapping,
files: list[dict], files: list[dict],
) -> Generator[dict, None, None] | dict: ) -> Generator[dict | str, None, None] | dict:
""" """
invoke chat app invoke chat app
""" """
@ -53,25 +64,54 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
if not workflow: if not workflow:
raise ValueError("unexpected app type") raise ValueError("unexpected app type")
generator = AdvancedChatAppGenerator() return AdvancedChatAppGenerator().generate(
response = generator.generate(
app_model=app, app_model=app,
workflow=workflow, workflow=workflow,
user=user, user=user,
args={ args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
}, },
invoke_from=InvokeFrom.SERVICE_API, invoke_from=InvokeFrom.SERVICE_API,
stream=stream stream=stream
) )
elif app.mode == AppMode.AGENT_CHAT.value:
return AgentChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
stream=stream
)
elif app.mode == AppMode.CHAT.value:
return ChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
stream=stream
)
else:
raise ValueError("unexpected app type")
@classmethod @classmethod
def invoke_workflow_app( def invoke_workflow_app(
cls, cls,
app: App, app: App,
user_id: str, user: EndUser | Account,
tenant_id: str, stream: bool,
inputs: Mapping, inputs: Mapping,
files: list[dict], files: list[dict],
): ):
@ -82,33 +122,41 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
if not workflow: if not workflow:
raise ValueError("") raise ValueError("")
generator = WorkflowAppGenerator() return WorkflowAppGenerator().generate(
result = generator.generate(
app_model=app, app_model=app,
workflow=workflow, workflow=workflow,
user=cls._get_user(user_id), user=user,
args={ args={
'inputs': tool_parameters, 'inputs': inputs,
'files': files 'files': files
}, },
invoke_from=self.runtime.invoke_from, invoke_from=InvokeFrom.SERVICE_API,
stream=False, stream=stream,
call_depth=self.workflow_call_depth + 1, call_depth=1,
) )
@classmethod @classmethod
def invoke_completion_app( def invoke_completion_app(
cls, cls,
app: App, app: App,
user_id: str, user: EndUser | Account,
tenant_id: str, stream: bool,
inputs: Mapping, inputs: Mapping,
files: list[dict], files: list[dict],
): ):
""" """
invoke completion app invoke completion app
""" """
return CompletionAppGenerator().generate(
app_model=app,
user=user,
args={
'inputs': inputs,
'files': files
},
invoke_from=InvokeFrom.SERVICE_API,
stream=stream,
)
@classmethod @classmethod
def _get_user(cls, user_id: str) -> Union[EndUser, Account]: def _get_user(cls, user_id: str) -> Union[EndUser, Account]:

View File

@ -6,15 +6,17 @@ from pydantic import BaseModel
class BaseBackwardsInvocation: class BaseBackwardsInvocation:
@classmethod @classmethod
def convert_to_event_stream(cls, response: Generator[BaseModel | dict, None, None] | BaseModel | dict): def convert_to_event_stream(cls, response: Generator[BaseModel | dict | str, None, None] | BaseModel | dict):
if isinstance(response, Generator): if isinstance(response, Generator):
for chunk in response: for chunk in response:
if isinstance(chunk, BaseModel): if isinstance(chunk, BaseModel):
yield chunk.model_dump_json().encode() + b'\n\n' yield chunk.model_dump_json().encode() + b'\n\n'
if isinstance(chunk, str):
yield f"event: {chunk}\n\n".encode()
else: else:
yield json.dumps(chunk).encode() + b'\n\n' yield json.dumps(chunk).encode() + b'\n\n'
else: else:
if isinstance(response, BaseModel): if isinstance(response, BaseModel):
yield response.model_dump_json().encode() + b'\n\n' yield response.model_dump_json().encode() + b'\n\n'
else: else:
yield json.dumps(response).encode() + b'\n\n' yield json.dumps(response).encode() + b'\n\n'

View File

@ -105,3 +105,4 @@ class RequestInvokeApp(BaseModel):
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
user: Optional[str] = None user: Optional[str] = None
files: list[dict] = Field(default_factory=list) files: list[dict] = Field(default_factory=list)
stream: bool = Field(default=False)

View File

@ -1,49 +0,0 @@
from collections.abc import Generator
from typing import Any, Union
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.plugin_tool_callback_handler import DifyPluginCallbackHandler
from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeType
from models.account import Tenant
from services.tools.tools_transform_service import ToolTransformService
class PluginInvokeService:
@classmethod
def invoke_tool(cls, user_id: str, invoke_from: InvokeFrom, tenant: Tenant,
tool_provider_type: ToolProviderType, tool_provider: str, tool_name: str,
tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
"""
Invokes a tool with the given user ID and tool parameters.
"""
tool_runtime = ToolManager.get_tool_runtime(tool_provider_type, provider_id=tool_provider,
tool_name=tool_name, tenant_id=tenant.id,
invoke_from=invoke_from)
response = ToolEngine.plugin_invoke(tool_runtime,
tool_parameters,
user_id,
callback=DifyPluginCallbackHandler())
response = ToolFileMessageTransformer.transform_tool_invoke_messages(response)
return ToolTransformService.transform_messages_to_dict(response)
@classmethod
def invoke_model(cls, user_id: str, tenant: Tenant,
model_provider: str, model_name: str, model_type: ModelType,
model_parameters: dict[str, Any]) -> Union[dict, Generator[ToolInvokeMessage]]:
"""
Invokes a model with the given user ID and model parameters.
"""
@classmethod
def invoke_workflow_node(cls, user_id: str, tenant: Tenant,
node_type: NodeType, node_data: dict[str, Any],
inputs: dict[str, Any]) -> Generator[ToolInvokeMessage]:
"""
Invokes a workflow node with the given user ID and node parameters.
"""