diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 152ff3b648..f5ba498c7d 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -23,6 +23,9 @@ jobs: uv run ruff check --fix-only . # Format code uv run ruff format . + - name: ast-grep + run: | + uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 diff --git a/.gitignore b/.gitignore index 5c68d89a4d..30432c4302 100644 --- a/.gitignore +++ b/.gitignore @@ -197,6 +197,8 @@ sdks/python-client/dify_client.egg-info !.vscode/README.md pyrightconfig.json api/.vscode +# vscode Code History Extension +.history .idea/ diff --git a/api/.env.example b/api/.env.example index 3c30872422..3052dbfe2b 100644 --- a/api/.env.example +++ b/api/.env.example @@ -478,6 +478,13 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node # API workflow run repository implementation API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository +# Workflow log cleanup configuration +# Enable automatic cleanup of workflow run logs to manage database size +WORKFLOW_LOG_CLEANUP_ENABLED=true +# Number of days to retain workflow run logs (default: 30 days) +WORKFLOW_LOG_RETENTION_DAYS=30 +# Batch size for workflow log cleanup operations (default: 100) +WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100 # App configuration APP_MAX_EXECUTION_TIME=1200 diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 0b2f99aece..2bccc4b7a0 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -968,6 +968,14 @@ class AccountConfig(BaseSettings): ) +class WorkflowLogConfig(BaseSettings): + WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup") + WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs") + WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field( + default=100, description="Batch size for workflow run log cleanup operations" + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -1003,5 +1011,6 @@ class FeatureConfig( HostedServiceConfig, CeleryBeatConfig, CeleryScheduleTasksConfig, + WorkflowLogConfig, ): pass diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py index 008f1f0f7a..6a5197635e 100644 --- a/api/controllers/common/helpers.py +++ b/api/controllers/common/helpers.py @@ -1,3 +1,4 @@ +import contextlib import mimetypes import os import platform @@ -65,10 +66,8 @@ def guess_file_info_from_response(response: httpx.Response): # Use python-magic to guess MIME type if still unknown or generic if mimetype == "application/octet-stream" and magic is not None: - try: + with contextlib.suppress(magic.MagicException): mimetype = magic.from_buffer(response.content[:1024], mime=True) - except magic.MagicException: - pass extension = os.path.splitext(filename)[1] diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 493a9a52e2..2caa908d4a 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask import request from flask_login import current_user from flask_restful import Resource, marshal, marshal_with, reqparse @@ -24,7 +26,7 @@ class AnnotationReplyActionApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") - def post(self, app_id, action): + def post(self, app_id, action: Literal["enable", "disable"]): if not current_user.is_editor: raise Forbidden() @@ -38,8 +40,6 @@ class AnnotationReplyActionApi(Resource): result = AppAnnotationService.enable_app_annotation(args, app_id) elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_id) - else: - raise ValueError("Unsupported annotation reply action") return result, 200 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 4847a2cab8..57dc1267d5 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + from flask_login import current_user from flask_restful import Resource, reqparse @@ -10,6 +12,8 @@ from controllers.console.app.error import ( ) from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider +from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required @@ -107,6 +111,121 @@ class RuleStructuredOutputGenerateApi(Resource): return structured_output +class InstructionGenerateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("flow_id", type=str, required=True, default="", location="json") + parser.add_argument("node_id", type=str, required=False, default="", location="json") + parser.add_argument("current", type=str, required=False, default="", location="json") + parser.add_argument("language", type=str, required=False, default="javascript", location="json") + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + parser.add_argument("ideal_output", type=str, required=False, default="", location="json") + args = parser.parse_args() + code_template = ( + Python3CodeProvider.get_default_code() + if args["language"] == "python" + else (JavascriptCodeProvider.get_default_code()) + if args["language"] == "javascript" + else "" + ) + try: + # Generate from nothing for a workflow node + if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "": + from models import App, db + from services.workflow_service import WorkflowService + + app = db.session.query(App).where(App.id == args["flow_id"]).first() + if not app: + return {"error": f"app {args['flow_id']} not found"}, 400 + workflow = WorkflowService().get_draft_workflow(app_model=app) + if not workflow: + return {"error": f"workflow {args['flow_id']} not found"}, 400 + nodes: Sequence = workflow.graph_dict["nodes"] + node = [node for node in nodes if node["id"] == args["node_id"]] + if len(node) == 0: + return {"error": f"node {args['node_id']} not found"}, 400 + node_type = node[0]["data"]["type"] + match node_type: + case "llm": + return LLMGenerator.generate_rule_config( + current_user.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + no_variable=True, + ) + case "agent": + return LLMGenerator.generate_rule_config( + current_user.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + no_variable=True, + ) + case "code": + return LLMGenerator.generate_code( + tenant_id=current_user.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + code_language=args["language"], + ) + case _: + return {"error": f"invalid node type: {node_type}"} + if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow + return LLMGenerator.instruction_modify_legacy( + tenant_id=current_user.current_tenant_id, + flow_id=args["flow_id"], + current=args["current"], + instruction=args["instruction"], + model_config=args["model_config"], + ideal_output=args["ideal_output"], + ) + if args["node_id"] != "" and args["current"] != "": # For workflow node + return LLMGenerator.instruction_modify_workflow( + tenant_id=current_user.current_tenant_id, + flow_id=args["flow_id"], + node_id=args["node_id"], + current=args["current"], + instruction=args["instruction"], + model_config=args["model_config"], + ideal_output=args["ideal_output"], + ) + return {"error": "incompatible parameters"}, 400 + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + + +class InstructionGenerationTemplateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self) -> dict: + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, default=False, location="json") + args = parser.parse_args() + match args["type"]: + case "prompt": + from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT + + return {"data": INSTRUCTION_GENERATE_TEMPLATE_PROMPT} + case "code": + from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_CODE + + return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE} + case _: + raise ValueError(f"Invalid type: {args['type']}") + + api.add_resource(RuleGenerateApi, "/rule-generate") api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate") +api.add_resource(InstructionGenerateApi, "/instruction-generate") +api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 36a5444825..d4e67409fc 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,7 +1,7 @@ import json import logging from argparse import ArgumentTypeError -from typing import cast +from typing import Literal, cast from flask import request from flask_login import current_user @@ -761,7 +761,7 @@ class DocumentProcessingApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id, document_id, action): + def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]): dataset_id = str(dataset_id) document_id = str(document_id) document = self.get_document(dataset_id, document_id) @@ -787,8 +787,6 @@ class DocumentProcessingApi(DocumentResource): document.paused_at = None document.is_paused = False db.session.commit() - else: - raise InvalidActionError() return {"result": "success"}, 200 @@ -843,7 +841,7 @@ class DocumentStatusApi(DocumentResource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id, action): + def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 65f76fb402..1b5570285d 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import NotFound @@ -100,7 +102,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): @login_required @account_initialization_required @enterprise_license_required - def post(self, dataset_id, action): + def post(self, dataset_id, action: Literal["enable", "disable"]): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: diff --git a/api/controllers/console/datasets/upload_file.py b/api/controllers/console/datasets/upload_file.py index 9b456c771d..2afdaf7f2b 100644 --- a/api/controllers/console/datasets/upload_file.py +++ b/api/controllers/console/datasets/upload_file.py @@ -39,7 +39,7 @@ class UploadFileApi(Resource): data_source_info = document.data_source_info_dict if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] - upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("UploadFile not found.") else: diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 9b22c535f4..23446bb702 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask import request from flask_restful import Resource, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden @@ -15,7 +17,7 @@ from services.annotation_service import AppAnnotationService class AnnotationReplyActionApi(Resource): @validate_app_token - def post(self, app_model: App, action): + def post(self, app_model: App, action: Literal["enable", "disable"]): parser = reqparse.RequestParser() parser.add_argument("score_threshold", required=True, type=float, location="json") parser.add_argument("embedding_provider_name", required=True, type=str, location="json") @@ -25,8 +27,6 @@ class AnnotationReplyActionApi(Resource): result = AppAnnotationService.enable_app_annotation(args, app_model.id) elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_model.id) - else: - raise ValueError("Unsupported annotation reply action") return result, 200 diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 29eef41253..35b1efeff6 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask import request from flask_restful import marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -358,14 +360,14 @@ class DatasetApi(DatasetApiResource): class DocumentStatusApi(DatasetApiResource): """Resource for batch document status operations.""" - def patch(self, tenant_id, dataset_id, action): + def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): """ Batch update document status. Args: tenant_id: tenant id dataset_id: dataset id - action: action to perform (enable, disable, archive, un_archive) + action: action to perform (Literal["enable", "disable", "archive", "un_archive"]) Returns: dict: A dictionary with a key 'result' and a value 'success' diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 6ba818c5fc..75a0b18285 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -1,3 +1,5 @@ +from typing import Literal + from flask_login import current_user # type: ignore from flask_restful import marshal, reqparse from werkzeug.exceptions import NotFound @@ -77,7 +79,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, action): + def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index f3b9dbf758..0d786ba051 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -181,7 +181,7 @@ class MessageCycleManager: :param message_id: message id :return: """ - message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first() + message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first() event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE return MessageStreamResponse( diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 47e5a79160..503f8e3e89 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,6 +1,7 @@ import json import logging import re +from collections.abc import Sequence from typing import Optional, cast import json_repair @@ -11,6 +12,8 @@ from core.llm_generator.prompts import ( CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, + LLM_MODIFY_CODE_SYSTEM, + LLM_MODIFY_PROMPT_SYSTEM, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, SYSTEM_STRUCTURED_OUTPUT_GENERATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, @@ -24,6 +27,9 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from core.workflow.graph_engine.entities.event import AgentLogEvent +from models import App, Message, WorkflowNodeExecutionModel, db class LLMGenerator: @@ -388,3 +394,181 @@ class LLMGenerator: except Exception as e: logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name")) return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} + + @staticmethod + def instruction_modify_legacy( + tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None + ) -> dict: + app: App | None = db.session.query(App).where(App.id == flow_id).first() + last_run: Message | None = ( + db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() + ) + if not last_run: + return LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=None, + current=current, + error_message="", + instruction=instruction, + node_type="llm", + ideal_output=ideal_output, + ) + last_run_dict = { + "query": last_run.query, + "answer": last_run.answer, + "error": last_run.error, + } + return LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=last_run_dict, + current=current, + error_message=str(last_run.error), + instruction=instruction, + node_type="llm", + ideal_output=ideal_output, + ) + + @staticmethod + def instruction_modify_workflow( + tenant_id: str, + flow_id: str, + node_id: str, + current: str, + instruction: str, + model_config: dict, + ideal_output: str | None, + ) -> dict: + from services.workflow_service import WorkflowService + + app: App | None = db.session.query(App).where(App.id == flow_id).first() + if not app: + raise ValueError("App not found.") + workflow = WorkflowService().get_draft_workflow(app_model=app) + if not workflow: + raise ValueError("Workflow not found for the given app model.") + last_run = WorkflowService().get_node_last_run(app_model=app, workflow=workflow, node_id=node_id) + try: + node_type = cast(WorkflowNodeExecutionModel, last_run).node_type + except Exception: + try: + node_type = [it for it in workflow.graph_dict["graph"]["nodes"] if it["id"] == node_id][0]["data"][ + "type" + ] + except Exception: + node_type = "llm" + + if not last_run: # Node is not executed yet + return LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=None, + current=current, + error_message="", + instruction=instruction, + node_type=node_type, + ideal_output=ideal_output, + ) + + def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence: + raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG) + if not raw_agent_log: + return [] + parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log) + + def dict_of_event(event: AgentLogEvent) -> dict: + return { + "status": event.status, + "error": event.error, + "data": event.data, + } + + return [dict_of_event(event) for event in parsed] + + last_run_dict = { + "inputs": last_run.inputs_dict, + "status": last_run.status, + "error": last_run.error, + "agent_log": agent_log_of(last_run), + } + + return LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=last_run_dict, + current=current, + error_message=last_run.error, + instruction=instruction, + node_type=last_run.node_type, + ideal_output=ideal_output, + ) + + @staticmethod + def __instruction_modify_common( + tenant_id: str, + model_config: dict, + last_run: dict | None, + current: str | None, + error_message: str | None, + instruction: str, + node_type: str, + ideal_output: str | None, + ) -> dict: + LAST_RUN = "{{#last_run#}}" + CURRENT = "{{#current#}}" + ERROR_MESSAGE = "{{#error_message#}}" + injected_instruction = instruction + if LAST_RUN in injected_instruction: + injected_instruction = injected_instruction.replace(LAST_RUN, json.dumps(last_run)) + if CURRENT in injected_instruction: + injected_instruction = injected_instruction.replace(CURRENT, current or "null") + if ERROR_MESSAGE in injected_instruction: + injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null") + model_instance = ModelManager().get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + match node_type: + case "llm", "agent": + system_prompt = LLM_MODIFY_PROMPT_SYSTEM + case "code": + system_prompt = LLM_MODIFY_CODE_SYSTEM + case _: + system_prompt = LLM_MODIFY_PROMPT_SYSTEM + prompt_messages = [ + SystemPromptMessage(content=system_prompt), + UserPromptMessage( + content=json.dumps( + { + "current": current, + "last_run": last_run, + "instruction": injected_instruction, + "ideal_output": ideal_output, + } + ) + ), + ] + model_parameters = {"temperature": 0.4} + + try: + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ), + ) + + generated_raw = cast(str, response.message.content) + first_brace = generated_raw.find("{") + last_brace = generated_raw.rfind("}") + return {**json.loads(generated_raw[first_brace : last_brace + 1])} + + except InvokeError as e: + error = str(e) + return {"error": f"Failed to generate code. Error: {error}"} + except Exception as e: + logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e) + return {"error": f"An unexpected error occurred: {str(e)}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ef81e38dc5..9268347526 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -309,3 +309,116 @@ eg: Here is the JSON schema: {{schema}} """ # noqa: E501 + +LLM_MODIFY_PROMPT_SYSTEM = """ +Both your input and output should be in JSON format. + +! Below is the schema for input content ! +{ + "type": "object", + "description": "The user is trying to process some content with a prompt, but the output is not as expected. They hope to achieve their goal by modifying the prompt.", + "properties": { + "current": { + "type": "string", + "description": "The prompt before modification, where placeholders {{}} will be replaced with actual values for the large language model. The content in the placeholders should not be changed." + }, + "last_run": { + "type": "object", + "description": "The output result from the large language model after receiving the prompt.", + }, + "instruction": { + "type": "string", + "description": "User's instruction to edit the current prompt" + }, + "ideal_output": { + "type": "string", + "description": "The ideal output that the user expects from the large language model after modifying the prompt. You should compare the last output with the ideal output and make changes to the prompt to achieve the goal." + } + } +} +! Above is the schema for input content ! + +! Below is the schema for output content ! +{ + "type": "object", + "description": "Your feedback to the user after they provide modification suggestions.", + "properties": { + "modified": { + "type": "string", + "description": "Your modified prompt. You should change the original prompt as little as possible to achieve the goal. Keep the language of prompt if not asked to change" + }, + "message": { + "type": "string", + "description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user." + } + }, + "required": [ + "modified", + "message" + ] +} +! Above is the schema for output content ! + +Your output must strictly follow the schema format, do not output any content outside of the JSON body. +""" # noqa: E501 + +LLM_MODIFY_CODE_SYSTEM = """ +Both your input and output should be in JSON format. + +! Below is the schema for input content ! +{ + "type": "object", + "description": "The user is trying to process some data with a code snippet, but the result is not as expected. They hope to achieve their goal by modifying the code.", + "properties": { + "current": { + "type": "string", + "description": "The code before modification." + }, + "last_run": { + "type": "object", + "description": "The result of the code.", + }, + "message": { + "type": "string", + "description": "User's instruction to edit the current code" + } + } +} +! Above is the schema for input content ! + +! Below is the schema for output content ! +{ + "type": "object", + "description": "Your feedback to the user after they provide modification suggestions.", + "properties": { + "modified": { + "type": "string", + "description": "Your modified code. You should change the original code as little as possible to achieve the goal. Keep the programming language of code if not asked to change" + }, + "message": { + "type": "string", + "description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user." + } + }, + "required": [ + "modified", + "message" + ] +} +! Above is the schema for output content ! + +When you are modifying the code, you should remember: +- Do not use print, this not work in dify sandbox. +- Do not try dangerous call like deleting files. It's PROHIBITED. +- Do not use any library that is not built-in in with Python. +- Get inputs from the parameters of the function and have explicit type annotations. +- Write proper imports at the top of the code. +- Use return statement to return the result. +- You should return a `dict`. If you need to return a `result: str`, you should `return {"result": result}`. +Your output must strictly follow the schema format, do not output any content outside of the JSON body. +""" # noqa: E501 + +INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as expected: {{#last_run#}}. +You should edit the prompt according to the IDEAL OUTPUT.""" + +INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}.""" diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index d6dfe967d7..8efe105bbf 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -4,8 +4,8 @@ import math from typing import Any from pydantic import BaseModel, model_validator -from pyobvector import VECTOR, ObVecClient # type: ignore -from sqlalchemy import JSON, Column, String, func +from pyobvector import VECTOR, FtsIndexParam, FtsParser, ObVecClient, l2_distance # type: ignore +from sqlalchemy import JSON, Column, String from sqlalchemy.dialects.mysql import LONGTEXT from configs import dify_config @@ -119,14 +119,21 @@ class OceanBaseVector(BaseVector): ) try: if self._hybrid_search_enabled: - self._client.perform_raw_text_sql(f"""ALTER TABLE {self._collection_name} - ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER ik""") + self._client.create_fts_idx_with_fts_index_param( + table_name=self._collection_name, + fts_idx_param=FtsIndexParam( + index_name="fulltext_index_for_col_text", + field_names=["text"], + parser_type=FtsParser.IK, + ), + ) except Exception as e: raise Exception( "Failed to add fulltext index to the target table, your OceanBase version must be 4.3.5.1 or above " + "to support fulltext index and vector index in the same table", e, ) + self._client.refresh_metadata([self._collection_name]) redis_client.set(collection_exist_cache_key, 1, ex=3600) def _check_hybrid_search_support(self) -> bool: @@ -252,7 +259,7 @@ class OceanBaseVector(BaseVector): vec_column_name="vector", vec_data=query_vector, topk=topk, - distance_func=func.l2_distance, + distance_func=l2_distance, output_column_names=["text", "metadata"], with_dist=True, where_clause=_where_clause, diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 9741dd8b1d..fcf3a6d126 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -331,6 +331,12 @@ class QdrantVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from qdrant_client.http import models + score_threshold = float(kwargs.get("score_threshold") or 0.0) + if score_threshold >= 1: + # return empty list because some versions of qdrant may response with 400 bad request, + # and at the same time, the score_threshold with value 1 may be valid for other vector stores + return [] + filter = models.Filter( must=[ models.FieldCondition( @@ -355,7 +361,7 @@ class QdrantVector(BaseVector): limit=kwargs.get("top_k", 4), with_payload=True, with_vectors=True, - score_threshold=float(kwargs.get("score_threshold") or 0.0), + score_threshold=score_threshold, ) docs = [] for result in results: @@ -363,7 +369,6 @@ class QdrantVector(BaseVector): continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold - score_threshold = float(kwargs.get("score_threshold") or 0.0) if result.score > score_threshold: metadata["score"] = result.score doc = Document( diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 198f60e554..00e0bd9a16 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -145,13 +145,19 @@ def init_app(app: DifyApp) -> Celery: minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30 ), } - if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: + if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED: imports.append("schedule.check_upgradable_plugin_task") beat_schedule["check_upgradable_plugin_task"] = { "task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task", "schedule": crontab(minute="*/15"), } - + if dify_config.WORKFLOW_LOG_CLEANUP_ENABLED: + # 2:00 AM every day + imports.append("schedule.clean_workflow_runlogs_precise") + beat_schedule["clean_workflow_runlogs_precise"] = { + "task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise", + "schedule": crontab(minute="0", hour="2"), + } celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/pyproject.toml b/api/pyproject.toml index 61a725a830..ce642aa9c8 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -205,7 +205,7 @@ vdb = [ "pgvector==0.2.5", "pymilvus~=2.5.0", "pymochow==1.3.1", - "pyobvector~=0.1.6", + "pyobvector~=0.2.15", "qdrant-client==1.9.0", "tablestore==6.2.0", "tcvectordb~=1.6.4", diff --git a/api/schedule/clean_workflow_runlogs_precise.py b/api/schedule/clean_workflow_runlogs_precise.py new file mode 100644 index 0000000000..8c21be01dc --- /dev/null +++ b/api/schedule/clean_workflow_runlogs_precise.py @@ -0,0 +1,155 @@ +import datetime +import logging +import time + +import click + +import app +from configs import dify_config +from extensions.ext_database import db +from models.model import ( + AppAnnotationHitHistory, + Conversation, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.workflow import ConversationVariable, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun + +_logger = logging.getLogger(__name__) + + +MAX_RETRIES = 3 +BATCH_SIZE = dify_config.WORKFLOW_LOG_CLEANUP_BATCH_SIZE + + +@app.celery.task(queue="dataset") +def clean_workflow_runlogs_precise(): + """Clean expired workflow run logs with retry mechanism and complete message cascade""" + + click.echo(click.style("Start clean workflow run logs (precise mode with complete cascade).", fg="green")) + start_at = time.perf_counter() + + retention_days = dify_config.WORKFLOW_LOG_RETENTION_DAYS + cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days) + + try: + total_workflow_runs = db.session.query(WorkflowRun).where(WorkflowRun.created_at < cutoff_date).count() + if total_workflow_runs == 0: + _logger.info("No expired workflow run logs found") + return + _logger.info("Found %s expired workflow run logs to clean", total_workflow_runs) + + total_deleted = 0 + failed_batches = 0 + batch_count = 0 + + while True: + workflow_runs = ( + db.session.query(WorkflowRun.id).where(WorkflowRun.created_at < cutoff_date).limit(BATCH_SIZE).all() + ) + + if not workflow_runs: + break + + workflow_run_ids = [run.id for run in workflow_runs] + batch_count += 1 + + success = _delete_batch_with_retry(workflow_run_ids, failed_batches) + + if success: + total_deleted += len(workflow_run_ids) + failed_batches = 0 + else: + failed_batches += 1 + if failed_batches >= MAX_RETRIES: + _logger.error("Failed to delete batch after %s retries, aborting cleanup for today", MAX_RETRIES) + break + else: + # Calculate incremental delay times: 5, 10, 15 minutes + retry_delay_minutes = failed_batches * 5 + _logger.warning("Batch deletion failed, retrying in %s minutes...", retry_delay_minutes) + time.sleep(retry_delay_minutes * 60) + continue + + _logger.info("Cleanup completed: %s expired workflow run logs deleted", total_deleted) + + except Exception as e: + db.session.rollback() + _logger.exception("Unexpected error in workflow log cleanup") + raise + + end_at = time.perf_counter() + execution_time = end_at - start_at + click.echo(click.style(f"Cleaned workflow run logs from db success latency: {execution_time:.2f}s", fg="green")) + + +def _delete_batch_with_retry(workflow_run_ids: list[str], attempt_count: int) -> bool: + """Delete a single batch with a retry mechanism and complete cascading deletion""" + try: + with db.session.begin_nested(): + message_data = ( + db.session.query(Message.id, Message.conversation_id) + .filter(Message.workflow_run_id.in_(workflow_run_ids)) + .all() + ) + message_id_list = [msg.id for msg in message_data] + conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id}) + if message_id_list: + db.session.query(AppAnnotationHitHistory).where( + AppAnnotationHitHistory.message_id.in_(message_id_list) + ).delete(synchronize_session=False) + + db.session.query(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(MessageChain).where(MessageChain.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(MessageFile).where(MessageFile.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete( + synchronize_session=False + ) + + db.session.query(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete( + synchronize_session=False + ) + + db.session.query(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids) + ).delete(synchronize_session=False) + + if conversation_id_list: + db.session.query(ConversationVariable).where( + ConversationVariable.conversation_id.in_(conversation_id_list) + ).delete(synchronize_session=False) + + db.session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete( + synchronize_session=False + ) + + db.session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False) + + db.session.commit() + return True + + except Exception as e: + db.session.rollback() + _logger.exception("Batch deletion failed (attempt %s)", attempt_count + 1) + return False diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index b7a047914e..1a0fdfa420 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -293,7 +293,7 @@ class AppAnnotationService: annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete] # Step 2: Bulk delete hit histories in a single query - db.session.query(AppAnnotationHitHistory).filter( + db.session.query(AppAnnotationHitHistory).where( AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete) ).delete(synchronize_session=False) @@ -307,7 +307,7 @@ class AppAnnotationService: # Step 4: Bulk delete annotations in a single query deleted_count = ( db.session.query(MessageAnnotation) - .filter(MessageAnnotation.id.in_(annotation_ids_to_delete)) + .where(MessageAnnotation.id.in_(annotation_ids_to_delete)) .delete(synchronize_session=False) ) @@ -505,9 +505,9 @@ class AppAnnotationService: db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() ) - annotations_query = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id) + annotations_query = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id) for annotation in annotations_query.yield_per(100): - annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).filter( + annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).where( AppAnnotationHitHistory.annotation_id == annotation.id ) for annotation_hit_history in annotation_hit_histories_query.yield_per(100): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 089cbdbe9e..77c17f67fb 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import secrets import time import uuid from collections import Counter -from typing import Any, Optional +from typing import Any, Literal, Optional from flask_login import current_user from sqlalchemy import func, select @@ -55,7 +55,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, RagPipelineDatasetCreateEntity, ) -from services.errors.account import InvalidActionError, NoPermissionError +from services.errors.account import NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError @@ -2231,14 +2231,16 @@ class DocumentService: raise ValueError("Process rule segmentation max_tokens is invalid") @staticmethod - def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user): + def batch_update_document_status( + dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user + ): """ Batch update document status. Args: dataset (Dataset): The dataset object document_ids (list[str]): List of document IDs to update - action (str): Action to perform (enable, disable, archive, un_archive) + action (Literal["enable", "disable", "archive", "un_archive"]): Action to perform user: Current user performing the action Raises: @@ -2321,9 +2323,10 @@ class DocumentService: raise propagation_error @staticmethod - def _prepare_document_status_update(document, action: str, user): - """ - Prepare document status update information. + def _prepare_document_status_update( + document: Document, action: Literal["enable", "disable", "archive", "un_archive"], user + ): + """Prepare document status update information. Args: document: Document object to update @@ -2786,7 +2789,9 @@ class SegmentService: db.session.commit() @classmethod - def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document): + def update_segments_status( + cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document + ): # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return @@ -2844,8 +2849,6 @@ class SegmentService: db.session.commit() disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) - else: - raise InvalidActionError() @classmethod def create_child_chunk( diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 8c4c1876ad..5ab377c232 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -1,5 +1,6 @@ import logging import time +from typing import Literal import click from celery import shared_task # type: ignore @@ -13,7 +14,7 @@ from models.dataset import Document as DatasetDocument @shared_task(queue="dataset") -def deal_dataset_vector_index_task(dataset_id: str, action: str): +def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]): """ Async deal dataset from index :param dataset_id: dataset_id diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py index 61d9a9e712..fe0e03f7b8 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -1,4 +1,5 @@ from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector +from core.rag.models.document import Document from tests.integration_tests.vdb.test_vector_store import ( AbstractVectorTest, setup_mock_redis, @@ -18,6 +19,14 @@ class QdrantVectorTest(AbstractVectorTest): ), ) + def search_by_vector(self): + super().search_by_vector() + # only test for qdrant, may not work on other vector stores + hits_by_vector: list[Document] = self.vector.search_by_vector( + query_vector=self.example_embedding, score_threshold=1 + ) + assert len(hits_by_vector) == 0 + def test_qdrant_vector(setup_mock_redis): QdrantVectorTest().run_all_tests() diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 0ab5f398e3..8816698af8 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -471,7 +471,7 @@ class TestAnnotationService: # Verify annotation was deleted from extensions.ext_database import db - deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() assert deleted_annotation is None # Verify delete_annotation_index_task was called (when annotation setting exists) @@ -1175,7 +1175,7 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(app.id, annotation_id) # Verify annotation was deleted - deleted_annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() + deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() assert deleted_annotation is None # Verify delete_annotation_index_task was called diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 38f532fd64..6cd8337ff9 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -234,7 +234,7 @@ class TestAPIBasedExtensionService: # Verify extension was deleted from extensions.ext_database import db - deleted_extension = db.session.query(APIBasedExtension).filter(APIBasedExtension.id == extension_id).first() + deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first() assert deleted_extension is None def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py new file mode 100644 index 0000000000..8bd5440411 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -0,0 +1,1785 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from services.feature_service import FeatureModel, FeatureService, KnowledgeRateLimitModel, SystemFeatureModel + + +class TestFeatureService: + """Integration tests for FeatureService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.feature_service.BillingService") as mock_billing_service, + patch("services.feature_service.EnterpriseService") as mock_enterprise_service, + ): + # Setup default mock returns for BillingService + mock_billing_service.get_info.return_value = { + "enabled": True, + "subscription": {"plan": "pro", "interval": "monthly", "education": True}, + "members": {"size": 5, "limit": 10}, + "apps": {"size": 3, "limit": 20}, + "vector_space": {"size": 2, "limit": 10}, + "documents_upload_quota": {"size": 15, "limit": 100}, + "annotation_quota_limit": {"size": 8, "limit": 50}, + "docs_processing": "enhanced", + "can_replace_logo": True, + "model_load_balancing_enabled": True, + "knowledge_rate_limit": {"limit": 100}, + } + + mock_billing_service.get_knowledge_rate_limit.return_value = {"limit": 100, "subscription_plan": "pro"} + + # Setup default mock returns for EnterpriseService + mock_enterprise_service.get_workspace_info.return_value = { + "WorkspaceMembers": {"used": 5, "limit": 10, "enabled": True} + } + + mock_enterprise_service.get_info.return_value = { + "SSOEnforcedForSignin": True, + "SSOEnforcedForSigninProtocol": "saml", + "EnableEmailCodeLogin": True, + "EnableEmailPasswordLogin": False, + "IsAllowRegister": False, + "IsAllowCreateWorkspace": False, + "Branding": { + "applicationTitle": "Test Enterprise", + "loginPageLogo": "https://example.com/logo.png", + "workspaceLogo": "https://example.com/workspace.png", + "favicon": "https://example.com/favicon.ico", + }, + "WebAppAuth": {"allowSso": True, "allowEmailCodeLogin": True, "allowEmailPasswordLogin": False}, + "SSOEnforcedForWebProtocol": "oidc", + "License": { + "status": "active", + "expiredAt": "2025-12-31", + "workspaces": {"enabled": True, "limit": 5, "used": 2}, + }, + "PluginInstallationPermission": { + "pluginInstallationScope": "official_only", + "restrictToMarketplaceOnly": True, + }, + } + + yield { + "billing_service": mock_billing_service, + "enterprise_service": mock_enterprise_service, + } + + def _create_test_tenant_id(self): + """Helper method to create a test tenant ID.""" + fake = Faker() + return fake.uuid4() + + def test_get_features_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful feature retrieval with billing and enterprise enabled. + + This test verifies: + - Proper feature model creation with all required fields + - Correct integration with billing service + - Proper enterprise workspace information handling + - Return value correctness and structure + """ + # Arrange: Setup test data with proper config mocking + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = True + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = True + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing features + assert result.billing.enabled is True + assert result.billing.subscription.plan == "pro" + assert result.billing.subscription.interval == "monthly" + assert result.education.activated is True + + # Verify member limitations + assert result.members.size == 5 + assert result.members.limit == 10 + + # Verify app limitations + assert result.apps.size == 3 + assert result.apps.limit == 20 + + # Verify vector space limitations + assert result.vector_space.size == 2 + assert result.vector_space.limit == 10 + + # Verify document upload quota + assert result.documents_upload_quota.size == 15 + assert result.documents_upload_quota.limit == 100 + + # Verify annotation quota + assert result.annotation_quota_limit.size == 8 + assert result.annotation_quota_limit.limit == 50 + + # Verify other features + assert result.docs_processing == "enhanced" + assert result.can_replace_logo is True + assert result.model_load_balancing_enabled is True + assert result.knowledge_rate_limit == 100 + + # Verify enterprise features + assert result.workspace_members.enabled is True + assert result.workspace_members.size == 5 + assert result.workspace_members.limit == 10 + + # Verify webapp copyright is enabled for non-sandbox plans + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with( + tenant_id + ) + + def test_get_features_sandbox_plan(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval for sandbox plan with specific limitations. + + This test verifies: + - Proper handling of sandbox plan limitations + - Correct webapp copyright settings for sandbox + - Transfer workspace restrictions for sandbox plans + - Proper billing service integration + """ + # Arrange: Setup sandbox plan mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = False + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = False + mock_config.EDUCATION_ENABLED = False + + # Set mock return value inside the patch context + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "sandbox", "interval": "monthly", "education": False}, + "members": {"size": 1, "limit": 3}, + "apps": {"size": 1, "limit": 5}, + "vector_space": {"size": 1, "limit": 2}, + "documents_upload_quota": {"size": 5, "limit": 20}, + "annotation_quota_limit": {"size": 2, "limit": 10}, + "docs_processing": "standard", + "can_replace_logo": False, + "model_load_balancing_enabled": False, + "knowledge_rate_limit": {"limit": 10}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify sandbox-specific limitations + assert result.billing.subscription.plan == "sandbox" + assert result.education.activated is False + + # Verify sandbox limitations + assert result.members.size == 1 + assert result.members.limit == 3 + assert result.apps.size == 1 + assert result.apps.limit == 5 + assert result.vector_space.size == 1 + assert result.vector_space.limit == 2 + assert result.documents_upload_quota.size == 5 + assert result.documents_upload_quota.limit == 20 + assert result.annotation_quota_limit.size == 2 + assert result.annotation_quota_limit.limit == 10 + + # Verify sandbox-specific restrictions + assert result.webapp_copyright_enabled is False + assert result.is_allow_transfer_workspace is False + assert result.can_replace_logo is False + assert result.model_load_balancing_enabled is False + assert result.docs_processing == "standard" + assert result.knowledge_rate_limit == 10 + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_knowledge_rate_limit_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful knowledge rate limit retrieval with billing enabled. + + This test verifies: + - Proper knowledge rate limit model creation + - Correct integration with billing service + - Proper rate limit configuration + - Return value correctness and structure + """ + # Arrange: Setup test data with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_knowledge_rate_limit(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, KnowledgeRateLimitModel) + + # Verify rate limit configuration + assert result.enabled is True + assert result.limit == 100 + assert result.subscription_plan == "pro" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_called_once_with( + tenant_id + ) + + def test_get_system_features_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful system features retrieval with enterprise and marketplace enabled. + + This test verifies: + - Proper system feature model creation + - Correct integration with enterprise service + - Proper marketplace configuration + - Return value correctness and structure + """ + # Arrange: Setup test data with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = True + mock_config.ENABLE_EMAIL_CODE_LOGIN = True + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify SSO configuration + assert result.sso_enforced_for_signin is True + assert result.sso_enforced_for_signin_protocol == "saml" + + # Verify authentication settings + assert result.enable_email_code_login is True + assert result.enable_email_password_login is False + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + + # Verify branding configuration + assert result.branding.application_title == "Test Enterprise" + assert result.branding.login_page_logo == "https://example.com/logo.png" + assert result.branding.workspace_logo == "https://example.com/workspace.png" + assert result.branding.favicon == "https://example.com/favicon.ico" + + # Verify webapp auth configuration + assert result.webapp_auth.allow_sso is True + assert result.webapp_auth.allow_email_code_login is True + assert result.webapp_auth.allow_email_password_login is False + assert result.webapp_auth.sso_config.protocol == "oidc" + + # Verify license configuration + assert result.license.status.value == "active" + assert result.license.expired_at == "2025-12-31" + assert result.license.workspaces.enabled is True + assert result.license.workspaces.limit == 5 + assert result.license.workspaces.size == 2 + + # Verify plugin installation permission + assert result.plugin_installation_permission.plugin_installation_scope == "official_only" + assert result.plugin_installation_permission.restrict_to_marketplace_only is True + + # Verify marketplace configuration + assert result.enable_marketplace is True + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test system features retrieval with basic configuration (no enterprise). + + This test verifies: + - Proper system feature model creation without enterprise + - Correct environment variable handling + - Default configuration values + - Return value correctness and structure + """ + # Arrange: Setup basic config mock (no enterprise) + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = True + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = True + mock_config.ALLOW_CREATE_WORKSPACE = True + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify basic configuration + assert result.branding.enabled is False + assert result.webapp_auth.enabled is False + assert result.enable_change_email is True + + # Verify authentication settings from config + assert result.enable_email_code_login is True + assert result.enable_email_password_login is True + assert result.enable_social_oauth_login is False + assert result.is_allow_register is True + assert result.is_allow_create_workspace is True + assert result.is_email_setup is True + + # Verify marketplace configuration + assert result.enable_marketplace is False + + # Verify plugin package size (uses default value from dify_config) + assert result.max_plugin_package_size == 15728640 + + def test_get_features_billing_disabled(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval when billing is disabled. + + This test verifies: + - Proper feature model creation without billing + - Correct environment variable handling + - Default configuration values + - Return value correctness and structure + """ + # Arrange: Setup billing disabled mock + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = True + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = True + + tenant_id = self._create_test_tenant_id() + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing is disabled + assert result.billing.enabled is False + + # Verify environment-based features + assert result.can_replace_logo is True + assert result.model_load_balancing_enabled is True + assert result.dataset_operator_enabled is True + assert result.education.enabled is True + + # Verify default limitations + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 + assert result.knowledge_rate_limit == 10 + assert result.docs_processing == "standard" + + # Verify no enterprise features + assert result.workspace_members.enabled is False + assert result.webapp_copyright_enabled is False + + def test_get_knowledge_rate_limit_billing_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test knowledge rate limit retrieval when billing is disabled. + + This test verifies: + - Proper knowledge rate limit model creation without billing + - Default rate limit configuration + - Return value correctness and structure + """ + # Arrange: Setup billing disabled mock + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + + tenant_id = self._create_test_tenant_id() + + # Act: Execute the method under test + result = FeatureService.get_knowledge_rate_limit(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, KnowledgeRateLimitModel) + + # Verify default configuration + assert result.enabled is False + assert result.limit == 10 + assert result.subscription_plan == "" # Empty string when billing is disabled + + # Verify no billing service calls + mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_not_called() + + def test_get_features_enterprise_only(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with enterprise enabled but billing disabled. + + This test verifies: + - Proper feature model creation with enterprise only + - Correct enterprise service integration + - Proper workspace member handling + - Return value correctness and structure + """ + # Arrange: Setup enterprise only mock + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + mock_config.ENTERPRISE_ENABLED = True + mock_config.CAN_REPLACE_LOGO = False + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = False + mock_config.EDUCATION_ENABLED = False + + tenant_id = self._create_test_tenant_id() + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing is disabled + assert result.billing.enabled is False + + # Verify enterprise features + assert result.webapp_copyright_enabled is True + + # Verify workspace members from enterprise + assert result.workspace_members.enabled is True + assert result.workspace_members.size == 5 + assert result.workspace_members.limit == 10 + + # Verify environment-based features + assert result.can_replace_logo is False + assert result.model_load_balancing_enabled is False + assert result.dataset_operator_enabled is False + assert result.education.enabled is False + + # Verify default limitations + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with( + tenant_id + ) + mock_external_service_dependencies["billing_service"].get_info.assert_not_called() + + def test_get_system_features_enterprise_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval when enterprise is disabled. + + This test verifies: + - Proper system feature model creation without enterprise + - Correct environment variable handling + - Default configuration values + - Return value correctness and structure + """ + # Arrange: Setup enterprise disabled mock + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + mock_config.MARKETPLACE_ENABLED = True + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = True + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = None + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 50 + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify enterprise features are disabled + assert result.branding.enabled is False + assert result.webapp_auth.enabled is False + assert result.enable_change_email is True + + # Verify authentication settings from config + assert result.enable_email_code_login is False + assert result.enable_email_password_login is True + assert result.enable_social_oauth_login is True + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + assert result.is_email_setup is False + + # Verify marketplace configuration + assert result.enable_marketplace is True + + # Verify plugin package size (uses default value from dify_config) + assert result.max_plugin_package_size == 15728640 + + # Verify default license status + assert result.license.status.value == "none" + assert result.license.expired_at == "" + assert result.license.workspaces.enabled is False + + # Verify no enterprise service calls + mock_external_service_dependencies["enterprise_service"].get_info.assert_not_called() + + def test_get_features_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval without tenant ID (billing disabled). + + This test verifies: + - Proper feature model creation without tenant ID + - Correct handling when billing is disabled + - Default configuration values + - Return value correctness and structure + """ + # Arrange: Setup no tenant ID scenario + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + # Act: Execute the method under test + result = FeatureService.get_features("") + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing is disabled due to no tenant ID + assert result.billing.enabled is False + + # Verify environment-based features + assert result.can_replace_logo is True + assert result.model_load_balancing_enabled is False + assert result.dataset_operator_enabled is True + assert result.education.enabled is False + + # Verify default limitations + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + + # Verify no billing service calls + mock_external_service_dependencies["billing_service"].get_info.assert_not_called() + + def test_get_features_partial_billing_info(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with partial billing information. + + This test verifies: + - Proper handling of partial billing data + - Correct fallback to default values + - Proper billing service integration + - Return value correctness and structure + """ + # Arrange: Setup partial billing info mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "basic", "interval": "yearly"}, + # Missing members, apps, vector_space, etc. + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing features + assert result.billing.enabled is True + assert result.billing.subscription.plan == "basic" + assert result.billing.subscription.interval == "yearly" + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 + assert result.knowledge_rate_limit == 10 + assert result.docs_processing == "standard" + + # Verify basic plan restrictions (non-sandbox plans have webapp copyright enabled) + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_features_edge_case_vector_space(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with edge case vector space configuration. + + This test verifies: + - Proper handling of vector space quota limits + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case vector space mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "pro", "interval": "monthly"}, + "vector_space": {"size": 0, "limit": 0}, + "apps": {"size": 5, "limit": 10}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify vector space configuration + assert result.vector_space.size == 0 + assert result.vector_space.limit == 0 + + # Verify apps configuration + assert result.apps.size == 5 + assert result.apps.limit == 10 + + # Verify pro plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 + assert result.knowledge_rate_limit == 10 + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_system_features_edge_case_webapp_auth( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with edge case webapp auth configuration. + + This test verifies: + - Proper handling of webapp auth configuration + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case webapp auth mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "WebAppAuth": {"allowSso": False, "allowEmailCodeLogin": True, "allowEmailPasswordLogin": False} + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify webapp auth configuration + assert result.webapp_auth.allow_sso is False + assert result.webapp_auth.allow_email_code_login is True + assert result.webapp_auth.allow_email_password_login is False + assert result.webapp_auth.sso_config.protocol == "" + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify default values for missing enterprise info + assert result.sso_enforced_for_signin is False + assert result.sso_enforced_for_signin_protocol == "" + assert result.enable_email_code_login is False + assert result.enable_email_password_login is True + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_members_quota(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with edge case members quota configuration. + + This test verifies: + - Proper handling of members quota limits + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case members quota mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "basic", "interval": "yearly"}, + "members": {"size": 10, "limit": 10}, + "vector_space": {"size": 3, "limit": 5}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify members configuration + assert result.members.size == 10 + assert result.members.limit == 10 + + # Verify vector space configuration + assert result.vector_space.size == 3 + assert result.vector_space.limit == 5 + + # Verify basic plan features (non-sandbox plans have webapp copyright enabled) + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 + assert result.knowledge_rate_limit == 10 + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_plugin_installation_permission_scopes( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with different plugin installation permission scopes. + + This test verifies: + - Proper handling of different plugin installation scopes + - Correct enterprise service integration + - Proper permission configuration + - Return value correctness and structure + """ + + # Test case 1: Official only scope + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "PluginInstallationPermission": { + "pluginInstallationScope": "official_only", + "restrictToMarketplaceOnly": True, + } + } + + result = FeatureService.get_system_features() + assert result.plugin_installation_permission.plugin_installation_scope == "official_only" + assert result.plugin_installation_permission.restrict_to_marketplace_only is True + + # Test case 2: All plugins scope + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "PluginInstallationPermission": {"pluginInstallationScope": "all", "restrictToMarketplaceOnly": False} + } + + result = FeatureService.get_system_features() + assert result.plugin_installation_permission.plugin_installation_scope == "all" + assert result.plugin_installation_permission.restrict_to_marketplace_only is False + + # Test case 3: Specific partners scope + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "PluginInstallationPermission": { + "pluginInstallationScope": "official_and_specific_partners", + "restrictToMarketplaceOnly": False, + } + } + + result = FeatureService.get_system_features() + assert result.plugin_installation_permission.plugin_installation_scope == "official_and_specific_partners" + assert result.plugin_installation_permission.restrict_to_marketplace_only is False + + # Test case 4: None scope + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "PluginInstallationPermission": {"pluginInstallationScope": "none", "restrictToMarketplaceOnly": True} + } + + result = FeatureService.get_system_features() + assert result.plugin_installation_permission.plugin_installation_scope == "none" + assert result.plugin_installation_permission.restrict_to_marketplace_only is True + + def test_get_features_workspace_members_missing( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval when workspace members info is missing from enterprise. + + This test verifies: + - Proper handling of missing workspace members data + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup missing workspace members mock + tenant_id = self._create_test_tenant_id() + mock_external_service_dependencies["enterprise_service"].get_workspace_info.return_value = { + # Missing WorkspaceMembers key + } + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + mock_config.ENTERPRISE_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify workspace members use default values + assert result.workspace_members.enabled is False + assert result.workspace_members.size == 0 + assert result.workspace_members.limit == 0 + + # Verify enterprise features + assert result.webapp_copyright_enabled is True + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with( + tenant_id + ) + + def test_get_system_features_license_inactive(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test system features retrieval with inactive license. + + This test verifies: + - Proper handling of inactive license status + - Correct enterprise service integration + - Proper license status handling + - Return value correctness and structure + """ + # Arrange: Setup inactive license mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "License": { + "status": "inactive", + "expiredAt": "", + "workspaces": {"enabled": False, "limit": 0, "used": 0}, + } + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify license status + assert result.license.status == "inactive" + assert result.license.expired_at == "" + assert result.license.workspaces.enabled is False + assert result.license.workspaces.size == 0 + assert result.license.workspaces.limit == 0 + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_system_features_partial_enterprise_info( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with partial enterprise information. + + This test verifies: + - Proper handling of partial enterprise data + - Correct fallback to default values + - Proper enterprise service integration + - Return value correctness and structure + """ + # Arrange: Setup partial enterprise info mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "SSOEnforcedForSignin": True, + "Branding": {"applicationTitle": "Partial Enterprise"}, + # Missing WebAppAuth, License, PluginInstallationPermission, etc. + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify SSO configuration + assert result.sso_enforced_for_signin is True + assert result.sso_enforced_for_signin_protocol == "" + + # Verify branding configuration (partial) + assert result.branding.application_title == "Partial Enterprise" + assert result.branding.login_page_logo == "" + assert result.branding.workspace_logo == "" + assert result.branding.favicon == "" + + # Verify default values for missing enterprise info + assert result.webapp_auth.allow_sso is False + assert result.webapp_auth.allow_email_code_login is False + assert result.webapp_auth.allow_email_password_login is False + assert result.webapp_auth.sso_config.protocol == "" + + # Verify default license status + assert result.license.status == "none" + assert result.license.expired_at == "" + assert result.license.workspaces.enabled is False + + # Verify default plugin installation permission + assert result.plugin_installation_permission.plugin_installation_scope == "all" + assert result.plugin_installation_permission.restrict_to_marketplace_only is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_limits(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with edge case limit values. + + This test verifies: + - Proper handling of zero and negative limits + - Correct handling of very large limits + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case limits mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "enterprise", "interval": "yearly"}, + "members": {"size": 0, "limit": 0}, + "apps": {"size": 0, "limit": -1}, + "vector_space": {"size": 0, "limit": 999999}, + "documents_upload_quota": {"size": 0, "limit": 0}, + "annotation_quota_limit": {"size": 0, "limit": 1}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify edge case limits + assert result.members.size == 0 + assert result.members.limit == 0 + assert result.apps.size == 0 + assert result.apps.limit == -1 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 999999 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 0 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 1 + + # Verify enterprise plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_system_features_edge_case_protocols( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with edge case protocol values. + + This test verifies: + - Proper handling of empty protocol strings + - Correct handling of special protocol values + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case protocols mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "SSOEnforcedForSigninProtocol": "", + "SSOEnforcedForWebProtocol": " ", + "WebAppAuth": {"allowSso": True, "allowEmailCodeLogin": False, "allowEmailPasswordLogin": True}, + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify edge case protocols + assert result.sso_enforced_for_signin_protocol == "" + assert result.webapp_auth.sso_config.protocol == " " + + # Verify webapp auth configuration + assert result.webapp_auth.allow_sso is True + assert result.webapp_auth.allow_email_code_login is False + assert result.webapp_auth.allow_email_password_login is True + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_education(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with edge case education configuration. + + This test verifies: + - Proper handling of education feature flags + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case education mock + tenant_id = self._create_test_tenant_id() + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "education", "interval": "semester", "education": True}, + "members": {"size": 100, "limit": 200}, + "apps": {"size": 50, "limit": 100}, + "vector_space": {"size": 20, "limit": 50}, + "documents_upload_quota": {"size": 500, "limit": 1000}, + "annotation_quota_limit": {"size": 200, "limit": 500}, + } + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.EDUCATION_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify education features + assert result.education.enabled is True + assert result.education.activated is True + + # Verify education plan limits + assert result.members.size == 100 + assert result.members.limit == 200 + assert result.apps.size == 50 + assert result.apps.limit == 100 + assert result.vector_space.size == 20 + assert result.vector_space.limit == 50 + assert result.documents_upload_quota.size == 500 + assert result.documents_upload_quota.limit == 1000 + assert result.annotation_quota_limit.size == 200 + assert result.annotation_quota_limit.limit == 500 + + # Verify education plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_license_limitation_model_is_available( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test LicenseLimitationModel.is_available method with various scenarios. + + This test verifies: + - Proper quota availability calculation + - Correct handling of unlimited limits + - Proper handling of disabled limits + - Return value correctness for different scenarios + """ + from services.feature_service import LicenseLimitationModel + + # Test case 1: Limit disabled + disabled_limit = LicenseLimitationModel(enabled=False, size=5, limit=10) + assert disabled_limit.is_available(3) is True + assert disabled_limit.is_available(10) is True + + # Test case 2: Unlimited limit + unlimited_limit = LicenseLimitationModel(enabled=True, size=5, limit=0) + assert unlimited_limit.is_available(3) is True + assert unlimited_limit.is_available(100) is True + + # Test case 3: Available quota + available_limit = LicenseLimitationModel(enabled=True, size=5, limit=10) + assert available_limit.is_available(3) is True + assert available_limit.is_available(5) is True + assert available_limit.is_available(1) is True + + # Test case 4: Insufficient quota + insufficient_limit = LicenseLimitationModel(enabled=True, size=8, limit=10) + assert insufficient_limit.is_available(3) is False + assert insufficient_limit.is_available(2) is True + assert insufficient_limit.is_available(1) is True + + # Test case 5: Exact quota usage + exact_limit = LicenseLimitationModel(enabled=True, size=7, limit=10) + assert exact_limit.is_available(3) is True + assert exact_limit.is_available(3) is True + + def test_get_features_workspace_members_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval when workspace members are disabled in enterprise. + + This test verifies: + - Proper handling of disabled workspace members + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup workspace members disabled mock + tenant_id = self._create_test_tenant_id() + mock_external_service_dependencies["enterprise_service"].get_workspace_info.return_value = { + "WorkspaceMembers": {"used": 0, "limit": 0, "enabled": False} + } + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + mock_config.ENTERPRISE_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify workspace members are disabled + assert result.workspace_members.enabled is False + assert result.workspace_members.size == 0 + assert result.workspace_members.limit == 0 + + # Verify enterprise features + assert result.webapp_copyright_enabled is True + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with(tenant_id) + + def test_get_system_features_license_expired(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test system features retrieval with expired license. + + This test verifies: + - Proper handling of expired license status + - Correct enterprise service integration + - Proper license status handling + - Return value correctness and structure + """ + # Arrange: Setup expired license mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "License": { + "status": "expired", + "expiredAt": "2023-12-31", + "workspaces": {"enabled": False, "limit": 0, "used": 0}, + } + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify license status + assert result.license.status == "expired" + assert result.license.expired_at == "2023-12-31" + assert result.license.workspaces.enabled is False + assert result.license.workspaces.size == 0 + assert result.license.workspaces.limit == 0 + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_docs_processing( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval with edge case document processing configuration. + + This test verifies: + - Proper handling of different document processing modes + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case docs processing mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = True + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "premium", "interval": "monthly"}, + "docs_processing": "advanced", + "can_replace_logo": True, + "model_load_balancing_enabled": True, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify docs processing configuration + assert result.docs_processing == "advanced" + assert result.can_replace_logo is True + assert result.model_load_balancing_enabled is True + + # Verify premium plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default limitations (no specific billing info) + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_system_features_edge_case_branding( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with edge case branding configuration. + + This test verifies: + - Proper handling of partial branding information + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case branding mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "Branding": { + "applicationTitle": "Edge Case App", + "loginPageLogo": None, + "workspaceLogo": "", + "favicon": "https://example.com/favicon.ico", + } + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify branding configuration (edge cases) + assert result.branding.application_title == "Edge Case App" + assert result.branding.login_page_logo is None # None value from mock + assert result.branding.workspace_logo == "" + assert result.branding.favicon == "https://example.com/favicon.ico" + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify default values for missing enterprise info + assert result.sso_enforced_for_signin is False + assert result.sso_enforced_for_signin_protocol == "" + assert result.enable_email_code_login is False + assert result.enable_email_password_login is True + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_annotation_quota( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval with edge case annotation quota configuration. + + This test verifies: + - Proper handling of annotation quota limits + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case annotation quota mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "enterprise", "interval": "yearly"}, + "annotation_quota_limit": {"size": 999, "limit": 1000}, + "knowledge_rate_limit": {"limit": 500}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify annotation quota configuration + assert result.annotation_quota_limit.size == 999 + assert result.annotation_quota_limit.limit == 1000 + + # Verify knowledge rate limit + assert result.knowledge_rate_limit == 500 + + # Verify enterprise plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_features_edge_case_documents_upload( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval with edge case documents upload settings. + + This test verifies: + - Proper handling of edge case documents upload configuration + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case documents upload mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "pro", "interval": "monthly"}, + "documents_upload_quota": { + "size": 0, # Edge case: zero current size + "limit": 0, # Edge case: zero limit + }, + "knowledge_rate_limit": {"limit": 100}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify documents upload quota configuration (edge cases) + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 0 + + # Verify knowledge rate limit + assert result.knowledge_rate_limit == 100 + + # Verify pro plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 # Default value when not provided + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_system_features_edge_case_license_lost( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features with lost license status. + + This test verifies: + - Proper handling of lost license status + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup lost license mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "license": {"status": "lost", "expired_at": None, "plan": None} + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify default values for missing enterprise info + assert result.sso_enforced_for_signin is False + assert result.sso_enforced_for_signin_protocol == "" + assert result.enable_email_code_login is False + assert result.enable_email_password_login is True + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_education_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval with education feature disabled. + + This test verifies: + - Proper handling of disabled education features + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup education disabled mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": { + "plan": "pro", + "interval": "monthly", + "education": False, # Education explicitly disabled + }, + "knowledge_rate_limit": {"limit": 100}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify education configuration + assert result.education.activated is False + + # Verify knowledge rate limit + assert result.knowledge_rate_limit == 100 + + # Verify pro plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 # Default value when not provided + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index 25ba0d03ef..ece6de6cdf 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -484,7 +484,7 @@ class TestMessageService: # Verify feedback was deleted from extensions.ext_database import db - deleted_feedback = db.session.query(MessageFeedback).filter(MessageFeedback.id == feedback.id).first() + deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first() assert deleted_feedback is None def test_create_feedback_no_rating_when_not_exists( diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index a8a36b2565..cb20238f0c 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -469,6 +469,6 @@ class TestModelLoadBalancingService: # Verify inherit config was created in database inherit_configs = ( - db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.name == "__inherit__").all() + db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all() ) assert len(inherit_configs) == 1 diff --git a/api/uv.lock b/api/uv.lock index cecce2bc43..40091c297a 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1602,7 +1602,7 @@ vdb = [ { name = "pgvector", specifier = "==0.2.5" }, { name = "pymilvus", specifier = "~=2.5.0" }, { name = "pymochow", specifier = "==1.3.1" }, - { name = "pyobvector", specifier = "~=0.1.6" }, + { name = "pyobvector", specifier = "~=0.2.15" }, { name = "qdrant-client", specifier = "==1.9.0" }, { name = "tablestore", specifier = "==6.2.0" }, { name = "tcvectordb", specifier = "~=1.6.4" }, @@ -4569,17 +4569,19 @@ wheels = [ [[package]] name = "pyobvector" -version = "0.1.14" +version = "0.2.15" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiomysql" }, { name = "numpy" }, + { name = "pydantic" }, { name = "pymysql" }, { name = "sqlalchemy" }, + { name = "sqlglot" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/dc/59/7d762061808948dd6aad165a000b34e22163dc83fb5014184eeacc0fabe5/pyobvector-0.1.14.tar.gz", hash = "sha256:4f85cdd63064d040e94c0a96099a0cd5cda18ce625865382e89429f28422fc02", size = 26780, upload-time = "2024-11-20T11:46:18.017Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/7d/3f3aac6acf1fdd1782042d6eecd48efaa2ee355af0dbb61e93292d629391/pyobvector-0.2.15.tar.gz", hash = "sha256:5de258c1e952c88b385b5661e130c1cf8262c498c1f8a4a348a35962d379fce4", size = 39611, upload-time = "2025-08-18T02:49:26.683Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/88/68/ecb21b74c974e7be7f9034e205d08db62d614ff5c221581ae96d37ef853e/pyobvector-0.1.14-py3-none-any.whl", hash = "sha256:828e0bec49a177355b70c7a1270af3b0bf5239200ee0d096e4165b267eeff97c", size = 35526, upload-time = "2024-11-20T11:46:16.809Z" }, + { url = "https://files.pythonhosted.org/packages/5f/1f/a62754ba9b8a02c038d2a96cb641b71d3809f34d2ba4f921fecd7840d7fb/pyobvector-0.2.15-py3-none-any.whl", hash = "sha256:feeefe849ee5400e72a9a4d3844e425a58a99053dd02abe06884206923065ebb", size = 52680, upload-time = "2025-08-18T02:49:25.452Z" }, ] [[package]] @@ -5432,6 +5434,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" }, ] +[[package]] +name = "sqlglot" +version = "26.33.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/25/9d/fcd59b4612d5ad1e2257c67c478107f073b19e1097d3bfde2fb517884416/sqlglot-26.33.0.tar.gz", hash = "sha256:2817278779fa51d6def43aa0d70690b93a25c83eb18ec97130fdaf707abc0d73", size = 5353340, upload-time = "2025-07-01T13:09:06.311Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/8d/f1d9cb5b18e06aa45689fbeaaea6ebab66d5f01d1e65029a8f7657c06be5/sqlglot-26.33.0-py3-none-any.whl", hash = "sha256:031cee20c0c796a83d26d079a47fdce667604df430598c7eabfa4e4dfd147033", size = 477610, upload-time = "2025-07-01T13:09:03.926Z" }, +] + [[package]] name = "sseclient-py" version = "1.8.0" diff --git a/docker/.env.example b/docker/.env.example index 743a1e8bba..826b7b9fe6 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -887,6 +887,14 @@ API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository. # API workflow node execution repository implementation API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository +# Workflow log cleanup configuration +# Enable automatic cleanup of workflow run logs to manage database size +WORKFLOW_LOG_CLEANUP_ENABLED=false +# Number of days to retain workflow run logs (default: 30 days) +WORKFLOW_LOG_RETENTION_DAYS=30 +# Batch size for workflow log cleanup operations (default: 100) +WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100 + # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index bcf9588dff..0c352e4658 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -396,6 +396,9 @@ x-shared-env: &shared-api-worker-env CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository} API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository} API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository} + WORKFLOW_LOG_CLEANUP_ENABLED: ${WORKFLOW_LOG_CLEANUP_ENABLED:-false} + WORKFLOW_LOG_RETENTION_DAYS: ${WORKFLOW_LOG_RETENTION_DAYS:-30} + WORKFLOW_LOG_CLEANUP_BATCH_SIZE: ${WORKFLOW_LOG_CLEANUP_BATCH_SIZE:-100} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} diff --git a/web/app/account/account-page/AvatarWithEdit.tsx b/web/app/account/account-page/AvatarWithEdit.tsx index 41a6971bf5..88e3a7b343 100644 --- a/web/app/account/account-page/AvatarWithEdit.tsx +++ b/web/app/account/account-page/AvatarWithEdit.tsx @@ -4,7 +4,7 @@ import type { Area } from 'react-easy-crop' import React, { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { RiPencilLine } from '@remixicon/react' +import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react' import { updateUserProfile } from '@/service/common' import { ToastContext } from '@/app/components/base/toast' import ImageInput, { type OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput' @@ -27,6 +27,8 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { const [inputImageInfo, setInputImageInfo] = useState() const [isShowAvatarPicker, setIsShowAvatarPicker] = useState(false) const [uploading, setUploading] = useState(false) + const [isShowDeleteConfirm, setIsShowDeleteConfirm] = useState(false) + const [hoverArea, setHoverArea] = useState('left') const handleImageInput: OnImageInput = useCallback(async (isCropped: boolean, fileOrTempUrl: string | File, croppedAreaPixels?: Area, fileName?: string) => { setInputImageInfo( @@ -48,6 +50,18 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { } }, [notify, onSave, t]) + const handleDeleteAvatar = useCallback(async () => { + try { + await updateUserProfile({ url: 'account/avatar', body: { avatar: '' } }) + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + setIsShowDeleteConfirm(false) + onSave?.() + } + catch (e) { + notify({ type: 'error', message: (e as Error).message }) + } + }, [notify, onSave, t]) + const { handleLocalFileUpload } = useLocalFileUploader({ limit: 3, disabled: false, @@ -86,12 +100,21 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
{ setIsShowAvatarPicker(true) }} className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black/50 opacity-0 transition-opacity group-hover:opacity-100" + onClick={() => hoverArea === 'right' ? setIsShowDeleteConfirm(true) : setIsShowAvatarPicker(true)} + onMouseMove={(e) => { + const rect = e.currentTarget.getBoundingClientRect() + const x = e.clientX - rect.left + const isRight = x > rect.width / 2 + setHoverArea(isRight ? 'right' : 'left') + }} > - + {hoverArea === 'right' ? + + : - + } +
@@ -115,6 +138,26 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { + + setIsShowDeleteConfirm(false)} + > +
{t('common.avatar.deleteTitle')}
+

