From 31e8b134d14f492dbf256067db184f83e4eef24e Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 29 Jul 2024 22:08:14 +0800 Subject: [PATCH] feat: backwards invoke llm --- api/controllers/inner_api/plugin/plugin.py | 14 ++++- api/core/model_manager.py | 4 +- api/core/plugin/backwards_invocation/model.py | 49 +++++++++++++++ api/core/plugin/entities/request.py | 59 ++++++++++++++++++- 4 files changed, 119 insertions(+), 7 deletions(-) create mode 100644 api/core/plugin/backwards_invocation/model.py diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index b3ebf81bf6..3a76e00767 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,10 +1,13 @@ import time +from collections.abc import Generator + from flask_restful import Resource, reqparse from controllers.console.setup import setup_required from controllers.inner_api import api from controllers.inner_api.plugin.wraps import get_tenant, plugin_data from controllers.inner_api.wraps import plugin_inner_api_only +from core.plugin.backwards_invocation.model import PluginBackwardsInvocation from core.plugin.entities.request import ( RequestInvokeLLM, RequestInvokeModeration, @@ -17,7 +20,6 @@ from core.plugin.entities.request import ( from core.tools.entities.tool_entities import ToolInvokeMessage from libs.helper import compact_generate_response from models.account import Tenant -from services.plugin.plugin_invoke_service import PluginInvokeService class PluginInvokeLLMApi(Resource): @@ -26,7 +28,15 @@ class PluginInvokeLLMApi(Resource): @get_tenant @plugin_data(payload_type=RequestInvokeLLM) def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM): - pass + def generator(): + response = PluginBackwardsInvocation.invoke_llm(user_id, tenant_model, payload) + if isinstance(response, Generator): + for chunk in response: + yield chunk.model_dump_json().encode() + b'\n\n' + else: + yield response.model_dump_json().encode() + b'\n\n' + + return compact_generate_response(generator()) class PluginInvokeTextEmbeddingApi(Resource): diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 8e99ad3dec..e46d1d35ee 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -7,7 +7,7 @@ from core.entities.provider_configuration import ProviderConfiguration, Provider from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.rerank_entities import RerankResult @@ -103,7 +103,7 @@ class ModelInstance: def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ - -> Union[LLMResult, Generator]: + -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: """ Invoke large language model diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py new file mode 100644 index 0000000000..b6da133119 --- /dev/null +++ b/api/core/plugin/backwards_invocation/model.py @@ -0,0 +1,49 @@ +from collections.abc import Generator + +from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from core.plugin.entities.request import RequestInvokeLLM +from core.workflow.nodes.llm.llm_node import LLMNode +from models.account import Tenant + + +class PluginBackwardsInvocation: + @classmethod + def invoke_llm( + cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM + ) -> Generator[LLMResultChunk, None, None] | LLMResult: + """ + invoke llm + """ + model_instance = ModelManager().get_model_instance( + tenant_id=tenant.id, + provider=payload.provider, + model_type=payload.model_type, + model=payload.model, + ) + + # invoke model + response = model_instance.invoke_llm( + prompt_messages=payload.prompt_messages, + model_parameters=payload.model_parameters, + tools=payload.tools, + stop=payload.stop, + stream=payload.stream or True, + user=user_id, + ) + + if isinstance(response, Generator): + + def handle() -> Generator[LLMResultChunk, None, None]: + for chunk in response: + if chunk.delta.usage: + LLMNode.deduct_llm_quota( + tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage + ) + yield chunk + + return handle() + else: + if response.usage: + LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) + return response diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 62db6396a8..bb08facf75 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -1,4 +1,17 @@ -from pydantic import BaseModel +from typing import Any, Optional + +from pydantic import BaseModel, Field, field_validator + +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ModelType class RequestInvokeTool(BaseModel): @@ -6,37 +19,77 @@ class RequestInvokeTool(BaseModel): Request to invoke a tool """ -class RequestInvokeLLM(BaseModel): + +class BaseRequestInvokeModel(BaseModel): + provider: str + model: str + model_type: ModelType + + +class RequestInvokeLLM(BaseRequestInvokeModel): """ Request to invoke LLM """ + model_type: ModelType = ModelType.LLM + mode: str + model_parameters: dict[str, Any] = Field(default_factory=dict) + prompt_messages: list[PromptMessage] + tools: Optional[list[PromptMessageTool]] = Field(default_factory=list) + stop: Optional[list[str]] = Field(default_factory=list) + stream: Optional[bool] = False + + @field_validator('prompt_messages', mode='before') + def convert_prompt_messages(cls, v): + if not isinstance(v, list): + raise ValueError('prompt_messages must be a list') + + for i in range(len(v)): + if v[i]['role'] == PromptMessageRole.USER.value: + v[i] = UserPromptMessage(**v[i]) + elif v[i]['role'] == PromptMessageRole.ASSISTANT.value: + v[i] = AssistantPromptMessage(**v[i]) + elif v[i]['role'] == PromptMessageRole.SYSTEM.value: + v[i] = SystemPromptMessage(**v[i]) + elif v[i]['role'] == PromptMessageRole.TOOL.value: + v[i] = ToolPromptMessage(**v[i]) + else: + v[i] = PromptMessage(**v[i]) + + return v + + class RequestInvokeTextEmbedding(BaseModel): """ Request to invoke text embedding """ + class RequestInvokeRerank(BaseModel): """ Request to invoke rerank """ + class RequestInvokeTTS(BaseModel): """ Request to invoke TTS """ + class RequestInvokeSpeech2Text(BaseModel): """ Request to invoke speech2text """ + class RequestInvokeModeration(BaseModel): """ Request to invoke moderation """ + class RequestInvokeNode(BaseModel): """ Request to invoke node - """ \ No newline at end of file + """