From 41ed2e0cc291543ee7338b14bf416c5bcfab7255 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 29 Aug 2024 20:17:17 +0800 Subject: [PATCH] feat: backwards invoke app --- api/controllers/inner_api/plugin/plugin.py | 8 +- api/core/plugin/backwards_invocation/app.py | 141 ++++++++++++++++++ api/core/plugin/backwards_invocation/base.py | 20 +++ api/core/plugin/backwards_invocation/model.py | 5 +- api/core/plugin/entities/request.py | 14 +- api/libs/login.py | 3 +- api/models/tools.py | 9 +- 7 files changed, 187 insertions(+), 13 deletions(-) create mode 100644 api/core/plugin/backwards_invocation/app.py create mode 100644 api/core/plugin/backwards_invocation/base.py diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index d9b58e1e93..dfe02b7635 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,5 +1,4 @@ import time -from collections.abc import Generator from flask_restful import Resource, reqparse @@ -30,15 +29,10 @@ class PluginInvokeLLMApi(Resource): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM): def generator(): response = PluginModelBackwardsInvocation.invoke_llm(user_id, tenant_model, payload) - if isinstance(response, Generator): - for chunk in response: - yield chunk.model_dump_json().encode() + b'\n\n' - else: - yield response.model_dump_json().encode() + b'\n\n' + return PluginModelBackwardsInvocation.convert_to_event_stream(response) return compact_generate_response(generator()) - class PluginInvokeTextEmbeddingApi(Resource): @setup_required @plugin_inner_api_only diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py new file mode 100644 index 0000000000..7304c637fa --- /dev/null +++ b/api/core/plugin/backwards_invocation/app.py @@ -0,0 +1,141 @@ +from collections.abc import Generator, Mapping +from typing import Literal, Union + +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation +from extensions.ext_database import db +from models.account import Account +from models.model import App, AppMode, EndUser + + +class PluginAppBackwardsInvocation(BaseBackwardsInvocation): + @classmethod + def invoke_app( + cls, app_id: str, + user_id: str, + tenant_id: str, + query: str, + inputs: Mapping, + files: list[dict], + ) -> Generator[dict, None, None] | dict: + """ + invoke app + """ + app = cls._get_app(app_id, tenant_id) + if app.mode in [AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value]: + return cls.invoke_chat_app(app, user_id, tenant_id, query, inputs, files) + elif app.mode in [AppMode.WORKFLOW.value]: + return cls.invoke_workflow_app(app, user_id, tenant_id, inputs, files) + elif app.mode in [AppMode.COMPLETION]: + return cls.invoke_completion_app(app, user_id, tenant_id, inputs, files) + + raise ValueError("unexpected app type") + + @classmethod + def invoke_chat_app( + cls, + app: App, + user: Account | EndUser, + tenant_id: str, + conversation_id: str, + query: str, + stream: bool, + inputs: Mapping, + files: list[dict], + ) -> Generator[dict, None, None] | dict: + """ + invoke chat app + """ + if app.mode == AppMode.ADVANCED_CHAT.value: + workflow = app.workflow + if not workflow: + raise ValueError("unexpected app type") + + generator = AdvancedChatAppGenerator() + response = generator.generate( + app_model=app, + workflow=workflow, + user=user, + args={ + }, + invoke_from=InvokeFrom.SERVICE_API, + stream=stream + ) + + + + @classmethod + def invoke_workflow_app( + cls, + app: App, + user_id: str, + tenant_id: str, + inputs: Mapping, + files: list[dict], + ): + """ + invoke workflow app + """ + workflow = app.workflow + if not workflow: + raise ValueError("") + + generator = WorkflowAppGenerator() + + result = generator.generate( + app_model=app, + workflow=workflow, + user=cls._get_user(user_id), + args={ + 'inputs': tool_parameters, + 'files': files + }, + invoke_from=self.runtime.invoke_from, + stream=False, + call_depth=self.workflow_call_depth + 1, + ) + + @classmethod + def invoke_completion_app( + cls, + app: App, + user_id: str, + tenant_id: str, + inputs: Mapping, + files: list[dict], + ): + """ + invoke completion app + """ + + @classmethod + def _get_user(cls, user_id: str) -> Union[EndUser, Account]: + """ + get the user by user id + """ + + user = db.session.query(EndUser).filter(EndUser.id == user_id).first() + if not user: + user = db.session.query(Account).filter(Account.id == user_id).first() + + if not user: + raise ValueError('user not found') + + return user + + @classmethod + def _get_app(cls, app_id: str, tenant_id: str) -> App: + """ + get app + """ + app = db.session.query(App). \ + filter(App.id == app_id). \ + filter(App.tenant_id == tenant_id). \ + first() + + if not app: + raise ValueError("app not found") + + return app \ No newline at end of file diff --git a/api/core/plugin/backwards_invocation/base.py b/api/core/plugin/backwards_invocation/base.py new file mode 100644 index 0000000000..28b691d990 --- /dev/null +++ b/api/core/plugin/backwards_invocation/base.py @@ -0,0 +1,20 @@ +import json +from collections.abc import Generator + +from pydantic import BaseModel + + +class BaseBackwardsInvocation: + @classmethod + def convert_to_event_stream(cls, response: Generator[BaseModel | dict, None, None] | BaseModel | dict): + if isinstance(response, Generator): + for chunk in response: + if isinstance(chunk, BaseModel): + yield chunk.model_dump_json().encode() + b'\n\n' + else: + yield json.dumps(chunk).encode() + b'\n\n' + else: + if isinstance(response, BaseModel): + yield response.model_dump_json().encode() + b'\n\n' + else: + yield json.dumps(response).encode() + b'\n\n' \ No newline at end of file diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index b6da133119..7904fd6234 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -2,12 +2,13 @@ from collections.abc import Generator from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from core.plugin.entities.request import RequestInvokeLLM from core.workflow.nodes.llm.llm_node import LLMNode from models.account import Tenant -class PluginBackwardsInvocation: +class PluginModelBackwardsInvocation(BaseBackwardsInvocation): @classmethod def invoke_llm( cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM @@ -47,3 +48,5 @@ class PluginBackwardsInvocation: if response.usage: LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) return response + + \ No newline at end of file diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index bb08facf75..d7781ba375 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Literal, Optional from pydantic import BaseModel, Field, field_validator @@ -93,3 +93,15 @@ class RequestInvokeNode(BaseModel): """ Request to invoke node """ + +class RequestInvokeApp(BaseModel): + """ + Request to invoke app + """ + app_id: str + inputs: dict[str, Any] + query: Optional[str] = None + response_mode: Literal["blocking", "streaming"] + conversation_id: Optional[str] = None + user: Optional[str] = None + files: list[dict] = Field(default_factory=list) diff --git a/api/libs/login.py b/api/libs/login.py index 7f05eb8404..8431d967bd 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -9,6 +9,7 @@ from werkzeug.local import LocalProxy from extensions.ext_database import db from models.account import Account, Tenant, TenantAccountJoin +from models.model import EndUser #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user @@ -96,7 +97,7 @@ def login_required(func): return decorated_view -def _get_user(): +def _get_user() -> EndUser | Account | None: if has_request_context(): if "_login_user" not in g: current_app.login_manager._load_user() diff --git a/api/models/tools.py b/api/models/tools.py index 3ee246eeb3..937481583a 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -278,7 +278,10 @@ class ToolConversationVariables(db.Model): def variables(self) -> dict: return json.loads(self.variables_str) -class ToolFile(DeclarativeBase): +class Base(DeclarativeBase): + pass + +class ToolFile(Base): """ store the file created by agent """ @@ -293,9 +296,9 @@ class ToolFile(DeclarativeBase): # conversation user id user_id: Mapped[str] = mapped_column(StringUUID) # tenant id - tenant_id: Mapped[StringUUID] = mapped_column(StringUUID) + tenant_id: Mapped[str] = mapped_column(StringUUID) # conversation id - conversation_id: Mapped[StringUUID] = mapped_column(nullable=True) + conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # file key file_key: Mapped[str] = mapped_column(db.String(255), nullable=False) # mime type