mirror of https://github.com/langgenius/dify.git
feat: backwards invoke llm
This commit is contained in:
parent
d52476c1c9
commit
31e8b134d1
|
|
@ -1,10 +1,13 @@
|
|||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
|
||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||
from core.plugin.backwards_invocation.model import PluginBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeModeration,
|
||||
|
|
@ -17,7 +20,6 @@ from core.plugin.entities.request import (
|
|||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from libs.helper import compact_generate_response
|
||||
from models.account import Tenant
|
||||
from services.plugin.plugin_invoke_service import PluginInvokeService
|
||||
|
||||
|
||||
class PluginInvokeLLMApi(Resource):
|
||||
|
|
@ -26,7 +28,15 @@ class PluginInvokeLLMApi(Resource):
|
|||
@get_tenant
|
||||
@plugin_data(payload_type=RequestInvokeLLM)
|
||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM):
|
||||
pass
|
||||
def generator():
|
||||
response = PluginBackwardsInvocation.invoke_llm(user_id, tenant_model, payload)
|
||||
if isinstance(response, Generator):
|
||||
for chunk in response:
|
||||
yield chunk.model_dump_json().encode() + b'\n\n'
|
||||
else:
|
||||
yield response.model_dump_json().encode() + b'\n\n'
|
||||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from core.entities.provider_configuration import ProviderConfiguration, Provider
|
|||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
|
|
@ -103,7 +103,7 @@ class ModelInstance:
|
|||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
-> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.plugin.entities.request import RequestInvokeLLM
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
class PluginBackwardsInvocation:
|
||||
@classmethod
|
||||
def invoke_llm(
|
||||
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
|
||||
) -> Generator[LLMResultChunk, None, None] | LLMResult:
|
||||
"""
|
||||
invoke llm
|
||||
"""
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant.id,
|
||||
provider=payload.provider,
|
||||
model_type=payload.model_type,
|
||||
model=payload.model,
|
||||
)
|
||||
|
||||
# invoke model
|
||||
response = model_instance.invoke_llm(
|
||||
prompt_messages=payload.prompt_messages,
|
||||
model_parameters=payload.model_parameters,
|
||||
tools=payload.tools,
|
||||
stop=payload.stop,
|
||||
stream=payload.stream or True,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
|
||||
def handle() -> Generator[LLMResultChunk, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
LLMNode.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
return response
|
||||
|
|
@ -1,4 +1,17 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class RequestInvokeTool(BaseModel):
|
||||
|
|
@ -6,37 +19,77 @@ class RequestInvokeTool(BaseModel):
|
|||
Request to invoke a tool
|
||||
"""
|
||||
|
||||
class RequestInvokeLLM(BaseModel):
|
||||
|
||||
class BaseRequestInvokeModel(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class RequestInvokeLLM(BaseRequestInvokeModel):
|
||||
"""
|
||||
Request to invoke LLM
|
||||
"""
|
||||
|
||||
model_type: ModelType = ModelType.LLM
|
||||
mode: str
|
||||
model_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
prompt_messages: list[PromptMessage]
|
||||
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
|
||||
stop: Optional[list[str]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
|
||||
@field_validator('prompt_messages', mode='before')
|
||||
def convert_prompt_messages(cls, v):
|
||||
if not isinstance(v, list):
|
||||
raise ValueError('prompt_messages must be a list')
|
||||
|
||||
for i in range(len(v)):
|
||||
if v[i]['role'] == PromptMessageRole.USER.value:
|
||||
v[i] = UserPromptMessage(**v[i])
|
||||
elif v[i]['role'] == PromptMessageRole.ASSISTANT.value:
|
||||
v[i] = AssistantPromptMessage(**v[i])
|
||||
elif v[i]['role'] == PromptMessageRole.SYSTEM.value:
|
||||
v[i] = SystemPromptMessage(**v[i])
|
||||
elif v[i]['role'] == PromptMessageRole.TOOL.value:
|
||||
v[i] = ToolPromptMessage(**v[i])
|
||||
else:
|
||||
v[i] = PromptMessage(**v[i])
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class RequestInvokeTextEmbedding(BaseModel):
|
||||
"""
|
||||
Request to invoke text embedding
|
||||
"""
|
||||
|
||||
|
||||
class RequestInvokeRerank(BaseModel):
|
||||
"""
|
||||
Request to invoke rerank
|
||||
"""
|
||||
|
||||
|
||||
class RequestInvokeTTS(BaseModel):
|
||||
"""
|
||||
Request to invoke TTS
|
||||
"""
|
||||
|
||||
|
||||
class RequestInvokeSpeech2Text(BaseModel):
|
||||
"""
|
||||
Request to invoke speech2text
|
||||
"""
|
||||
|
||||
|
||||
class RequestInvokeModeration(BaseModel):
|
||||
"""
|
||||
Request to invoke moderation
|
||||
"""
|
||||
|
||||
|
||||
class RequestInvokeNode(BaseModel):
|
||||
"""
|
||||
Request to invoke node
|
||||
"""
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue