diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 7dde4f0148..f785b2aed6 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -7,6 +7,7 @@ from controllers.inner_api import api from controllers.inner_api.plugin.wraps import get_tenant, plugin_data from controllers.inner_api.wraps import plugin_inner_api_only from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation +from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation from core.plugin.encrypt import PluginEncrypter @@ -47,11 +48,16 @@ class PluginInvokeTextEmbeddingApi(Resource): @get_tenant @plugin_data(payload_type=RequestInvokeTextEmbedding) def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding): - return PluginModelBackwardsInvocation.invoke_text_embedding( - user_id=user_id, - tenant=tenant_model, - payload=payload, - ) + try: + return BaseBackwardsInvocationResponse( + data=PluginModelBackwardsInvocation.invoke_text_embedding( + user_id=user_id, + tenant=tenant_model, + payload=payload, + ) + ).model_dump() + except Exception as e: + return BaseBackwardsInvocationResponse(error=str(e)).model_dump() class PluginInvokeRerankApi(Resource): @@ -60,7 +66,16 @@ class PluginInvokeRerankApi(Resource): @get_tenant @plugin_data(payload_type=RequestInvokeRerank) def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeRerank): - pass + try: + return BaseBackwardsInvocationResponse( + data=PluginModelBackwardsInvocation.invoke_rerank( + user_id=user_id, + tenant=tenant_model, + payload=payload, + ) + ).model_dump() + except Exception as e: + return BaseBackwardsInvocationResponse(error=str(e)).model_dump() class PluginInvokeTTSApi(Resource): @@ -69,7 +84,15 @@ class PluginInvokeTTSApi(Resource): @get_tenant @plugin_data(payload_type=RequestInvokeTTS) def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTTS): - pass + def generator(): + response = PluginModelBackwardsInvocation.invoke_tts( + user_id=user_id, + tenant=tenant_model, + payload=payload, + ) + return PluginModelBackwardsInvocation.convert_to_event_stream(response) + + return compact_generate_response(generator()) class PluginInvokeSpeech2TextApi(Resource): @@ -78,7 +101,16 @@ class PluginInvokeSpeech2TextApi(Resource): @get_tenant @plugin_data(payload_type=RequestInvokeSpeech2Text) def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSpeech2Text): - pass + try: + return BaseBackwardsInvocationResponse( + data=PluginModelBackwardsInvocation.invoke_speech2text( + user_id=user_id, + tenant=tenant_model, + payload=payload, + ) + ).model_dump() + except Exception as e: + return BaseBackwardsInvocationResponse(error=str(e)).model_dump() class PluginInvokeModerationApi(Resource): @@ -87,7 +119,16 @@ class PluginInvokeModerationApi(Resource): @get_tenant @plugin_data(payload_type=RequestInvokeModeration) def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeModeration): - pass + try: + return BaseBackwardsInvocationResponse( + data=PluginModelBackwardsInvocation.invoke_moderation( + user_id=user_id, + tenant=tenant_model, + payload=payload, + ) + ).model_dump() + except Exception as e: + return BaseBackwardsInvocationResponse(error=str(e)).model_dump() class PluginInvokeToolApi(Resource): @@ -118,14 +159,19 @@ class PluginInvokeParameterExtractorNodeApi(Resource): @get_tenant @plugin_data(payload_type=RequestInvokeParameterExtractorNode) def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode): - return PluginNodeBackwardsInvocation.invoke_parameter_extractor( - tenant_id=tenant_model.id, - user_id=user_id, - parameters=payload.parameters, - model_config=payload.model, - instruction=payload.instruction, - query=payload.query, - ) + try: + return BaseBackwardsInvocationResponse( + data=PluginNodeBackwardsInvocation.invoke_parameter_extractor( + tenant_id=tenant_model.id, + user_id=user_id, + parameters=payload.parameters, + model_config=payload.model, + instruction=payload.instruction, + query=payload.query, + ) + ).model_dump() + except Exception as e: + return BaseBackwardsInvocationResponse(error=str(e)).model_dump() class PluginInvokeQuestionClassifierNodeApi(Resource): @@ -134,14 +180,19 @@ class PluginInvokeQuestionClassifierNodeApi(Resource): @get_tenant @plugin_data(payload_type=RequestInvokeQuestionClassifierNode) def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode): - return PluginNodeBackwardsInvocation.invoke_question_classifier( - tenant_id=tenant_model.id, - user_id=user_id, - query=payload.query, - model_config=payload.model, - classes=payload.classes, - instruction=payload.instruction, - ) + try: + return BaseBackwardsInvocationResponse( + data=PluginNodeBackwardsInvocation.invoke_question_classifier( + tenant_id=tenant_model.id, + user_id=user_id, + query=payload.query, + model_config=payload.model, + classes=payload.classes, + instruction=payload.instruction, + ) + ).model_dump() + except Exception as e: + return BaseBackwardsInvocationResponse(error=str(e)).model_dump() class PluginInvokeAppApi(Resource): @@ -173,7 +224,12 @@ class PluginInvokeEncryptApi(Resource): """ encrypt or decrypt data """ - return PluginEncrypter.invoke_encrypt(tenant_model, payload) + try: + return BaseBackwardsInvocationResponse( + data=PluginEncrypter.invoke_encrypt(tenant_model, payload) + ).model_dump() + except Exception as e: + return BaseBackwardsInvocationResponse(error=str(e)).model_dump() api.add_resource(PluginInvokeLLMApi, "/invoke/llm") diff --git a/api/core/plugin/backwards_invocation/base.py b/api/core/plugin/backwards_invocation/base.py index 7b699b4d67..2ec71fdc5b 100644 --- a/api/core/plugin/backwards_invocation/base.py +++ b/api/core/plugin/backwards_invocation/base.py @@ -1,5 +1,6 @@ import json from collections.abc import Generator +from typing import Generic, Optional, TypeVar from pydantic import BaseModel @@ -8,15 +9,28 @@ class BaseBackwardsInvocation: @classmethod def convert_to_event_stream(cls, response: Generator[BaseModel | dict | str, None, None] | BaseModel | dict): if isinstance(response, Generator): - for chunk in response: - if isinstance(chunk, BaseModel): - yield chunk.model_dump_json().encode() + b'\n\n' - elif isinstance(chunk, str): - yield f"event: {chunk}\n\n".encode() - else: - yield json.dumps(chunk).encode() + b'\n\n' + try: + for chunk in response: + if isinstance(chunk, BaseModel): + yield BaseBackwardsInvocationResponse(data=chunk).model_dump_json().encode() + b"\n\n" + + elif isinstance(chunk, str): + yield f"event: {chunk}\n\n".encode() + else: + yield json.dumps(chunk).encode() + b"\n\n" + except Exception as e: + error_message = BaseBackwardsInvocationResponse(error=str(e)).model_dump_json() + yield f"{error_message}\n\n".encode() else: if isinstance(response, BaseModel): - yield response.model_dump_json().encode() + b'\n\n' + yield response.model_dump_json().encode() + b"\n\n" else: - yield json.dumps(response).encode() + b'\n\n' + yield json.dumps(response).encode() + b"\n\n" + + +T = TypeVar("T", bound=BaseModel | dict | str | bool | int) + + +class BaseBackwardsInvocationResponse(BaseModel, Generic[T]): + data: Optional[T] = None + error: str = "" diff --git a/api/core/plugin/encrypt/__init__.py b/api/core/plugin/encrypt/__init__.py index 313d161ec9..95c416d28c 100644 --- a/api/core/plugin/encrypt/__init__.py +++ b/api/core/plugin/encrypt/__init__.py @@ -8,7 +8,7 @@ from models.account import Tenant class PluginEncrypter: @classmethod - def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> Mapping[str, Any]: + def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: encrypter = ProviderConfigEncrypter( tenant_id=tenant.id, config=payload.data, @@ -16,16 +16,7 @@ class PluginEncrypter: provider_identity=payload.identity, ) - try: - if payload.opt == "encrypt": - return { - "data": encrypter.encrypt(payload.data), - } - else: - return { - "data": encrypter.decrypt(payload.data), - } - except Exception as e: - return { - "error": str(e), - } + if payload.opt == "encrypt": + return encrypter.encrypt(payload.data) + else: + return encrypter.decrypt(payload.data)