feat: backwards invoke llm

This commit is contained in:
Yeuoly 2024-07-29 22:08:14 +08:00
parent d52476c1c9
commit 31e8b134d1
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
4 changed files with 119 additions and 7 deletions

View File

@ -1,10 +1,13 @@
import time
from collections.abc import Generator
from flask_restful import Resource, reqparse
from controllers.console.setup import setup_required
from controllers.inner_api import api
from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
from controllers.inner_api.wraps import plugin_inner_api_only
from core.plugin.backwards_invocation.model import PluginBackwardsInvocation
from core.plugin.entities.request import (
RequestInvokeLLM,
RequestInvokeModeration,
@ -17,7 +20,6 @@ from core.plugin.entities.request import (
from core.tools.entities.tool_entities import ToolInvokeMessage
from libs.helper import compact_generate_response
from models.account import Tenant
from services.plugin.plugin_invoke_service import PluginInvokeService
class PluginInvokeLLMApi(Resource):
@ -26,7 +28,15 @@ class PluginInvokeLLMApi(Resource):
@get_tenant
@plugin_data(payload_type=RequestInvokeLLM)
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM):
pass
def generator():
response = PluginBackwardsInvocation.invoke_llm(user_id, tenant_model, payload)
if isinstance(response, Generator):
for chunk in response:
yield chunk.model_dump_json().encode() + b'\n\n'
else:
yield response.model_dump_json().encode() + b'\n\n'
return compact_generate_response(generator())
class PluginInvokeTextEmbeddingApi(Resource):

View File

@ -7,7 +7,7 @@ from core.entities.provider_configuration import ProviderConfiguration, Provider
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
@ -103,7 +103,7 @@ class ModelInstance:
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
-> Union[LLMResult, Generator]:
-> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
"""
Invoke large language model

View File

@ -0,0 +1,49 @@
from collections.abc import Generator
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.plugin.entities.request import RequestInvokeLLM
from core.workflow.nodes.llm.llm_node import LLMNode
from models.account import Tenant
class PluginBackwardsInvocation:
@classmethod
def invoke_llm(
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
) -> Generator[LLMResultChunk, None, None] | LLMResult:
"""
invoke llm
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
# invoke model
response = model_instance.invoke_llm(
prompt_messages=payload.prompt_messages,
model_parameters=payload.model_parameters,
tools=payload.tools,
stop=payload.stop,
stream=payload.stream or True,
user=user_id,
)
if isinstance(response, Generator):
def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response:
if chunk.delta.usage:
LLMNode.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
yield chunk
return handle()
else:
if response.usage:
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
return response

View File

@ -1,4 +1,17 @@
from pydantic import BaseModel
from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelType
class RequestInvokeTool(BaseModel):
@ -6,37 +19,77 @@ class RequestInvokeTool(BaseModel):
Request to invoke a tool
"""
class RequestInvokeLLM(BaseModel):
class BaseRequestInvokeModel(BaseModel):
provider: str
model: str
model_type: ModelType
class RequestInvokeLLM(BaseRequestInvokeModel):
"""
Request to invoke LLM
"""
model_type: ModelType = ModelType.LLM
mode: str
model_parameters: dict[str, Any] = Field(default_factory=dict)
prompt_messages: list[PromptMessage]
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
stop: Optional[list[str]] = Field(default_factory=list)
stream: Optional[bool] = False
@field_validator('prompt_messages', mode='before')
def convert_prompt_messages(cls, v):
if not isinstance(v, list):
raise ValueError('prompt_messages must be a list')
for i in range(len(v)):
if v[i]['role'] == PromptMessageRole.USER.value:
v[i] = UserPromptMessage(**v[i])
elif v[i]['role'] == PromptMessageRole.ASSISTANT.value:
v[i] = AssistantPromptMessage(**v[i])
elif v[i]['role'] == PromptMessageRole.SYSTEM.value:
v[i] = SystemPromptMessage(**v[i])
elif v[i]['role'] == PromptMessageRole.TOOL.value:
v[i] = ToolPromptMessage(**v[i])
else:
v[i] = PromptMessage(**v[i])
return v
class RequestInvokeTextEmbedding(BaseModel):
"""
Request to invoke text embedding
"""
class RequestInvokeRerank(BaseModel):
"""
Request to invoke rerank
"""
class RequestInvokeTTS(BaseModel):
"""
Request to invoke TTS
"""
class RequestInvokeSpeech2Text(BaseModel):
"""
Request to invoke speech2text
"""
class RequestInvokeModeration(BaseModel):
"""
Request to invoke moderation
"""
class RequestInvokeNode(BaseModel):
"""
Request to invoke node
"""
"""