diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index f2d829c168..7a980d6e39 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -19,6 +19,7 @@ from core.plugin.entities.request import ( RequestInvokeQuestionClassifierNode, RequestInvokeRerank, RequestInvokeSpeech2Text, + RequestInvokeSummary, RequestInvokeTextEmbedding, RequestInvokeTool, RequestInvokeTTS, @@ -230,6 +231,24 @@ class PluginInvokeEncryptApi(Resource): return BaseBackwardsInvocationResponse(error=str(e)).model_dump() +class PluginInvokeSummaryApi(Resource): + @setup_required + @plugin_inner_api_only + @get_tenant + @plugin_data(payload_type=RequestInvokeSummary) + def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSummary): + try: + return BaseBackwardsInvocationResponse( + data=PluginModelBackwardsInvocation.invoke_summary( + user_id=user_id, + tenant=tenant_model, + payload=payload, + ) + ).model_dump() + except Exception as e: + return BaseBackwardsInvocationResponse(error=str(e)).model_dump() + + api.add_resource(PluginInvokeLLMApi, "/invoke/llm") api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") @@ -241,3 +260,4 @@ api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extra api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier") api.add_resource(PluginInvokeAppApi, "/invoke/app") api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt") +api.add_resource(PluginInvokeSummaryApi, "/invoke/summary") diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 405fcb069d..377512886a 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -4,15 +4,23 @@ from collections.abc import Generator from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from core.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from core.plugin.entities.request import ( RequestInvokeLLM, RequestInvokeModeration, RequestInvokeRerank, RequestInvokeSpeech2Text, + RequestInvokeSummary, RequestInvokeTextEmbedding, RequestInvokeTTS, ) +from core.tools.entities.tool_entities import ToolProviderType +from core.tools.utils.model_invocation_utils import ModelInvocationUtils from core.workflow.nodes.llm.llm_node import LLMNode from models.account import Tenant @@ -175,3 +183,139 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): return { "result": response, } + + @classmethod + def get_system_model_max_tokens(cls, tenant_id: str) -> int: + """ + get system model max tokens + """ + return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id) + + @classmethod + def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + """ + get prompt tokens + """ + return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages) + + @classmethod + def invoke_system_model( + cls, + user_id: str, + tenant: Tenant, + prompt_messages: list[PromptMessage], + ) -> LLMResult: + """ + invoke system model + """ + return ModelInvocationUtils.invoke( + user_id=user_id, + tenant_id=tenant.id, + tool_type=ToolProviderType.PLUGIN, + tool_name="plugin", + prompt_messages=prompt_messages, + ) + + @classmethod + def invoke_summary(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSummary): + """ + invoke summary + """ + max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id) + content = payload.text + + SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language +and you can quickly aimed at the main point of an webpage and reproduce it in your own words but +retain the original meaning and keep the key points. +however, the text you got is too long, what you got is possible a part of the text. +Please summarize the text you got. + +Here is the extra instruction you need to follow: + +{payload.instruction} + +""" + + if ( + cls.get_prompt_tokens( + tenant_id=tenant.id, + prompt_messages=[UserPromptMessage(content=content)], + ) + < max_tokens * 0.6 + ): + return content + + def get_prompt_tokens(content: str) -> int: + return cls.get_prompt_tokens( + tenant_id=tenant.id, + prompt_messages=[ + SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)), + UserPromptMessage(content=content), + ], + ) + + def summarize(content: str) -> str: + summary = cls.invoke_system_model( + user_id=user_id, + tenant=tenant, + prompt_messages=[ + SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)), + UserPromptMessage(content=content), + ], + ) + + assert isinstance(summary.message.content, str) + return summary.message.content + + lines = content.split("\n") + new_lines = [] + # split long line into multiple lines + for i in range(len(lines)): + line = lines[i] + if not line.strip(): + continue + if len(line) < max_tokens * 0.5: + new_lines.append(line) + elif get_prompt_tokens(line) > max_tokens * 0.7: + while get_prompt_tokens(line) > max_tokens * 0.7: + new_lines.append(line[: int(max_tokens * 0.5)]) + line = line[int(max_tokens * 0.5) :] + new_lines.append(line) + else: + new_lines.append(line) + + # merge lines into messages with max tokens + messages: list[str] = [] + for i in new_lines: + if len(messages) == 0: + messages.append(i) + else: + if len(messages[-1]) + len(i) < max_tokens * 0.5: + messages[-1] += i + if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: + messages.append(i) + else: + messages[-1] += i + + summaries = [] + for i in range(len(messages)): + message = messages[i] + summary = summarize(message) + summaries.append(summary) + + result = "\n".join(summaries) + + if ( + cls.get_prompt_tokens( + tenant_id=tenant.id, + prompt_messages=[UserPromptMessage(content=result)], + ) + > max_tokens * 0.7 + ): + return cls.invoke_summary( + user_id=user_id, + tenant=tenant, + payload=RequestInvokeSummary(text=result, instruction=payload.instruction), + ) + + return result diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index ae94bc95f6..d98b80ee43 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -186,3 +186,12 @@ class RequestInvokeEncrypt(BaseModel): identity: str data: dict = Field(default_factory=dict) config: list[BasicProviderConfig] = Field(default_factory=list) + + +class RequestInvokeSummary(BaseModel): + """ + Request to summary + """ + + text: str + instruction: str