diff --git a/api/commands.py b/api/commands.py index 1858cb2734..d46750316b 100644 --- a/api/commands.py +++ b/api/commands.py @@ -477,12 +477,12 @@ def convert_to_agent_apps(): click.echo(f"Converting app: {app.id}") try: - app.mode = AppMode.AGENT_CHAT.value + app.mode = AppMode.AGENT_CHAT db.session.commit() # update conversation mode to agent db.session.query(Conversation).where(Conversation.app_id == app.id).update( - {Conversation.mode: AppMode.AGENT_CHAT.value} + {Conversation.mode: AppMode.AGENT_CHAT} ) db.session.commit() diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 9fd9b60194..9700447a4c 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -1,4 +1,4 @@ -import enum +from enum import Enum from typing import Literal, Optional from pydantic import Field, PositiveInt @@ -10,7 +10,7 @@ class OpenSearchConfig(BaseSettings): Configuration settings for OpenSearch """ - class AuthMethod(enum.StrEnum): + class AuthMethod(Enum): """ Authentication method for OpenSearch """ diff --git a/api/constants/model_template.py b/api/constants/model_template.py index c26d8c0186..cacf6b6874 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -7,7 +7,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # workflow default mode AppMode.WORKFLOW: { "app": { - "mode": AppMode.WORKFLOW.value, + "mode": AppMode.WORKFLOW, "enable_site": True, "enable_api": True, } @@ -15,7 +15,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # completion default mode AppMode.COMPLETION: { "app": { - "mode": AppMode.COMPLETION.value, + "mode": AppMode.COMPLETION, "enable_site": True, "enable_api": True, }, @@ -44,7 +44,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # chat default mode AppMode.CHAT: { "app": { - "mode": AppMode.CHAT.value, + "mode": AppMode.CHAT, "enable_site": True, "enable_api": True, }, @@ -60,7 +60,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # advanced-chat default mode AppMode.ADVANCED_CHAT: { "app": { - "mode": AppMode.ADVANCED_CHAT.value, + "mode": AppMode.ADVANCED_CHAT, "enable_site": True, "enable_api": True, }, @@ -68,7 +68,7 @@ default_app_templates: Mapping[AppMode, Mapping] = { # agent-chat default mode AppMode.AGENT_CHAT: { "app": { - "mode": AppMode.AGENT_CHAT.value, + "mode": AppMode.AGENT_CHAT, "enable_site": True, "enable_api": True, }, diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 5ff444bb45..c0cbf6613e 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -307,7 +307,7 @@ class ChatConversationApi(Resource): .having(func.count(Message.id) >= args["message_count_gte"]) ) - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) match args["sort_by"]: diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index b4f7605136..11df511840 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -74,7 +74,7 @@ class ModelConfigResource(Resource): ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) - if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: # get original app model config original_app_model_config = ( db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index d9afb5bab2..7742ea24a9 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -20,7 +20,7 @@ class AppParameterApi(InstalledAppResource): if app_model is None: raise AppUnavailableError() - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 43b59d5334..80e6037cc7 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -150,7 +150,7 @@ class MCPAppApi(Resource): def _get_user_input_form(self, app: App) -> list[VariableEntity]: """Get and convert user input form""" # Get raw user input form based on app mode - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if not app.workflow: raise MCPRequestError(mcp_types.INVALID_REQUEST, "App is unavailable") raw_user_input_form = app.workflow.user_input_form(to_old_structure=True) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 2dbeed1d68..25d7ccccec 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -29,7 +29,7 @@ class AppParameterApi(Resource): Returns the input form parameters and configuration for the application. """ - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index e0c3e997ce..2bc068ec75 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -38,7 +38,7 @@ class AppParameterApi(WebApiResource): @marshal_with(fields.parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index a3438fc2c7..1133ecf66c 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -1,4 +1,4 @@ -import enum +from enum import StrEnum from typing import Any, Optional from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator @@ -26,25 +26,25 @@ class AgentStrategyProviderIdentity(ToolProviderIdentity): class AgentStrategyParameter(PluginParameter): - class AgentStrategyParameterType(enum.StrEnum): + class AgentStrategyParameterType(StrEnum): """ Keep all the types from PluginParameterType """ - STRING = CommonParameterType.STRING.value - NUMBER = CommonParameterType.NUMBER.value - BOOLEAN = CommonParameterType.BOOLEAN.value - SELECT = CommonParameterType.SELECT.value - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - FILE = CommonParameterType.FILE.value - FILES = CommonParameterType.FILES.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value - ANY = CommonParameterType.ANY.value + STRING = CommonParameterType.STRING + NUMBER = CommonParameterType.NUMBER + BOOLEAN = CommonParameterType.BOOLEAN + SELECT = CommonParameterType.SELECT + SECRET_INPUT = CommonParameterType.SECRET_INPUT + FILE = CommonParameterType.FILE + FILES = CommonParameterType.FILES + APP_SELECTOR = CommonParameterType.APP_SELECTOR + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR + ANY = CommonParameterType.ANY # deprecated, should not use. - SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value + SYSTEM_FILES = CommonParameterType.SYSTEM_FILES def as_normal_type(self): return as_normal_type(self) @@ -72,7 +72,7 @@ class AgentStrategyIdentity(ToolIdentity): pass -class AgentFeature(enum.StrEnum): +class AgentFeature(StrEnum): """ Agent Feature, used to describe the features of the agent strategy. """ diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index cda17c0010..ec4f6074ab 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -70,7 +70,7 @@ class PromptTemplateConfigManager: :param config: app model config args """ if not config.get("prompt_type"): - config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value + config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] if config["prompt_type"] not in prompt_type_vals: @@ -90,7 +90,7 @@ class PromptTemplateConfigManager: if not isinstance(config["completion_prompt_config"], dict): raise ValueError("completion_prompt_config must be of object type") - if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value: + if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED: if not config["chat_prompt_config"] and not config["completion_prompt_config"]: raise ValueError( "chat_prompt_config or completion_prompt_config is required when prompt_type is advanced" diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index df2074df2c..37745506e9 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import Any, Literal, Optional from pydantic import BaseModel, Field, field_validator @@ -61,14 +61,14 @@ class PromptTemplateEntity(BaseModel): Prompt Template Entity. """ - class PromptType(Enum): + class PromptType(StrEnum): """ Prompt Type. 'simple', 'advanced' """ - SIMPLE = "simple" - ADVANCED = "advanced" + SIMPLE = auto() + ADVANCED = auto() @classmethod def value_of(cls, value: str): @@ -195,14 +195,14 @@ class DatasetRetrieveConfigEntity(BaseModel): Dataset Retrieve Config Entity. """ - class RetrieveStrategy(Enum): + class RetrieveStrategy(StrEnum): """ Dataset Retrieve Strategy. 'single' or 'multiple' """ - SINGLE = "single" - MULTIPLE = "multiple" + SINGLE = auto() + MULTIPLE = auto() @classmethod def value_of(cls, value: str): @@ -293,12 +293,12 @@ class AppConfig(BaseModel): sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None -class EasyUIBasedAppModelConfigFrom(Enum): +class EasyUIBasedAppModelConfigFrom(StrEnum): """ App Model Config From. """ - ARGS = "args" + ARGS = auto() APP_LATEST_CONFIG = "app-latest-config" CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config" diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index fc04e60836..8d20b6bc1b 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from datetime import datetime -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import Any, Optional from pydantic import BaseModel @@ -626,15 +626,15 @@ class QueueStopEvent(AppQueueEvent): QueueStopEvent entity """ - class StopBy(Enum): + class StopBy(StrEnum): """ Stop by enum """ - USER_MANUAL = "user-manual" - ANNOTATION_REPLY = "annotation-reply" - OUTPUT_MODERATION = "output-moderation" - INPUT_MODERATION = "input-moderation" + USER_MANUAL = auto() + ANNOTATION_REPLY = auto() + OUTPUT_MODERATION = auto() + INPUT_MODERATION = auto() event: QueueEvent = QueueEvent.STOP stopped_by: StopBy diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 31183d19a3..717e5d715a 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from enum import Enum +from enum import StrEnum, auto from typing import Any, Optional from pydantic import BaseModel, ConfigDict, Field @@ -50,37 +50,37 @@ class WorkflowTaskState(TaskState): answer: str = "" -class StreamEvent(Enum): +class StreamEvent(StrEnum): """ Stream event """ - PING = "ping" - ERROR = "error" - MESSAGE = "message" - MESSAGE_END = "message_end" - TTS_MESSAGE = "tts_message" - TTS_MESSAGE_END = "tts_message_end" - MESSAGE_FILE = "message_file" - MESSAGE_REPLACE = "message_replace" - AGENT_THOUGHT = "agent_thought" - AGENT_MESSAGE = "agent_message" - WORKFLOW_STARTED = "workflow_started" - WORKFLOW_FINISHED = "workflow_finished" - NODE_STARTED = "node_started" - NODE_FINISHED = "node_finished" - NODE_RETRY = "node_retry" - PARALLEL_BRANCH_STARTED = "parallel_branch_started" - PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" - ITERATION_STARTED = "iteration_started" - ITERATION_NEXT = "iteration_next" - ITERATION_COMPLETED = "iteration_completed" - LOOP_STARTED = "loop_started" - LOOP_NEXT = "loop_next" - LOOP_COMPLETED = "loop_completed" - TEXT_CHUNK = "text_chunk" - TEXT_REPLACE = "text_replace" - AGENT_LOG = "agent_log" + PING = auto() + ERROR = auto() + MESSAGE = auto() + MESSAGE_END = auto() + TTS_MESSAGE = auto() + TTS_MESSAGE_END = auto() + MESSAGE_FILE = auto() + MESSAGE_REPLACE = auto() + AGENT_THOUGHT = auto() + AGENT_MESSAGE = auto() + WORKFLOW_STARTED = auto() + WORKFLOW_FINISHED = auto() + NODE_STARTED = auto() + NODE_FINISHED = auto() + NODE_RETRY = auto() + PARALLEL_BRANCH_STARTED = auto() + PARALLEL_BRANCH_FINISHED = auto() + ITERATION_STARTED = auto() + ITERATION_NEXT = auto() + ITERATION_COMPLETED = auto() + LOOP_STARTED = auto() + LOOP_NEXT = auto() + LOOP_COMPLETED = auto() + TEXT_CHUNK = auto() + TEXT_REPLACE = auto() + AGENT_LOG = auto() class StreamResponse(BaseModel): diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 71fd5ac653..d4d5c9f7d2 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -145,7 +145,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if self._task_state.metadata: extras["metadata"] = self._task_state.metadata.model_dump() response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] - if self._conversation_mode == AppMode.COMPLETION.value: + if self._conversation_mode == AppMode.COMPLETION: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=CompletionAppBlockingResponse.Data( diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index e865ba9d60..2cefb61df3 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -92,7 +92,7 @@ class MessageCycleManager: if not conversation: return - if conversation.mode != AppMode.COMPLETION.value: + if conversation.mode != AppMode.COMPLETION: app_model = conversation.app if not app_model: return diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py index 656bf4aa72..cf958b91d2 100644 --- a/api/core/entities/agent_entities.py +++ b/api/core/entities/agent_entities.py @@ -1,8 +1,8 @@ -from enum import Enum +from enum import StrEnum, auto -class PlanningStrategy(Enum): - ROUTER = "router" - REACT_ROUTER = "react_router" - REACT = "react" - FUNCTION_CALL = "function_call" +class PlanningStrategy(StrEnum): + ROUTER = auto() + REACT_ROUTER = auto() + REACT = auto() + FUNCTION_CALL = auto() diff --git a/api/core/entities/embedding_type.py b/api/core/entities/embedding_type.py index 9b4934646b..89b48fd2ef 100644 --- a/api/core/entities/embedding_type.py +++ b/api/core/entities/embedding_type.py @@ -1,10 +1,10 @@ -from enum import Enum +from enum import StrEnum, auto -class EmbeddingInputType(Enum): +class EmbeddingInputType(StrEnum): """ Enum for embedding input type. """ - DOCUMENT = "document" - QUERY = "query" + DOCUMENT = auto() + QUERY = auto() diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 0fd49b059c..4794a691bd 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum +from enum import StrEnum, auto from typing import Optional from pydantic import BaseModel, ConfigDict @@ -9,16 +9,16 @@ from core.model_runtime.entities.model_entities import ModelType, ProviderModel from core.model_runtime.entities.provider_entities import ProviderEntity -class ModelStatus(Enum): +class ModelStatus(StrEnum): """ Enum class for model status. """ - ACTIVE = "active" + ACTIVE = auto() NO_CONFIGURE = "no-configure" QUOTA_EXCEEDED = "quota-exceeded" NO_PERMISSION = "no-permission" - DISABLED = "disabled" + DISABLED = auto() CREDENTIAL_REMOVED = "credential-removed" diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index fbd62437e6..0afb51edce 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -1,20 +1,20 @@ -from enum import StrEnum +from enum import StrEnum, auto class CommonParameterType(StrEnum): SECRET_INPUT = "secret-input" TEXT_INPUT = "text-input" - SELECT = "select" - STRING = "string" - NUMBER = "number" - FILE = "file" - FILES = "files" + SELECT = auto() + STRING = auto() + NUMBER = auto() + FILE = auto() + FILES = auto() SYSTEM_FILES = "system-files" - BOOLEAN = "boolean" + BOOLEAN = auto() APP_SELECTOR = "app-selector" MODEL_SELECTOR = "model-selector" TOOLS_SELECTOR = "array[tools]" - ANY = "any" + ANY = auto() # Dynamic select parameter # Once you are not sure about the available options until authorization is done @@ -23,29 +23,29 @@ class CommonParameterType(StrEnum): # TOOL_SELECTOR = "tool-selector" # MCP object and array type parameters - ARRAY = "array" - OBJECT = "object" + ARRAY = auto() + OBJECT = auto() class AppSelectorScope(StrEnum): - ALL = "all" - CHAT = "chat" - WORKFLOW = "workflow" - COMPLETION = "completion" + ALL = auto() + CHAT = auto() + WORKFLOW = auto() + COMPLETION = auto() class ModelSelectorScope(StrEnum): - LLM = "llm" + LLM = auto() TEXT_EMBEDDING = "text-embedding" - RERANK = "rerank" - TTS = "tts" - SPEECH2TEXT = "speech2text" - MODERATION = "moderation" - VISION = "vision" + RERANK = auto() + TTS = auto() + SPEECH2TEXT = auto() + MODERATION = auto() + VISION = auto() class ToolSelectorScope(StrEnum): - ALL = "all" - CUSTOM = "custom" - BUILTIN = "builtin" - WORKFLOW = "workflow" + ALL = auto() + CUSTOM = auto() + BUILTIN = auto() + WORKFLOW = auto() diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 52acbc1eef..ad23f4381d 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum, auto from typing import Optional, Union from pydantic import BaseModel, ConfigDict, Field @@ -13,14 +13,14 @@ from core.model_runtime.entities.model_entities import ModelType from core.tools.entities.common_entities import I18nObject -class ProviderQuotaType(Enum): - PAID = "paid" +class ProviderQuotaType(StrEnum): + PAID = auto() """hosted paid quota""" - FREE = "free" + FREE = auto() """third-party free quota""" - TRIAL = "trial" + TRIAL = auto() """hosted trial quota""" @staticmethod @@ -31,20 +31,20 @@ class ProviderQuotaType(Enum): raise ValueError(f"No matching enum found for value '{value}'") -class QuotaUnit(Enum): - TIMES = "times" - TOKENS = "tokens" - CREDITS = "credits" +class QuotaUnit(StrEnum): + TIMES = auto() + TOKENS = auto() + CREDITS = auto() -class SystemConfigurationStatus(Enum): +class SystemConfigurationStatus(StrEnum): """ Enum class for system configuration status. """ - ACTIVE = "active" + ACTIVE = auto() QUOTA_EXCEEDED = "quota-exceeded" - UNSUPPORTED = "unsupported" + UNSUPPORTED = auto() class RestrictModel(BaseModel): @@ -168,14 +168,14 @@ class BasicProviderConfig(BaseModel): Base model class for common provider settings like credentials """ - class Type(Enum): - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - TEXT_INPUT = CommonParameterType.TEXT_INPUT.value - SELECT = CommonParameterType.SELECT.value - BOOLEAN = CommonParameterType.BOOLEAN.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + class Type(StrEnum): + SECRET_INPUT = CommonParameterType.SECRET_INPUT + TEXT_INPUT = CommonParameterType.TEXT_INPUT + SELECT = CommonParameterType.SELECT + BOOLEAN = CommonParameterType.BOOLEAN + APP_SELECTOR = CommonParameterType.APP_SELECTOR + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR @classmethod def value_of(cls, value: str) -> "ProviderConfig.Type": diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index eee914a529..8cb9b4ac58 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -1,8 +1,8 @@ -import enum import importlib.util import json import logging import os +from enum import StrEnum, auto from pathlib import Path from typing import Any, Optional @@ -13,9 +13,9 @@ from core.helper.position_helper import sort_to_dict_by_position_map logger = logging.getLogger(__name__) -class ExtensionModule(enum.Enum): - MODERATION = "moderation" - EXTERNAL_DATA_TOOL = "external_data_tool" +class ExtensionModule(StrEnum): + MODERATION = auto() + EXTERNAL_DATA_TOOL = auto() class ModuleExtension(BaseModel): diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 1c112007cb..2f160628ca 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -1,12 +1,12 @@ import json -from enum import Enum +from enum import StrEnum from json import JSONDecodeError from typing import Optional from extensions.ext_redis import redis_client -class ProviderCredentialsCacheType(Enum): +class ProviderCredentialsCacheType(StrEnum): PROVIDER = "provider" MODEL = "provider_model" LOAD_BALANCING_MODEL = "load_balancing_provider_model" @@ -14,7 +14,7 @@ class ProviderCredentialsCacheType(Enum): class ProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType): - self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" def get(self) -> Optional[dict]: """ diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index 95a1086ca8..c2cc2e4db0 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -1,12 +1,12 @@ import json -from enum import Enum +from enum import StrEnum from json import JSONDecodeError from typing import Optional from extensions.ext_redis import redis_client -class ToolParameterCacheType(Enum): +class ToolParameterCacheType(StrEnum): PARAMETER = "tool_parameter" @@ -15,7 +15,7 @@ class ToolParameterCache: self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str ): self.cache_key = ( - f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" + f"{cache_type}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" f":identity_id:{identity_id}" ) diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 6f52c65234..212c2eb073 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -142,7 +142,7 @@ def handle_call_tool( end_user, args, InvokeFrom.SERVICE_API, - streaming=app.mode == AppMode.AGENT_CHAT.value, + streaming=app.mode == AppMode.AGENT_CHAT, ) answer = extract_answer_from_response(app, response) @@ -157,7 +157,7 @@ def build_parameter_schema( """Build parameter schema for the tool""" parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict) - if app_mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}: + if app_mode in {AppMode.COMPLETION, AppMode.WORKFLOW}: return { "type": "object", "properties": parameters, @@ -175,9 +175,9 @@ def build_parameter_schema( def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]: """Prepare arguments based on app mode""" - if app.mode == AppMode.WORKFLOW.value: + if app.mode == AppMode.WORKFLOW: return {"inputs": arguments} - elif app.mode == AppMode.COMPLETION.value: + elif app.mode == AppMode.COMPLETION: return {"query": "", "inputs": arguments} else: # Chat modes - create a copy to avoid modifying original dict @@ -218,13 +218,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str: def process_mapping_response(app: App, response: Mapping) -> str: """Process mapping response based on app mode""" if app.mode in { - AppMode.ADVANCED_CHAT.value, - AppMode.COMPLETION.value, - AppMode.CHAT.value, - AppMode.AGENT_CHAT.value, + AppMode.ADVANCED_CHAT, + AppMode.COMPLETION, + AppMode.CHAT, + AppMode.AGENT_CHAT, }: return response.get("answer", "") - elif app.mode == AppMode.WORKFLOW.value: + elif app.mode == AppMode.WORKFLOW: return json.dumps(response["data"]["outputs"], ensure_ascii=False) else: raise ValueError("Invalid app mode: " + str(app.mode)) diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 7cd2e6a3d1..75ffe7c32f 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,20 +1,20 @@ from abc import ABC from collections.abc import Mapping, Sequence -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import Annotated, Any, Literal, Optional, Union from pydantic import BaseModel, Field, field_serializer, field_validator -class PromptMessageRole(Enum): +class PromptMessageRole(StrEnum): """ Enum class for prompt message. """ - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - TOOL = "tool" + SYSTEM = auto() + USER = auto() + ASSISTANT = auto() + TOOL = auto() @classmethod def value_of(cls, value: str) -> "PromptMessageRole": @@ -54,11 +54,11 @@ class PromptMessageContentType(StrEnum): Enum class for prompt message content type. """ - TEXT = "text" - IMAGE = "image" - AUDIO = "audio" - VIDEO = "video" - DOCUMENT = "document" + TEXT = auto() + IMAGE = auto() + AUDIO = auto() + VIDEO = auto() + DOCUMENT = auto() class PromptMessageContent(ABC, BaseModel): @@ -108,8 +108,8 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent): """ class DETAIL(StrEnum): - LOW = "low" - HIGH = "high" + LOW = auto() + HIGH = auto() type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 568149cc37..8259335c50 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -1,5 +1,5 @@ from decimal import Decimal -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import Any, Optional from pydantic import BaseModel, ConfigDict, model_validator @@ -7,17 +7,17 @@ from pydantic import BaseModel, ConfigDict, model_validator from core.model_runtime.entities.common_entities import I18nObject -class ModelType(Enum): +class ModelType(StrEnum): """ Enum class for model type. """ - LLM = "llm" + LLM = auto() TEXT_EMBEDDING = "text-embedding" - RERANK = "rerank" - SPEECH2TEXT = "speech2text" - MODERATION = "moderation" - TTS = "tts" + RERANK = auto() + SPEECH2TEXT = auto() + MODERATION = auto() + TTS = auto() @classmethod def value_of(cls, origin_model_type: str) -> "ModelType": @@ -26,17 +26,17 @@ class ModelType(Enum): :return: model type """ - if origin_model_type in {"text-generation", cls.LLM.value}: + if origin_model_type in {"text-generation", cls.LLM}: return cls.LLM - elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}: + elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}: return cls.TEXT_EMBEDDING - elif origin_model_type in {"reranking", cls.RERANK.value}: + elif origin_model_type in {"reranking", cls.RERANK}: return cls.RERANK - elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}: + elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}: return cls.SPEECH2TEXT - elif origin_model_type in {"tts", cls.TTS.value}: + elif origin_model_type in {"tts", cls.TTS}: return cls.TTS - elif origin_model_type == cls.MODERATION.value: + elif origin_model_type == cls.MODERATION: return cls.MODERATION else: raise ValueError(f"invalid origin model type {origin_model_type}") @@ -63,7 +63,7 @@ class ModelType(Enum): raise ValueError(f"invalid model type {self}") -class FetchFrom(Enum): +class FetchFrom(StrEnum): """ Enum class for fetch from. """ @@ -72,7 +72,7 @@ class FetchFrom(Enum): CUSTOMIZABLE_MODEL = "customizable-model" -class ModelFeature(Enum): +class ModelFeature(StrEnum): """ Enum class for llm feature. """ @@ -80,11 +80,11 @@ class ModelFeature(Enum): TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" - VISION = "vision" + VISION = auto() STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = "document" - VIDEO = "video" - AUDIO = "audio" + DOCUMENT = auto() + VIDEO = auto() + AUDIO = auto() STRUCTURED_OUTPUT = "structured-output" @@ -93,14 +93,14 @@ class DefaultParameterName(StrEnum): Enum class for parameter template variable. """ - TEMPERATURE = "temperature" - TOP_P = "top_p" - TOP_K = "top_k" - PRESENCE_PENALTY = "presence_penalty" - FREQUENCY_PENALTY = "frequency_penalty" - MAX_TOKENS = "max_tokens" - RESPONSE_FORMAT = "response_format" - JSON_SCHEMA = "json_schema" + TEMPERATURE = auto() + TOP_P = auto() + TOP_K = auto() + PRESENCE_PENALTY = auto() + FREQUENCY_PENALTY = auto() + MAX_TOKENS = auto() + RESPONSE_FORMAT = auto() + JSON_SCHEMA = auto() @classmethod def value_of(cls, value: Any) -> "DefaultParameterName": @@ -116,34 +116,34 @@ class DefaultParameterName(StrEnum): raise ValueError(f"invalid parameter name {value}") -class ParameterType(Enum): +class ParameterType(StrEnum): """ Enum class for parameter type. """ - FLOAT = "float" - INT = "int" - STRING = "string" - BOOLEAN = "boolean" - TEXT = "text" + FLOAT = auto() + INT = auto() + STRING = auto() + BOOLEAN = auto() + TEXT = auto() -class ModelPropertyKey(Enum): +class ModelPropertyKey(StrEnum): """ Enum class for model property key. """ - MODE = "mode" - CONTEXT_SIZE = "context_size" - MAX_CHUNKS = "max_chunks" - FILE_UPLOAD_LIMIT = "file_upload_limit" - SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions" - MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk" - DEFAULT_VOICE = "default_voice" - VOICES = "voices" - WORD_LIMIT = "word_limit" - AUDIO_TYPE = "audio_type" - MAX_WORKERS = "max_workers" + MODE = auto() + CONTEXT_SIZE = auto() + MAX_CHUNKS = auto() + FILE_UPLOAD_LIMIT = auto() + SUPPORTED_FILE_EXTENSIONS = auto() + MAX_CHARACTERS_PER_CHUNK = auto() + DEFAULT_VOICE = auto() + VOICES = auto() + WORD_LIMIT = auto() + AUDIO_TYPE = auto() + MAX_WORKERS = auto() class ProviderModel(BaseModel): @@ -220,13 +220,13 @@ class ModelUsage(BaseModel): pass -class PriceType(Enum): +class PriceType(StrEnum): """ Enum class for price type. """ - INPUT = "input" - OUTPUT = "output" + INPUT = auto() + OUTPUT = auto() class PriceInfo(BaseModel): diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index c9aa8d1474..451c2359b3 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum +from enum import Enum, StrEnum, auto from typing import Optional from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -17,16 +17,16 @@ class ConfigurateMethod(Enum): CUSTOMIZABLE_MODEL = "customizable-model" -class FormType(Enum): +class FormType(StrEnum): """ Enum class for form type. """ TEXT_INPUT = "text-input" SECRET_INPUT = "secret-input" - SELECT = "select" - RADIO = "radio" - SWITCH = "switch" + SELECT = auto() + RADIO = auto() + SWITCH = auto() class FormShowOnObject(BaseModel): diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index f7bba0eba1..3ce438955a 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -47,7 +47,7 @@ class TextEmbeddingModel(AIModel): model=model, credentials=credentials, texts=texts, - input_type=input_type.value, + input_type=input_type, ) except Exception as e: raise self._transform_invoke_error(e) diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index 962e417671..f65339fbfc 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -18,7 +18,7 @@ from pydantic_core import Url from pydantic_extra_types.color import Color -def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any): +def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: return model.model_dump(mode=mode, **kwargs) @@ -100,7 +100,7 @@ def jsonable_encoder( exclude_none: bool = False, custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, sqlalchemy_safe: bool = True, -): +) -> Any: custom_encoder = custom_encoder or {} if custom_encoder: if type(obj) in custom_encoder: diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 752617b654..340483c894 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from enum import Enum +from enum import StrEnum, auto from typing import Optional from pydantic import BaseModel, Field @@ -7,9 +7,9 @@ from pydantic import BaseModel, Field from core.extension.extensible import Extensible, ExtensionModule -class ModerationAction(Enum): - DIRECT_OUTPUT = "direct_output" - OVERRIDDEN = "overridden" +class ModerationAction(StrEnum): + DIRECT_OUTPUT = auto() + OVERRIDDEN = auto() class ModerationInputsResult(BaseModel): diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/core/ops/aliyun_trace/entities/semconv.py index 5d70264320..c9427c776a 100644 --- a/api/core/ops/aliyun_trace/entities/semconv.py +++ b/api/core/ops/aliyun_trace/entities/semconv.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum # public GEN_AI_SESSION_ID = "gen_ai.session.id" @@ -53,7 +53,7 @@ TOOL_DESCRIPTION = "tool.description" TOOL_PARAMETERS = "tool.parameters" -class GenAISpanKind(Enum): +class GenAISpanKind(StrEnum): CHAIN = "CHAIN" RETRIEVER = "RETRIEVER" RERANKER = "RERANKER" diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 48f44da68e..86cc3839f2 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -27,7 +27,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): app = cls._get_app(app_id, tenant_id) """Retrieve app parameters.""" - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app.workflow if workflow is None: raise ValueError("unexpected app type") @@ -70,7 +70,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): conversation_id = conversation_id or "" - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.AGENT_CHAT.value, AppMode.CHAT.value}: + if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}: if not query: raise ValueError("missing query") @@ -96,7 +96,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ invoke chat app """ - if app.mode == AppMode.ADVANCED_CHAT.value: + if app.mode == AppMode.ADVANCED_CHAT: workflow = app.workflow if not workflow: raise ValueError("unexpected app type") @@ -114,7 +114,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): invoke_from=InvokeFrom.SERVICE_API, streaming=stream, ) - elif app.mode == AppMode.AGENT_CHAT.value: + elif app.mode == AppMode.AGENT_CHAT: return AgentChatAppGenerator().generate( app_model=app, user=user, @@ -127,7 +127,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): invoke_from=InvokeFrom.SERVICE_API, streaming=stream, ) - elif app.mode == AppMode.CHAT.value: + elif app.mode == AppMode.CHAT: return ChatAppGenerator().generate( app_model=app, user=user, diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 92427a7426..01bb011ce7 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -1,5 +1,5 @@ -import enum import json +from enum import StrEnum, auto from typing import Any, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -25,44 +25,44 @@ class PluginParameterOption(BaseModel): return value -class PluginParameterType(enum.StrEnum): +class PluginParameterType(StrEnum): """ all available parameter types """ - STRING = CommonParameterType.STRING.value - NUMBER = CommonParameterType.NUMBER.value - BOOLEAN = CommonParameterType.BOOLEAN.value - SELECT = CommonParameterType.SELECT.value - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - FILE = CommonParameterType.FILE.value - FILES = CommonParameterType.FILES.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value - ANY = CommonParameterType.ANY.value - DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value + STRING = CommonParameterType.STRING + NUMBER = CommonParameterType.NUMBER + BOOLEAN = CommonParameterType.BOOLEAN + SELECT = CommonParameterType.SELECT + SECRET_INPUT = CommonParameterType.SECRET_INPUT + FILE = CommonParameterType.FILE + FILES = CommonParameterType.FILES + APP_SELECTOR = CommonParameterType.APP_SELECTOR + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR + ANY = CommonParameterType.ANY + DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT # deprecated, should not use. - SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value + SYSTEM_FILES = CommonParameterType.SYSTEM_FILES # MCP object and array type parameters - ARRAY = CommonParameterType.ARRAY.value - OBJECT = CommonParameterType.OBJECT.value + ARRAY = CommonParameterType.ARRAY + OBJECT = CommonParameterType.OBJECT -class MCPServerParameterType(enum.StrEnum): +class MCPServerParameterType(StrEnum): """ MCP server got complex parameter types """ - ARRAY = "array" - OBJECT = "object" + ARRAY = auto() + OBJECT = auto() class PluginParameterAutoGenerate(BaseModel): - class Type(enum.StrEnum): - PROMPT_INSTRUCTION = "prompt_instruction" + class Type(StrEnum): + PROMPT_INSTRUCTION = auto() type: Type @@ -93,7 +93,7 @@ class PluginParameter(BaseModel): return v -def as_normal_type(typ: enum.StrEnum): +def as_normal_type(typ: StrEnum): if typ.value in { PluginParameterType.SECRET_INPUT, PluginParameterType.SELECT, @@ -102,7 +102,7 @@ def as_normal_type(typ: enum.StrEnum): return typ.value -def cast_parameter_value(typ: enum.StrEnum, value: Any, /): +def cast_parameter_value(typ: StrEnum, value: Any, /): try: match typ.value: case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT: @@ -190,7 +190,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.") -def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: Any): +def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any): """ init frontend parameter by rule """ diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index a6369636e2..261d97c2b6 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -1,7 +1,7 @@ import datetime -import enum import re from collections.abc import Mapping +from enum import StrEnum, auto from typing import Any, Optional from packaging.version import InvalidVersion, Version @@ -16,11 +16,11 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity -class PluginInstallationSource(enum.StrEnum): - Github = "github" - Marketplace = "marketplace" - Package = "package" - Remote = "remote" +class PluginInstallationSource(StrEnum): + Github = auto() + Marketplace = auto() + Package = auto() + Remote = auto() class PluginResourceRequirements(BaseModel): @@ -58,10 +58,10 @@ class PluginResourceRequirements(BaseModel): permission: Optional[Permission] = Field(default=None) -class PluginCategory(enum.StrEnum): - Tool = "tool" - Model = "model" - Extension = "extension" +class PluginCategory(StrEnum): + Tool = auto() + Model = auto() + Extension = auto() AgentStrategy = "agent-strategy" @@ -206,10 +206,10 @@ class ToolProviderID(GenericProviderID): class PluginDependency(BaseModel): - class Type(enum.StrEnum): - Github = PluginInstallationSource.Github.value - Marketplace = PluginInstallationSource.Marketplace.value - Package = PluginInstallationSource.Package.value + class Type(StrEnum): + Github = PluginInstallationSource.Github + Marketplace = PluginInstallationSource.Marketplace + Package = PluginInstallationSource.Package class Github(BaseModel): repo: str diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d15cb7cbc1..fdab0af103 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,7 +1,7 @@ -import enum import json import os from collections.abc import Mapping, Sequence +from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, Optional, cast from core.app.app_config.entities import PromptTemplateEntity @@ -25,9 +25,9 @@ if TYPE_CHECKING: from core.file.models import File -class ModelMode(enum.StrEnum): - COMPLETION = "completion" - CHAT = "chat" +class ModelMode(StrEnum): + COMPLETION = auto() + CHAT = auto() prompt_file_contents: dict[str, Any] = {} diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py index 9887e21b7c..8fc94be360 100644 --- a/api/core/rag/datasource/vdb/field.py +++ b/api/core/rag/datasource/vdb/field.py @@ -1,13 +1,13 @@ -from enum import Enum +from enum import StrEnum, auto -class Field(Enum): +class Field(StrEnum): CONTENT_KEY = "page_content" METADATA_KEY = "metadata" GROUP_KEY = "group_id" - VECTOR = "vector" + VECTOR = auto() # Sparse Vector aims to support full text search - SPARSE_VECTOR = "sparse_vector" + SPARSE_VECTOR = auto() TEXT_KEY = "text" PRIMARY_KEY = "id" DOC_ID = "metadata.doc_id" diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b590a4dfe4..17aac25b87 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -1,7 +1,7 @@ import json import logging import uuid -from enum import Enum +from enum import StrEnum from typing import Any from clickhouse_connect import get_client @@ -27,7 +27,7 @@ class MyScaleConfig(BaseModel): fts_params: str -class SortOrder(Enum): +class SortOrder(StrEnum): ASC = "ASC" DESC = "DESC" diff --git a/api/core/rag/extractor/entity/datasource_type.py b/api/core/rag/extractor/entity/datasource_type.py index 19ad300d11..6568f60ea2 100644 --- a/api/core/rag/extractor/entity/datasource_type.py +++ b/api/core/rag/extractor/entity/datasource_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class DatasourceType(Enum): +class DatasourceType(StrEnum): FILE = "upload_file" NOTION = "notion_import" WEBSITE = "website_crawl" diff --git a/api/core/rag/index_processor/constant/built_in_field.py b/api/core/rag/index_processor/constant/built_in_field.py index c8ad53e3dd..1d9ca89ba7 100644 --- a/api/core/rag/index_processor/constant/built_in_field.py +++ b/api/core/rag/index_processor/constant/built_in_field.py @@ -1,15 +1,15 @@ -from enum import Enum, StrEnum +from enum import StrEnum, auto class BuiltInField(StrEnum): - document_name = "document_name" - uploader = "uploader" - upload_date = "upload_date" - last_update_date = "last_update_date" - source = "source" + document_name = auto() + uploader = auto() + upload_date = auto() + last_update_date = auto() + source = auto() -class MetadataDataSource(Enum): +class MetadataDataSource(StrEnum): upload_file = "file_upload" website_crawl = "website" notion_import = "notion" diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 66304b30a5..f934b0bf96 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,8 +1,7 @@ import base64 import contextlib -import enum from collections.abc import Mapping -from enum import Enum +from enum import StrEnum, auto from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator @@ -22,37 +21,37 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY -class ToolLabelEnum(Enum): - SEARCH = "search" - IMAGE = "image" - VIDEOS = "videos" - WEATHER = "weather" - FINANCE = "finance" - DESIGN = "design" - TRAVEL = "travel" - SOCIAL = "social" - NEWS = "news" - MEDICAL = "medical" - PRODUCTIVITY = "productivity" - EDUCATION = "education" - BUSINESS = "business" - ENTERTAINMENT = "entertainment" - UTILITIES = "utilities" - OTHER = "other" +class ToolLabelEnum(StrEnum): + SEARCH = auto() + IMAGE = auto() + VIDEOS = auto() + WEATHER = auto() + FINANCE = auto() + DESIGN = auto() + TRAVEL = auto() + SOCIAL = auto() + NEWS = auto() + MEDICAL = auto() + PRODUCTIVITY = auto() + EDUCATION = auto() + BUSINESS = auto() + ENTERTAINMENT = auto() + UTILITIES = auto() + OTHER = auto() -class ToolProviderType(enum.StrEnum): +class ToolProviderType(StrEnum): """ Enum class for tool provider """ - PLUGIN = "plugin" + PLUGIN = auto() BUILT_IN = "builtin" - WORKFLOW = "workflow" - API = "api" - APP = "app" + WORKFLOW = auto() + API = auto() + APP = auto() DATASET_RETRIEVAL = "dataset-retrieval" - MCP = "mcp" + MCP = auto() @classmethod def value_of(cls, value: str) -> "ToolProviderType": @@ -68,15 +67,15 @@ class ToolProviderType(enum.StrEnum): raise ValueError(f"invalid mode value {value}") -class ApiProviderSchemaType(Enum): +class ApiProviderSchemaType(StrEnum): """ Enum class for api provider schema type. """ - OPENAPI = "openapi" - SWAGGER = "swagger" - OPENAI_PLUGIN = "openai_plugin" - OPENAI_ACTIONS = "openai_actions" + OPENAPI = auto() + SWAGGER = auto() + OPENAI_PLUGIN = auto() + OPENAI_ACTIONS = auto() @classmethod def value_of(cls, value: str) -> "ApiProviderSchemaType": @@ -92,14 +91,14 @@ class ApiProviderSchemaType(Enum): raise ValueError(f"invalid mode value {value}") -class ApiProviderAuthType(Enum): +class ApiProviderAuthType(StrEnum): """ Enum class for api provider auth type. """ - NONE = "none" - API_KEY_HEADER = "api_key_header" - API_KEY_QUERY = "api_key_query" + NONE = auto() + API_KEY_HEADER = auto() + API_KEY_QUERY = auto() @classmethod def value_of(cls, value: str) -> "ApiProviderAuthType": @@ -176,10 +175,10 @@ class ToolInvokeMessage(BaseModel): return value class LogMessage(BaseModel): - class LogStatus(Enum): - START = "start" - ERROR = "error" - SUCCESS = "success" + class LogStatus(StrEnum): + START = auto() + ERROR = auto() + SUCCESS = auto() id: str label: str = Field(..., description="The label of the log") @@ -193,19 +192,19 @@ class ToolInvokeMessage(BaseModel): retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") - class MessageType(Enum): - TEXT = "text" - IMAGE = "image" - LINK = "link" - BLOB = "blob" - JSON = "json" - IMAGE_LINK = "image_link" - BINARY_LINK = "binary_link" - VARIABLE = "variable" - FILE = "file" - LOG = "log" - BLOB_CHUNK = "blob_chunk" - RETRIEVER_RESOURCES = "retriever_resources" + class MessageType(StrEnum): + TEXT = auto() + IMAGE = auto() + LINK = auto() + BLOB = auto() + JSON = auto() + IMAGE_LINK = auto() + BINARY_LINK = auto() + VARIABLE = auto() + FILE = auto() + LOG = auto() + BLOB_CHUNK = auto() + RETRIEVER_RESOURCES = auto() type: MessageType = MessageType.TEXT """ @@ -250,29 +249,29 @@ class ToolParameter(PluginParameter): Overrides type """ - class ToolParameterType(enum.StrEnum): + class ToolParameterType(StrEnum): """ removes TOOLS_SELECTOR from PluginParameterType """ - STRING = PluginParameterType.STRING.value - NUMBER = PluginParameterType.NUMBER.value - BOOLEAN = PluginParameterType.BOOLEAN.value - SELECT = PluginParameterType.SELECT.value - SECRET_INPUT = PluginParameterType.SECRET_INPUT.value - FILE = PluginParameterType.FILE.value - FILES = PluginParameterType.FILES.value - APP_SELECTOR = PluginParameterType.APP_SELECTOR.value - MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value - ANY = PluginParameterType.ANY.value - DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value + STRING = PluginParameterType.STRING + NUMBER = PluginParameterType.NUMBER + BOOLEAN = PluginParameterType.BOOLEAN + SELECT = PluginParameterType.SELECT + SECRET_INPUT = PluginParameterType.SECRET_INPUT + FILE = PluginParameterType.FILE + FILES = PluginParameterType.FILES + APP_SELECTOR = PluginParameterType.APP_SELECTOR + MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR + ANY = PluginParameterType.ANY + DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT # MCP object and array type parameters - ARRAY = MCPServerParameterType.ARRAY.value - OBJECT = MCPServerParameterType.OBJECT.value + ARRAY = MCPServerParameterType.ARRAY + OBJECT = MCPServerParameterType.OBJECT # deprecated, should not use. - SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value + SYSTEM_FILES = PluginParameterType.SYSTEM_FILES def as_normal_type(self): return as_normal_type(self) @@ -280,10 +279,10 @@ class ToolParameter(PluginParameter): def cast_value(self, value: Any): return cast_parameter_value(self, value) - class ToolParameterForm(Enum): - SCHEMA = "schema" # should be set while adding tool - FORM = "form" # should be set before invoking tool - LLM = "llm" # will be set by LLM + class ToolParameterForm(StrEnum): + SCHEMA = auto() # should be set while adding tool + FORM = auto() # should be set before invoking tool + LLM = auto() # will be set by LLM type: ToolParameterType = Field(..., description="The type of the parameter") human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") @@ -446,14 +445,14 @@ class ToolLabel(BaseModel): icon: str = Field(..., description="The icon of the tool") -class ToolInvokeFrom(Enum): +class ToolInvokeFrom(StrEnum): """ Enum class for tool invoke """ - WORKFLOW = "workflow" - AGENT = "agent" - PLUGIN = "plugin" + WORKFLOW = auto() + AGENT = auto() + PLUGIN = auto() class ToolSelector(BaseModel): @@ -478,9 +477,9 @@ class ToolSelector(BaseModel): return self.model_dump() -class CredentialType(enum.StrEnum): +class CredentialType(StrEnum): API_KEY = "api-key" - OAUTH2 = "oauth2" + OAUTH2 = auto() def get_name(self): if self == CredentialType.API_KEY: diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index 54440df725..d010ce05b8 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime -from enum import Enum +from enum import StrEnum, auto from typing import Optional from pydantic import BaseModel, Field @@ -11,12 +11,12 @@ from libs.datetime_utils import naive_utc_now class RouteNodeState(BaseModel): - class Status(Enum): - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - PAUSED = "paused" - EXCEPTION = "exception" + class Status(StrEnum): + RUNNING = auto() + SUCCESS = auto() + FAILED = auto() + PAUSED = auto() + EXCEPTION = auto() id: str = Field(default_factory=lambda: str(uuid.uuid4())) """node state id""" diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 11b11068e7..ce6eb33ecc 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,4 +1,4 @@ -from enum import Enum, StrEnum +from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union from pydantic import BaseModel @@ -25,9 +25,9 @@ class AgentNodeData(BaseNodeData): agent_parameters: dict[str, AgentInput] -class ParamsAutoGenerated(Enum): - CLOSE = 0 - OPEN = 1 +class ParamsAutoGenerated(IntEnum): + CLOSE = auto() + OPEN = auto() class AgentOldVersionModelFeatures(StrEnum): @@ -38,8 +38,8 @@ class AgentOldVersionModelFeatures(StrEnum): TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" - VISION = "vision" + VISION = auto() STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = "document" - VIDEO = "video" - AUDIO = "audio" + DOCUMENT = auto() + VIDEO = auto() + AUDIO = auto() diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index a05cc44c99..850ff14880 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum +from enum import StrEnum, auto from pydantic import BaseModel, Field @@ -19,9 +19,9 @@ class GenerateRouteChunk(BaseModel): Generate Route Chunk. """ - class ChunkType(Enum): - VAR = "var" - TEXT = "text" + class ChunkType(StrEnum): + VAR = auto() + TEXT = auto() type: ChunkType = Field(..., description="generate route chunk type") diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index d357fea7dd..3a77541807 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -259,7 +259,7 @@ class KnowledgeRetrievalNode(BaseNode): ) all_documents = [] dataset_retrieval = DatasetRetrieval() - if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: + if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # fetch model config if node_data.single_retrieval_config is None: raise ValueError("single_retrieval_config is required") @@ -291,7 +291,7 @@ class KnowledgeRetrievalNode(BaseNode): metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, ) - elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: + elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index ef6b12fd59..43fe771bcd 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -9,19 +9,19 @@ import json import logging from dataclasses import asdict, dataclass from datetime import datetime -from enum import Enum +from enum import StrEnum, auto from typing import Any, Optional logger = logging.getLogger(__name__) -class FileStatus(Enum): +class FileStatus(StrEnum): """File status enumeration""" - ACTIVE = "active" # Active status - ARCHIVED = "archived" # Archived - DELETED = "deleted" # Deleted (soft delete) - BACKUP = "backup" # Backup file + ACTIVE = auto() # Active status + ARCHIVED = auto() # Archived + DELETED = auto() # Deleted (soft delete) + BACKUP = auto() # Backup file @dataclass diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index 243df92efe..2431b08d81 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -5,13 +5,13 @@ According to ClickZetta's permission model, different Volume types have differen """ import logging -from enum import Enum +from enum import StrEnum from typing import Optional logger = logging.getLogger(__name__) -class VolumePermission(Enum): +class VolumePermission(StrEnum): """Volume permission type enumeration""" READ = "SELECT" # Corresponds to ClickZetta's SELECT permission diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index 9dde87d800..5258823a07 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -7,7 +7,7 @@ eliminates the need for repetitive language switching logic. """ from dataclasses import dataclass -from enum import Enum +from enum import StrEnum, auto from typing import Any, Optional, Protocol from flask import render_template @@ -17,30 +17,30 @@ from extensions.ext_mail import mail from services.feature_service import BrandingModel, FeatureService -class EmailType(Enum): +class EmailType(StrEnum): """Enumeration of supported email types.""" - RESET_PASSWORD = "reset_password" - RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = "reset_password_when_account_not_exist" - INVITE_MEMBER = "invite_member" - EMAIL_CODE_LOGIN = "email_code_login" - CHANGE_EMAIL_OLD = "change_email_old" - CHANGE_EMAIL_NEW = "change_email_new" - CHANGE_EMAIL_COMPLETED = "change_email_completed" - OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm" - OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify" - OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify" - ACCOUNT_DELETION_SUCCESS = "account_deletion_success" - ACCOUNT_DELETION_VERIFICATION = "account_deletion_verification" - ENTERPRISE_CUSTOM = "enterprise_custom" - QUEUE_MONITOR_ALERT = "queue_monitor_alert" - DOCUMENT_CLEAN_NOTIFY = "document_clean_notify" - EMAIL_REGISTER = "email_register" - EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = "email_register_when_account_exist" - RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = "reset_password_when_account_not_exist_no_register" + RESET_PASSWORD = auto() + RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = auto() + INVITE_MEMBER = auto() + EMAIL_CODE_LOGIN = auto() + CHANGE_EMAIL_OLD = auto() + CHANGE_EMAIL_NEW = auto() + CHANGE_EMAIL_COMPLETED = auto() + OWNER_TRANSFER_CONFIRM = auto() + OWNER_TRANSFER_OLD_NOTIFY = auto() + OWNER_TRANSFER_NEW_NOTIFY = auto() + ACCOUNT_DELETION_SUCCESS = auto() + ACCOUNT_DELETION_VERIFICATION = auto() + ENTERPRISE_CUSTOM = auto() + QUEUE_MONITOR_ALERT = auto() + DOCUMENT_CLEAN_NOTIFY = auto() + EMAIL_REGISTER = auto() + EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto() + RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto() -class EmailLanguage(Enum): +class EmailLanguage(StrEnum): """Supported email languages with fallback handling.""" EN_US = "en-US" diff --git a/api/libs/helper.py b/api/libs/helper.py index f3c46b4843..09dbfac6cb 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -68,7 +68,7 @@ class AppIconUrlField(fields.Raw): if isinstance(obj, dict) and "app" in obj: obj = obj["app"] - if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value: + if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE: return file_helpers.get_signed_file_url(obj.icon) return None diff --git a/api/models/dataset.py b/api/models/dataset.py index 13087bf995..7314945053 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -224,35 +224,35 @@ class Dataset(Base): doc_metadata.append( { "id": "built-in", - "name": BuiltInField.document_name.value, + "name": BuiltInField.document_name, "type": "string", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.uploader.value, + "name": BuiltInField.uploader, "type": "string", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.upload_date.value, + "name": BuiltInField.upload_date, "type": "time", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.last_update_date.value, + "name": BuiltInField.last_update_date, "type": "time", } ) doc_metadata.append( { "id": "built-in", - "name": BuiltInField.source.value, + "name": BuiltInField.source, "type": "string", } ) @@ -544,7 +544,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.source, "type": "string", - "value": MetadataDataSource[self.data_source_type].value, + "value": MetadataDataSource[self.data_source_type], } ) return built_in_fields diff --git a/api/models/model.py b/api/models/model.py index 62b6467e33..50c07268dd 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,7 +3,7 @@ import re import uuid from collections.abc import Mapping from datetime import datetime -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, Literal, Optional, cast from core.plugin.entities.plugin import GenericProviderID @@ -62,9 +62,9 @@ class AppMode(StrEnum): raise ValueError(f"invalid mode value {value}") -class IconType(Enum): - IMAGE = "image" - EMOJI = "emoji" +class IconType(StrEnum): + IMAGE = auto() + EMOJI = auto() class App(Base): @@ -149,15 +149,15 @@ class App(Base): if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get( "strategy", "" ) in {"function_call", "react"}: - self.mode = AppMode.AGENT_CHAT.value + self.mode = AppMode.AGENT_CHAT db.session.commit() return True return False @property def mode_compatible_with_agent(self) -> str: - if self.mode == AppMode.CHAT.value and self.is_agent: - return AppMode.AGENT_CHAT.value + if self.mode == AppMode.CHAT and self.is_agent: + return AppMode.AGENT_CHAT return str(self.mode) @@ -713,7 +713,7 @@ class Conversation(Base): model_config = {} app_model_config: Optional[AppModelConfig] = None - if self.mode == AppMode.ADVANCED_CHAT.value: + if self.mode == AppMode.ADVANCED_CHAT: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) model_config = override_model_configs diff --git a/api/models/provider.py b/api/models/provider.py index 9a344ea56d..17094f6d6e 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum, auto from functools import cached_property from typing import Optional @@ -12,9 +12,9 @@ from .engine import db from .types import StringUUID -class ProviderType(Enum): - CUSTOM = "custom" - SYSTEM = "system" +class ProviderType(StrEnum): + CUSTOM = auto() + SYSTEM = auto() @staticmethod def value_of(value: str) -> "ProviderType": @@ -24,14 +24,14 @@ class ProviderType(Enum): raise ValueError(f"No matching enum found for value '{value}'") -class ProviderQuotaType(Enum): - PAID = "paid" +class ProviderQuotaType(StrEnum): + PAID = auto() """hosted paid quota""" - FREE = "free" + FREE = auto() """third-party free quota""" - TRIAL = "trial" + TRIAL = auto() """hosted trial quota""" @staticmethod diff --git a/api/models/workflow.py b/api/models/workflow.py index 4686b38b01..78582dd3fb 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,7 +2,7 @@ import json import logging from collections.abc import Mapping, Sequence from datetime import datetime -from enum import Enum, StrEnum +from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, Optional, Union, cast from uuid import uuid4 @@ -41,13 +41,13 @@ from .types import EnumText, StringUUID logger = logging.getLogger(__name__) -class WorkflowType(Enum): +class WorkflowType(StrEnum): """ Workflow Type Enum """ - WORKFLOW = "workflow" - CHAT = "chat" + WORKFLOW = auto() + CHAT = auto() @classmethod def value_of(cls, value: str) -> "WorkflowType": @@ -777,7 +777,7 @@ class WorkflowNodeExecutionModel(Base): return extras -class WorkflowAppLogCreatedFrom(Enum): +class WorkflowAppLogCreatedFrom(StrEnum): """ Workflow App Log Created From Enum """ diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 6f0ab2546a..f2ffa3b170 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -32,14 +32,14 @@ class AdvancedPromptTemplateService: def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str): context_prompt = copy.deepcopy(CONTEXT) - if app_mode == AppMode.CHAT.value: + if app_mode == AppMode.CHAT: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt ) elif model_mode == "chat": return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - elif app_mode == AppMode.COMPLETION.value: + elif app_mode == AppMode.COMPLETION: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt @@ -73,7 +73,7 @@ class AdvancedPromptTemplateService: def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str): baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) - if app_mode == AppMode.CHAT.value: + if app_mode == AppMode.CHAT: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt @@ -82,7 +82,7 @@ class AdvancedPromptTemplateService: return cls.get_chat_prompt( copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt ) - elif app_mode == AppMode.COMPLETION.value: + elif app_mode == AppMode.COMPLETION: if model_mode == "completion": return cls.get_completion_prompt( copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index e812fcc992..c1ef206c99 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -60,7 +60,7 @@ class AppGenerateService: request_id = RateLimit.gen_request_key() try: request_id = rate_limit.enter(request_id) - if app_model.mode == AppMode.COMPLETION.value: + if app_model.mode == AppMode.COMPLETION: return rate_limit.generate( CompletionAppGenerator.convert_to_event_stream( CompletionAppGenerator().generate( @@ -69,7 +69,7 @@ class AppGenerateService: ), request_id=request_id, ) - elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: return rate_limit.generate( AgentChatAppGenerator.convert_to_event_stream( AgentChatAppGenerator().generate( @@ -78,7 +78,7 @@ class AppGenerateService: ), request_id, ) - elif app_model.mode == AppMode.CHAT.value: + elif app_model.mode == AppMode.CHAT: return rate_limit.generate( ChatAppGenerator.convert_to_event_stream( ChatAppGenerator().generate( @@ -87,7 +87,7 @@ class AppGenerateService: ), request_id=request_id, ) - elif app_model.mode == AppMode.ADVANCED_CHAT.value: + elif app_model.mode == AppMode.ADVANCED_CHAT: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) return rate_limit.generate( @@ -103,7 +103,7 @@ class AppGenerateService: ), request_id=request_id, ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: workflow_id = args.get("workflow_id") workflow = cls._get_workflow(app_model, invoke_from, workflow_id) return rate_limit.generate( @@ -155,14 +155,14 @@ class AppGenerateService: @classmethod def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( AdvancedChatAppGenerator().single_iteration_generate( app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( WorkflowAppGenerator().single_iteration_generate( @@ -174,14 +174,14 @@ class AppGenerateService: @classmethod def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( AdvancedChatAppGenerator().single_loop_generate( app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator.convert_to_event_stream( WorkflowAppGenerator().single_loop_generate( diff --git a/api/services/app_service.py b/api/services/app_service.py index 9b200a570d..c553ec8c19 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -40,15 +40,15 @@ class AppService: filters = [App.tenant_id == tenant_id, App.is_universal == False] if args["mode"] == "workflow": - filters.append(App.mode == AppMode.WORKFLOW.value) + filters.append(App.mode == AppMode.WORKFLOW) elif args["mode"] == "completion": - filters.append(App.mode == AppMode.COMPLETION.value) + filters.append(App.mode == AppMode.COMPLETION) elif args["mode"] == "chat": - filters.append(App.mode == AppMode.CHAT.value) + filters.append(App.mode == AppMode.CHAT) elif args["mode"] == "advanced-chat": - filters.append(App.mode == AppMode.ADVANCED_CHAT.value) + filters.append(App.mode == AppMode.ADVANCED_CHAT) elif args["mode"] == "agent-chat": - filters.append(App.mode == AppMode.AGENT_CHAT.value) + filters.append(App.mode == AppMode.AGENT_CHAT) if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) @@ -171,7 +171,7 @@ class AppService: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None # get original app model config - if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: + if app.mode == AppMode.AGENT_CHAT or app.is_agent: model_config = app.app_model_config if not model_config: return app diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 9b1999d813..a7cd5e9487 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) class AudioService: @classmethod def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: raise ValueError("Speech to text is not enabled") @@ -88,7 +88,7 @@ class AudioService: def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False): with app.app_context(): if voice is None: - if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if is_draft: workflow = WorkflowService().get_draft_workflow(app_model=app_model) else: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8d1db20b04..8c8b368d42 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1004,7 +1004,7 @@ class DocumentService: if dataset.built_in_field_enabled: if document.doc_metadata: doc_metadata = copy.deepcopy(document.doc_metadata) - doc_metadata[BuiltInField.document_name.value] = name + doc_metadata[BuiltInField.document_name] = name document.doc_metadata = doc_metadata document.name = name diff --git a/api/services/message_service.py b/api/services/message_service.py index 13c8e948ca..7c1f74e488 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -229,7 +229,7 @@ class MessageService: model_manager = ModelManager() - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: workflow_service = WorkflowService() if invoke_from == InvokeFrom.DEBUGGER: workflow = workflow_service.get_draft_workflow(app_model=app_model) diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 05fa5a95bc..208ecdb79e 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -131,11 +131,11 @@ class MetadataService: @staticmethod def get_built_in_fields(): return [ - {"name": BuiltInField.document_name.value, "type": "string"}, - {"name": BuiltInField.uploader.value, "type": "string"}, - {"name": BuiltInField.upload_date.value, "type": "time"}, - {"name": BuiltInField.last_update_date.value, "type": "time"}, - {"name": BuiltInField.source.value, "type": "string"}, + {"name": BuiltInField.document_name, "type": "string"}, + {"name": BuiltInField.uploader, "type": "string"}, + {"name": BuiltInField.upload_date, "type": "time"}, + {"name": BuiltInField.last_update_date, "type": "time"}, + {"name": BuiltInField.source, "type": "string"}, ] @staticmethod @@ -153,11 +153,11 @@ class MetadataService: doc_metadata = {} else: doc_metadata = copy.deepcopy(document.doc_metadata) - doc_metadata[BuiltInField.document_name.value] = document.name - doc_metadata[BuiltInField.uploader.value] = document.uploader - doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp() - doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp() - doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value + doc_metadata[BuiltInField.document_name] = document.name + doc_metadata[BuiltInField.uploader] = document.uploader + doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp() + doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp() + doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type] document.doc_metadata = doc_metadata db.session.add(document) dataset.built_in_field_enabled = True @@ -183,11 +183,11 @@ class MetadataService: doc_metadata = {} else: doc_metadata = copy.deepcopy(document.doc_metadata) - doc_metadata.pop(BuiltInField.document_name.value, None) - doc_metadata.pop(BuiltInField.uploader.value, None) - doc_metadata.pop(BuiltInField.upload_date.value, None) - doc_metadata.pop(BuiltInField.last_update_date.value, None) - doc_metadata.pop(BuiltInField.source.value, None) + doc_metadata.pop(BuiltInField.document_name, None) + doc_metadata.pop(BuiltInField.uploader, None) + doc_metadata.pop(BuiltInField.upload_date, None) + doc_metadata.pop(BuiltInField.last_update_date, None) + doc_metadata.pop(BuiltInField.source, None) document.doc_metadata = doc_metadata db.session.add(document) document_ids.append(document.id) @@ -211,11 +211,11 @@ class MetadataService: for metadata_value in operation.metadata_list: doc_metadata[metadata_value.name] = metadata_value.value if dataset.built_in_field_enabled: - doc_metadata[BuiltInField.document_name.value] = document.name - doc_metadata[BuiltInField.uploader.value] = document.uploader - doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp() - doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp() - doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value + doc_metadata[BuiltInField.document_name] = document.name + doc_metadata[BuiltInField.uploader] = document.uploader + doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp() + doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp() + doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type] document.doc_metadata = doc_metadata db.session.add(document) db.session.commit() diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index bae2921a27..34f10ae407 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -256,7 +256,7 @@ class PluginMigration: return [] agent_app_model_config_ids = [ - app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value + app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT ] rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all() diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 8a58289b22..376f96c03a 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -65,7 +65,7 @@ class WorkflowConverter: new_app = App() new_app.tenant_id = app_model.tenant_id new_app.name = name or app_model.name + "(workflow)" - new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW new_app.icon_type = icon_type or app_model.icon_type new_app.icon = icon or app_model.icon new_app.icon_background = icon_background or app_model.icon_background @@ -203,7 +203,7 @@ class WorkflowConverter: app_mode_enum = AppMode.value_of(app_model.mode) app_config: EasyUIBasedAppConfig if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent: - app_model.mode = AppMode.AGENT_CHAT.value + app_model.mode = AppMode.AGENT_CHAT app_config = AgentChatAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) @@ -279,7 +279,7 @@ class WorkflowConverter: "app_id": app_model.id, "tool_variable": tool_variable, "inputs": inputs, - "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "", + "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT else "", }, } @@ -618,7 +618,7 @@ class WorkflowConverter: :param app_model: App instance :return: AppMode """ - if app_model.mode == AppMode.COMPLETION.value: + if app_model.mode == AppMode.COMPLETION: return AppMode.WORKFLOW else: return AppMode.ADVANCED_CHAT diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 4e0ae15841..e680de3502 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -828,7 +828,7 @@ class WorkflowService: # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: + if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}: raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow @@ -844,11 +844,11 @@ class WorkflowService: return new_app def validate_features_structure(self, app_model: App, features: dict): - if app_model.mode == AppMode.ADVANCED_CHAT.value: + if app_model.mode == AppMode.ADVANCED_CHAT: return AdvancedChatAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) - elif app_model.mode == AppMode.WORKFLOW.value: + elif app_model.mode == AppMode.WORKFLOW: return WorkflowAppConfigManager.config_validate( tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) diff --git a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py index 9ed9008af9..3ec265d009 100644 --- a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py +++ b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py @@ -42,7 +42,7 @@ class TestAdvancedPromptTemplateService: # Test data for Baichuan model args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", @@ -77,7 +77,7 @@ class TestAdvancedPromptTemplateService: # Test data for common model args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", @@ -116,7 +116,7 @@ class TestAdvancedPromptTemplateService: for model_name in test_cases: args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": model_name, "has_context": "true", @@ -144,7 +144,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -173,7 +173,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "chat", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -202,7 +202,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "completion", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -230,7 +230,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "chat", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -257,7 +257,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "false") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "false") # Assert: Verify the expected outcomes assert result is not None @@ -303,7 +303,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "unsupported_mode", "true") + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "unsupported_mode", "true") # Assert: Verify empty dict is returned assert result == {} @@ -442,7 +442,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -473,7 +473,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "chat", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -502,7 +502,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "completion", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true") # Assert: Verify the expected outcomes assert result is not None @@ -530,7 +530,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "chat", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "true") # Assert: Verify the expected outcomes assert result is not None @@ -557,7 +557,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "false") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false") # Assert: Verify the expected outcomes assert result is not None @@ -603,7 +603,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Act: Execute the method under test - result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "unsupported_mode", "true") + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "unsupported_mode", "true") # Assert: Verify empty dict is returned assert result == {} @@ -621,7 +621,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Test all app modes - app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value] + app_modes = [AppMode.CHAT, AppMode.COMPLETION] model_modes = ["completion", "chat"] for app_mode in app_modes: @@ -653,7 +653,7 @@ class TestAdvancedPromptTemplateService: fake = Faker() # Test all app modes - app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value] + app_modes = [AppMode.CHAT, AppMode.COMPLETION] model_modes = ["completion", "chat"] for app_mode in app_modes: @@ -686,10 +686,10 @@ class TestAdvancedPromptTemplateService: # Test edge cases edge_cases = [ {"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"}, - {"app_mode": AppMode.CHAT.value, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"}, - {"app_mode": AppMode.CHAT.value, "model_mode": "completion", "model_name": "", "has_context": "true"}, + {"app_mode": AppMode.CHAT, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"}, + {"app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "", "has_context": "true"}, { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "", @@ -723,7 +723,7 @@ class TestAdvancedPromptTemplateService: # Test with context args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", @@ -757,7 +757,7 @@ class TestAdvancedPromptTemplateService: # Test with context args = { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", @@ -786,25 +786,25 @@ class TestAdvancedPromptTemplateService: # Test different scenarios test_scenarios = [ { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", }, { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "chat", "model_name": "gpt-3.5-turbo", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "chat", "model_name": "gpt-3.5-turbo", "has_context": "true", @@ -843,25 +843,25 @@ class TestAdvancedPromptTemplateService: # Test different scenarios test_scenarios = [ { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", }, { - "app_mode": AppMode.CHAT.value, + "app_mode": AppMode.CHAT, "model_mode": "chat", "model_name": "baichuan-13b-chat", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "completion", "model_name": "baichuan-13b-chat", "has_context": "true", }, { - "app_mode": AppMode.COMPLETION.value, + "app_mode": AppMode.COMPLETION, "model_mode": "chat", "model_name": "baichuan-13b-chat", "has_context": "true", diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index 4646531a4e..d0f7e945f1 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -255,7 +255,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Try to create metadata with built-in field name - built_in_field_name = BuiltInField.document_name.value + built_in_field_name = BuiltInField.document_name metadata_args = MetadataArgs(type="string", name=built_in_field_name) # Act & Assert: Verify proper error handling @@ -375,7 +375,7 @@ class TestMetadataService: metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Try to update with built-in field name - built_in_field_name = BuiltInField.document_name.value + built_in_field_name = BuiltInField.document_name with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) @@ -540,11 +540,11 @@ class TestMetadataService: field_names = [field["name"] for field in result] field_types = [field["type"] for field in result] - assert BuiltInField.document_name.value in field_names - assert BuiltInField.uploader.value in field_names - assert BuiltInField.upload_date.value in field_names - assert BuiltInField.last_update_date.value in field_names - assert BuiltInField.source.value in field_names + assert BuiltInField.document_name in field_names + assert BuiltInField.uploader in field_names + assert BuiltInField.upload_date in field_names + assert BuiltInField.last_update_date in field_names + assert BuiltInField.source in field_names # Verify field types assert "string" in field_types @@ -682,11 +682,11 @@ class TestMetadataService: # Set document metadata with built-in fields document.doc_metadata = { - BuiltInField.document_name.value: document.name, - BuiltInField.uploader.value: "test_uploader", - BuiltInField.upload_date.value: 1234567890.0, - BuiltInField.last_update_date.value: 1234567890.0, - BuiltInField.source.value: "test_source", + BuiltInField.document_name: document.name, + BuiltInField.uploader: "test_uploader", + BuiltInField.upload_date: 1234567890.0, + BuiltInField.last_update_date: 1234567890.0, + BuiltInField.source: "test_source", } db.session.add(document) db.session.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index 018eb6d896..b61df18b90 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -96,7 +96,7 @@ class TestWorkflowService: app.tenant_id = fake.uuid4() app.name = fake.company() app.description = fake.text() - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW app.icon_type = "emoji" app.icon = "🤖" app.icon_background = "#FFEAD5" @@ -883,7 +883,7 @@ class TestWorkflowService: # Create chat mode app app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT # Create app model config (required for conversion) from models.model import AppModelConfig @@ -926,7 +926,7 @@ class TestWorkflowService: # Assert assert result is not None - assert result.mode == AppMode.ADVANCED_CHAT.value # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW + assert result.mode == AppMode.ADVANCED_CHAT # CHAT mode converts to ADVANCED_CHAT, not WORKFLOW assert result.name == conversion_args["name"] assert result.icon == conversion_args["icon"] assert result.icon_type == conversion_args["icon_type"] @@ -945,7 +945,7 @@ class TestWorkflowService: # Create completion mode app app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.COMPLETION.value + app.mode = AppMode.COMPLETION # Create app model config (required for conversion) from models.model import AppModelConfig @@ -988,7 +988,7 @@ class TestWorkflowService: # Assert assert result is not None - assert result.mode == AppMode.WORKFLOW.value + assert result.mode == AppMode.WORKFLOW assert result.name == conversion_args["name"] assert result.icon == conversion_args["icon"] assert result.icon_type == conversion_args["icon_type"] @@ -1007,7 +1007,7 @@ class TestWorkflowService: # Create workflow mode app (already in workflow mode) app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW from extensions.ext_database import db @@ -1030,7 +1030,7 @@ class TestWorkflowService: # Arrange fake = Faker() app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.ADVANCED_CHAT.value + app.mode = AppMode.ADVANCED_CHAT from extensions.ext_database import db @@ -1061,7 +1061,7 @@ class TestWorkflowService: # Arrange fake = Faker() app = self._create_test_app(db_session_with_containers, fake) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW from extensions.ext_database import db diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index f1d741602a..895ebdd751 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -29,7 +29,7 @@ class TestHandleMCPRequest: """Setup test fixtures""" self.app = Mock(spec=App) self.app.name = "test_app" - self.app.mode = AppMode.CHAT.value + self.app.mode = AppMode.CHAT self.mcp_server = Mock(spec=AppMCPServer) self.mcp_server.description = "Test server" @@ -196,7 +196,7 @@ class TestIndividualHandlers: def test_handle_list_tools(self): """Test list tools handler""" app_name = "test_app" - app_mode = AppMode.CHAT.value + app_mode = AppMode.CHAT description = "Test server" parameters_dict: dict[str, str] = {} user_input_form: list[VariableEntity] = [] @@ -212,7 +212,7 @@ class TestIndividualHandlers: def test_handle_call_tool(self, mock_app_generate): """Test call tool handler""" app = Mock(spec=App) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT # Create mock request mock_request = Mock() @@ -252,7 +252,7 @@ class TestUtilityFunctions: def test_build_parameter_schema_chat_mode(self): """Test building parameter schema for chat mode""" - app_mode = AppMode.CHAT.value + app_mode = AppMode.CHAT parameters_dict: dict[str, str] = {"name": "Enter your name"} user_input_form = [ @@ -275,7 +275,7 @@ class TestUtilityFunctions: def test_build_parameter_schema_workflow_mode(self): """Test building parameter schema for workflow mode""" - app_mode = AppMode.WORKFLOW.value + app_mode = AppMode.WORKFLOW parameters_dict: dict[str, str] = {"input_text": "Enter text"} user_input_form = [ @@ -298,7 +298,7 @@ class TestUtilityFunctions: def test_prepare_tool_arguments_chat_mode(self): """Test preparing tool arguments for chat mode""" app = Mock(spec=App) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT arguments = {"query": "test question", "name": "John"} @@ -312,7 +312,7 @@ class TestUtilityFunctions: def test_prepare_tool_arguments_workflow_mode(self): """Test preparing tool arguments for workflow mode""" app = Mock(spec=App) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW arguments = {"input_text": "test input"} @@ -324,7 +324,7 @@ class TestUtilityFunctions: def test_prepare_tool_arguments_completion_mode(self): """Test preparing tool arguments for completion mode""" app = Mock(spec=App) - app.mode = AppMode.COMPLETION.value + app.mode = AppMode.COMPLETION arguments = {"name": "John"} @@ -336,7 +336,7 @@ class TestUtilityFunctions: def test_extract_answer_from_mapping_response_chat(self): """Test extracting answer from mapping response for chat mode""" app = Mock(spec=App) - app.mode = AppMode.CHAT.value + app.mode = AppMode.CHAT response = {"answer": "test answer", "other": "data"} @@ -347,7 +347,7 @@ class TestUtilityFunctions: def test_extract_answer_from_mapping_response_workflow(self): """Test extracting answer from mapping response for workflow mode""" app = Mock(spec=App) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW response = {"data": {"outputs": {"result": "test result"}}} diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 0a09167349..2ca781bae5 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -66,7 +66,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables): app_model = MagicMock() app_model.id = "app_id" app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.CHAT.value + app_model.mode = AppMode.CHAT api_based_extension_id = "api_based_extension_id" mock_api_based_extension = APIBasedExtension( @@ -127,7 +127,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): app_model = MagicMock() app_model.id = "app_id" app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.WORKFLOW.value + app_model.mode = AppMode.WORKFLOW api_based_extension_id = "api_based_extension_id" mock_api_based_extension = APIBasedExtension(