From 8340d775bd0806b34ddf9789c76998ce04b60b94 Mon Sep 17 00:00:00 2001 From: quicksand Date: Fri, 25 Jul 2025 09:00:26 +0800 Subject: [PATCH] Improve: support custom model parameters in auto-generator (#22924) --- api/controllers/console/app/generator.py | 7 ------- api/core/llm_generator/llm_generator.py | 17 ++++------------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 790369c052..4847a2cab8 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,5 +1,3 @@ -import os - from flask_login import current_user from flask_restful import Resource, reqparse @@ -29,15 +27,12 @@ class RuleGenerateApi(Resource): args = parser.parse_args() account = current_user - PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512")) - try: rules = LLMGenerator.generate_rule_config( tenant_id=account.current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], no_variable=args["no_variable"], - rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -64,14 +59,12 @@ class RuleCodeGenerateApi(Resource): args = parser.parse_args() account = current_user - CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024")) try: code_result = LLMGenerator.generate_code( tenant_id=account.current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], code_language=args["code_language"], - max_tokens=CODE_GENERATION_MAX_TOKENS, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 331ac933c8..80f0457962 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -125,16 +125,13 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config( - cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512 - ) -> dict: + def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict: output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} - model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01} - + model_parameters = model_config.get("completion_params", {}) if no_variable: prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) @@ -276,12 +273,7 @@ class LLMGenerator: @classmethod def generate_code( - cls, - tenant_id: str, - instruction: str, - model_config: dict, - code_language: str = "javascript", - max_tokens: int = 1000, + cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript" ) -> dict: if code_language == "python": prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) @@ -305,8 +297,7 @@ class LLMGenerator: ) prompt_messages = [UserPromptMessage(content=prompt)] - model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} - + model_parameters = model_config.get("completion_params", {}) try: response = cast( LLMResult,