{t('common.avatar.deleteDescription')}

+ +
+ + + +
+
) } diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index eb0f524386..a7bdc550d1 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -13,7 +13,7 @@ import Tooltip from '@/app/components/base/tooltip' import { AppType } from '@/types/app' import { getNewVar, getVars } from '@/utils/var' import AutomaticBtn from '@/app/components/app/configuration/config/automatic/automatic-btn' -import type { AutomaticRes } from '@/service/debug' +import type { GenRes } from '@/service/debug' import GetAutomaticResModal from '@/app/components/app/configuration/config/automatic/get-automatic-res' import PromptEditor from '@/app/components/base/prompt-editor' import ConfigContext from '@/context/debug-configuration' @@ -61,6 +61,7 @@ const Prompt: FC = ({ const { eventEmitter } = useEventEmitterContextContext() const { + appId, modelConfig, dataSets, setModelConfig, @@ -139,21 +140,21 @@ const Prompt: FC = ({ } const [showAutomatic, { setTrue: showAutomaticTrue, setFalse: showAutomaticFalse }] = useBoolean(false) - const handleAutomaticRes = (res: AutomaticRes) => { + const handleAutomaticRes = (res: GenRes) => { // put eventEmitter in first place to prevent overwrite the configs.prompt_variables.But another problem is that prompt won't hight the prompt_variables. eventEmitter?.emit({ type: PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER, - payload: res.prompt, + payload: res.modified, } as any) const newModelConfig = produce(modelConfig, (draft) => { - draft.configs.prompt_template = res.prompt - draft.configs.prompt_variables = res.variables.map(key => ({ key, name: key, type: 'string', required: true })) + draft.configs.prompt_template = res.modified + draft.configs.prompt_variables = (res.variables || []).map(key => ({ key, name: key, type: 'string', required: true })) }) setModelConfig(newModelConfig) setPrevPromptConfig(modelConfig.configs) if (mode !== AppType.completion) { - setIntroduction(res.opening_statement) + setIntroduction(res.opening_statement || '') const newFeatures = produce(features, (draft) => { draft.opening = { ...draft.opening, @@ -272,10 +273,13 @@ const Prompt: FC = ({ {showAutomatic && ( )} diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index aacaa81ac2..31f81d274d 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import React, { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useBoolean } from 'ahooks' +import { useBoolean, useSessionStorageState } from 'ahooks' import { RiDatabase2Line, RiFileExcel2Line, @@ -14,24 +14,18 @@ import { RiTranslate, RiUser2Line, } from '@remixicon/react' -import cn from 'classnames' import s from './style.module.css' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' -import Textarea from '@/app/components/base/textarea' import Toast from '@/app/components/base/toast' -import { generateRule } from '@/service/debug' -import ConfigPrompt from '@/app/components/app/configuration/config-prompt' +import { generateBasicAppFistTimeRule, generateRule } from '@/service/debug' import type { CompletionParams, Model } from '@/types/app' -import { AppType } from '@/types/app' -import ConfigVar from '@/app/components/app/configuration/config-var' -import GroupName from '@/app/components/app/configuration/base/group-name' +import type { AppType } from '@/types/app' import Loading from '@/app/components/base/loading' import Confirm from '@/app/components/base/confirm' -import { LoveMessage } from '@/app/components/base/icons/src/vender/features' // type -import type { AutomaticRes } from '@/service/debug' +import type { GenRes } from '@/service/debug' import { Generator } from '@/app/components/base/icons/src/vender/other' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' @@ -39,13 +33,25 @@ import { ModelTypeEnum } from '@/app/components/header/account-setting/model-pro import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import type { ModelModeType } from '@/types/app' import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' +import InstructionEditorInWorkflow from './instruction-editor-in-workflow' +import InstructionEditorInBasic from './instruction-editor' +import { GeneratorType } from './types' +import Result from './result' +import useGenData from './use-gen-data' +import IdeaOutput from './idea-output' +import ResPlaceholder from './res-placeholder' +import { useGenerateRuleTemplate } from '@/service/use-apps' +const i18nPrefix = 'appDebug.generate' export type IGetAutomaticResProps = { mode: AppType isShow: boolean onClose: () => void - onFinished: (res: AutomaticRes) => void - isInLLMNode?: boolean + onFinished: (res: GenRes) => void + flowId?: string + nodeId?: string + currentPrompt?: string + isBasicMode?: boolean } const TryLabel: FC<{ @@ -68,7 +74,10 @@ const GetAutomaticRes: FC = ({ mode, isShow, onClose, - isInLLMNode, + flowId, + nodeId, + currentPrompt, + isBasicMode, onFinished, }) => { const { t } = useTranslation() @@ -123,13 +132,27 @@ const GetAutomaticRes: FC = ({ }, ] - const [instruction, setInstruction] = useState('') + const [instructionFromSessionStorage, setInstruction] = useSessionStorageState(`improve-instruction-${flowId}${isBasicMode ? '' : `-${nodeId}`}`) + const instruction = instructionFromSessionStorage || '' + const [ideaOutput, setIdeaOutput] = useState('') + + const [editorKey, setEditorKey] = useState(`${flowId}-0`) const handleChooseTemplate = useCallback((key: string) => { return () => { const template = t(`appDebug.generate.template.${key}.instruction`) setInstruction(template) + setEditorKey(`${flowId}-${Date.now()}`) } }, [t]) + + const { data: instructionTemplate } = useGenerateRuleTemplate(GeneratorType.prompt, isBasicMode) + useEffect(() => { + if (!instruction && instructionTemplate) + setInstruction(instructionTemplate.data) + + setEditorKey(`${flowId}-${Date.now()}`) + }, [instructionTemplate]) + const isValid = () => { if (instruction.trim() === '') { Toast.notify({ @@ -143,7 +166,10 @@ const GetAutomaticRes: FC = ({ return true } const [isLoading, { setTrue: setLoadingTrue, setFalse: setLoadingFalse }] = useBoolean(false) - const [res, setRes] = useState(null) + const storageKey = `${flowId}${isBasicMode ? '' : `-${nodeId}`}` + const { addVersion, current, currentVersionIndex, setCurrentVersionIndex, versions } = useGenData({ + storageKey, + }) useEffect(() => { if (defaultModel) { @@ -170,16 +196,6 @@ const GetAutomaticRes: FC = ({ ) - const renderNoData = ( -
- -
-
{t('appDebug.generate.noDataLine1')}
-
{t('appDebug.generate.noDataLine2')}
-
-
- ) - const handleModelChange = useCallback((newValue: { modelId: string; provider: string; mode?: string; features?: string[] }) => { const newModel = { ...model, @@ -207,28 +223,59 @@ const GetAutomaticRes: FC = ({ return setLoadingTrue() try { - const { error, ...res } = await generateRule({ - instruction, - model_config: model, - no_variable: !!isInLLMNode, - }) - setRes(res) - if (error) { - Toast.notify({ - type: 'error', - message: error, + let apiRes: GenRes + let hasError = false + if (isBasicMode || !currentPrompt) { + const { error, ...res } = await generateBasicAppFistTimeRule({ + instruction, + model_config: model, + no_variable: false, }) + apiRes = { + ...res, + modified: res.prompt, + } as GenRes + if (error) { + hasError = true + Toast.notify({ + type: 'error', + message: error, + }) + } } + else { + const { error, ...res } = await generateRule({ + flow_id: flowId, + node_id: nodeId, + current: currentPrompt, + instruction, + ideal_output: ideaOutput, + model_config: model, + }) + apiRes = res + if (error) { + hasError = true + Toast.notify({ + type: 'error', + message: error, + }) + } + } + if (!hasError) + addVersion(apiRes) } finally { setLoadingFalse() } } - const [showConfirmOverwrite, setShowConfirmOverwrite] = React.useState(false) + const [isShowConfirmOverwrite, { + setTrue: showConfirmOverwrite, + setFalse: hideShowConfirmOverwrite, + }] = useBoolean(false) const isShowAutoPromptResPlaceholder = () => { - return !isLoading && !res + return !isLoading && !current } return ( @@ -236,15 +283,14 @@ const GetAutomaticRes: FC = ({ isShow={isShow} onClose={onClose} className='min-w-[1140px] !p-0' - closable >
-
+
{t('appDebug.generate.title')}
{t('appDebug.generate.description')}
-
+
= ({ hideDebugWithMultipleModel />
-
-
-
{t('appDebug.generate.tryIt')}
-
-
-
- {tryList.map(item => ( - - ))} -
-
- {/* inputs */} -
-
-
{t('appDebug.generate.instruction')}
-