diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 486bf766bd..f81bcf7c0b 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import Any from flask_restx import Resource from pydantic import BaseModel, Field @@ -11,12 +12,10 @@ from controllers.console.app.error import ( ProviderQuotaExceededError, ) from controllers.console.wraps import account_initialization_required, setup_required -from core.app.app_config.entities import ModelConfig from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider -from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db @@ -27,13 +26,28 @@ from services.workflow_service import WorkflowService DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" +class RuleGeneratePayload(BaseModel): + instruction: str = Field(..., description="Rule generation instruction") + model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") + no_variable: bool = Field(default=False, description="Whether to exclude variables") + + +class RuleCodeGeneratePayload(RuleGeneratePayload): + code_language: str = Field(default="javascript", description="Programming language for code generation") + + +class RuleStructuredOutputPayload(BaseModel): + instruction: str = Field(..., description="Structured output generation instruction") + model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") + + class InstructionGeneratePayload(BaseModel): flow_id: str = Field(..., description="Workflow/Flow ID") node_id: str = Field(default="", description="Node ID for workflow context") current: str = Field(default="", description="Current instruction text") language: str = Field(default="javascript", description="Programming language (javascript/python)") instruction: str = Field(..., description="Instruction for generation") - model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration") + model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration") ideal_output: str = Field(default="", description="Expected ideal output") @@ -50,7 +64,6 @@ reg(RuleCodeGeneratePayload) reg(RuleStructuredOutputPayload) reg(InstructionGeneratePayload) reg(InstructionTemplatePayload) -reg(ModelConfig) @console_ns.route("/rule-generate") @@ -65,28 +78,27 @@ class RuleGenerateApi(Resource): @login_required @account_initialization_required def post(self): - args = RuleGeneratePayload.model_validate(console_ns.payload) - account, current_tenant_id = current_account_with_tenant() + args = RuleGeneratePayload.model_validate(console_ns.payload) + account, current_tenant_id = current_account_with_tenant() - try: - rules = LLMGenerator.generate_rule_config( - tenant_id=current_tenant_id, - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=args.no_variable, - user_id=account.id, - app_id=None, - ) - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except InvokeError as e: - raise CompletionRequestError(e.description) + try: + rules = LLMGenerator.generate_rule_config( + tenant_id=current_tenant_id, + instruction=args.instruction, + model_config=args.model_config_data, + no_variable=args.no_variable, + user_id=account.id, + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) - return rules + return rules @console_ns.route("/rule-code-generate") @@ -111,7 +123,6 @@ class RuleCodeGenerateApi(Resource): model_config=args.model_config_data, code_language=args.code_language, user_id=account.id, - app_id=None, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -143,14 +154,9 @@ class RuleStructuredOutputGenerateApi(Resource): try: structured_output = LLMGenerator.generate_structured_output( tenant_id=current_tenant_id, -<<<<<<< HEAD - args=args, -======= instruction=args.instruction, model_config=args.model_config_data, user_id=account.id, - app_id=None, ->>>>>>> c56e5a5b71 (feat(telemetry): add prompt generation telemetry to Enterprise OTEL) ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -200,53 +206,29 @@ class InstructionGenerateApi(Resource): case "llm": return LLMGenerator.generate_rule_config( current_tenant_id, -<<<<<<< HEAD - args=RuleGeneratePayload( - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, - ), -======= instruction=args.instruction, model_config=args.model_config_data, no_variable=True, user_id=account.id, app_id=args.flow_id, ->>>>>>> c56e5a5b71 (feat(telemetry): add prompt generation telemetry to Enterprise OTEL) ) case "agent": return LLMGenerator.generate_rule_config( current_tenant_id, -<<<<<<< HEAD - args=RuleGeneratePayload( - instruction=args.instruction, - model_config=args.model_config_data, - no_variable=True, - ), -======= instruction=args.instruction, model_config=args.model_config_data, no_variable=True, user_id=account.id, app_id=args.flow_id, ->>>>>>> c56e5a5b71 (feat(telemetry): add prompt generation telemetry to Enterprise OTEL) ) case "code": return LLMGenerator.generate_code( tenant_id=current_tenant_id, -<<<<<<< HEAD - args=RuleCodeGeneratePayload( - instruction=args.instruction, - model_config=args.model_config_data, - code_language=args.language, - ), -======= instruction=args.instruction, model_config=args.model_config_data, code_language=args.language, user_id=account.id, app_id=args.flow_id, ->>>>>>> c56e5a5b71 (feat(telemetry): add prompt generation telemetry to Enterprise OTEL) ) case _: return {"error": f"invalid node type: {node_type}"}