remove .value (#26633)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2025-10-11 10:08:29 +09:00 committed by GitHub
parent bb6a331490
commit 1bd621f819
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
138 changed files with 613 additions and 633 deletions

View File

@ -1,23 +1,24 @@
from enum import Enum
from enum import StrEnum
from typing import Literal
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AuthMethod(StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
class OpenSearchConfig(BaseSettings):
"""
Configuration settings for OpenSearch
"""
class AuthMethod(Enum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: str | None = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None,

View File

@ -304,7 +304,7 @@ class AppCopyApi(Resource):
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name=args.get("name"),
description=args.get("description"),

View File

@ -70,9 +70,9 @@ class AppImportApi(Resource):
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@ -97,7 +97,7 @@ class AppImportConfirmApi(Resource):
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200

View File

@ -309,7 +309,7 @@ class ChatConversationApi(Resource):
)
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args["sort_by"]:
case "created_at":

View File

@ -52,7 +52,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -127,7 +127,7 @@ class DailyConversationStatistic(Resource):
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
)
.select_from(Message)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
)
if args["start"]:
@ -190,7 +190,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -263,7 +263,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -345,7 +345,7 @@ FROM
WHERE
c.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -432,7 +432,7 @@ LEFT JOIN
WHERE
m.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -509,7 +509,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -584,7 +584,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc

View File

@ -47,7 +47,7 @@ WHERE
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
timezone = pytz.timezone(account.timezone)
@ -115,7 +115,7 @@ WHERE
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
timezone = pytz.timezone(account.timezone)
@ -183,7 +183,7 @@ WHERE
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
timezone = pytz.timezone(account.timezone)
@ -269,7 +269,7 @@ GROUP BY
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
timezone = pytz.timezone(account.timezone)

View File

@ -103,7 +103,7 @@ class ActivateApi(Resource):
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()

View File

@ -130,11 +130,11 @@ class OAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
# Check account status
if account.status == AccountStatus.BANNED.value:
if account.status == AccountStatus.BANNED:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
if account.status == AccountStatus.PENDING:
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()

View File

@ -256,7 +256,7 @@ class DataSourceNotionApi(Resource):
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,

View File

@ -500,7 +500,7 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value,
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=args["doc_form"],
)
@ -512,7 +512,7 @@ class DatasetIndexingEstimateApi(Resource):
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
@ -529,7 +529,7 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": website_info_list["provider"],
@ -786,7 +786,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
@ -813,9 +813,9 @@ class DatasetRetrievalSettingApi(Resource):
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH,
]
}
case _:
@ -842,7 +842,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
@ -867,9 +867,9 @@ class DatasetRetrievalSettingMockApi(Resource):
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH,
]
}
case _:

View File

@ -475,7 +475,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form
)
indexing_runner = IndexingRunner()
@ -538,7 +538,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
@ -546,7 +546,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
@ -563,7 +563,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],

View File

@ -60,9 +60,9 @@ class RagPipelineImportApi(Resource):
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@ -87,7 +87,7 @@ class RagPipelineImportConfirmApi(Resource):
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200

View File

@ -25,8 +25,8 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
As a result, it could only be considered as an end user id.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
with Session(db.engine) as session:
user_model = None
@ -85,7 +85,7 @@ def get_user_tenant(view: Callable[P, R] | None = None):
raise ValueError("tenant_id is required")
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
tenant_model = (

View File

@ -313,7 +313,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
with Session(db.engine, expire_on_commit=False) as session:
end_user = (
@ -332,7 +332,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
session_id=user_id,
)
session.add(end_user)

View File

@ -197,12 +197,12 @@ class DatasetConfigManager:
# strategy
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER
has_datasets = False
if config.get("agent_mode", {}).get("strategy") in {
PlanningStrategy.ROUTER.value,
PlanningStrategy.REACT_ROUTER.value,
PlanningStrategy.ROUTER,
PlanningStrategy.REACT_ROUTER,
}:
for tool in config.get("agent_mode", {}).get("tools", []):
key = list(tool.keys())[0]

View File

@ -68,7 +68,7 @@ class ModelConfigConverter:
# get model mode
model_mode = model_config.mode
if not model_mode:
model_mode = LLMMode.CHAT.value
model_mode = LLMMode.CHAT
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value

View File

@ -100,7 +100,7 @@ class PromptTemplateConfigManager:
if config["model"]["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION:
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
@ -110,7 +110,7 @@ class PromptTemplateConfigManager:
if not assistant_prefix:
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
if config["model"]["mode"] == ModelMode.CHAT.value:
if config["model"]["mode"] == ModelMode.CHAT:
prompt_list = config["chat_prompt_config"]["prompt"]
if len(prompt_list) > 10:

View File

@ -186,7 +186,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
raise ValueError("enabled in agent_mode must be of boolean type")
if not agent_mode.get("strategy"):
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
agent_mode["strategy"] = PlanningStrategy.ROUTER
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")

View File

@ -198,9 +198,9 @@ class AgentChatAppRunner(AppRunner):
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
runner_cls = CotChatAgentRunner
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value:
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
runner_cls = CotCompletionAgentRunner
else:
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")

View File

@ -229,8 +229,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@ -100,8 +100,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
@ -244,8 +244,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@ -49,7 +49,7 @@ class DatasourceProviderApiEntity(BaseModel):
for datasource in datasources:
if datasource.get("parameters"):
for parameter in datasource.get("parameters"):
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES:
parameter["type"] = "files"
# -------------

View File

@ -54,16 +54,16 @@ class DatasourceParameter(PluginParameter):
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
STRING = PluginParameterType.STRING
NUMBER = PluginParameterType.NUMBER
BOOLEAN = PluginParameterType.BOOLEAN
SELECT = PluginParameterType.SELECT
SECRET_INPUT = PluginParameterType.SECRET_INPUT
FILE = PluginParameterType.FILE
FILES = PluginParameterType.FILES
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
def as_normal_type(self):
return as_normal_type(self)

View File

@ -207,7 +207,7 @@ class ProviderConfiguration(BaseModel):
"""
stmt = select(Provider).where(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_type == ProviderType.CUSTOM,
Provider.provider_name.in_(self._get_provider_names()),
)
@ -458,7 +458,7 @@ class ProviderConfiguration(BaseModel):
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
provider_type=ProviderType.CUSTOM,
is_valid=True,
credential_id=new_record.id,
)
@ -1414,7 +1414,7 @@ class ProviderConfiguration(BaseModel):
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
if credential_form_schema.type.value == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -343,7 +343,7 @@ class IndexingRunner:
if file_detail:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value,
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
@ -356,7 +356,7 @@ class IndexingRunner:
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
@ -379,7 +379,7 @@ class IndexingRunner:
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],

View File

@ -224,8 +224,8 @@ def _handle_native_json_schema(
# Set appropriate response format if required by the model
for rule in rules:
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA
return model_parameters
@ -239,10 +239,10 @@ def _set_response_format(model_parameters: dict, rules: list):
"""
for rule in rules:
if rule.name == "response_format":
if ResponseFormat.JSON.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON.value
elif ResponseFormat.JSON_OBJECT.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
if ResponseFormat.JSON in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON
elif ResponseFormat.JSON_OBJECT in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT
def _handle_prompt_based_schema(

View File

@ -213,9 +213,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_metadata.update(json.loads(node_execution.execution_metadata))
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN.value
span_kind = OpenInferenceSpanKindValues.CHAIN
if node_execution.node_type == "llm":
span_kind = OpenInferenceSpanKindValues.LLM.value
span_kind = OpenInferenceSpanKindValues.LLM
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
@ -230,18 +230,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
elif node_execution.node_type == "dataset_retrieval":
span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
span_kind = OpenInferenceSpanKindValues.RETRIEVER
elif node_execution.node_type == "tool":
span_kind = OpenInferenceSpanKindValues.TOOL.value
span_kind = OpenInferenceSpanKindValues.TOOL
else:
span_kind = OpenInferenceSpanKindValues.CHAIN.value
span_kind = OpenInferenceSpanKindValues.CHAIN
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},

View File

@ -73,7 +73,7 @@ class LangFuseDataTrace(BaseTraceInstance):
if trace_info.message_id:
trace_id = trace_info.trace_id or trace_info.message_id
name = TraceTaskName.MESSAGE_TRACE.value
name = TraceTaskName.MESSAGE_TRACE
trace_data = LangfuseTrace(
id=trace_id,
user_id=user_id,
@ -88,7 +88,7 @@ class LangFuseDataTrace(BaseTraceInstance):
self.add_trace(langfuse_trace_data=trace_data)
workflow_span_data = LangfuseSpan(
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
name=TraceTaskName.WORKFLOW_TRACE,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
trace_id=trace_id,
@ -103,7 +103,7 @@ class LangFuseDataTrace(BaseTraceInstance):
trace_data = LangfuseTrace(
id=trace_id,
user_id=user_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
name=TraceTaskName.WORKFLOW_TRACE,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
metadata=metadata,
@ -253,7 +253,7 @@ class LangFuseDataTrace(BaseTraceInstance):
trace_data = LangfuseTrace(
id=trace_id,
user_id=user_id,
name=TraceTaskName.MESSAGE_TRACE.value,
name=TraceTaskName.MESSAGE_TRACE,
input={
"message": trace_info.inputs,
"files": file_list,
@ -303,7 +303,7 @@ class LangFuseDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
span_data = LangfuseSpan(
name=TraceTaskName.MODERATION_TRACE.value,
name=TraceTaskName.MODERATION_TRACE,
input=trace_info.inputs,
output={
"action": trace_info.action,
@ -331,7 +331,7 @@ class LangFuseDataTrace(BaseTraceInstance):
)
generation_data = LangfuseGeneration(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
input=trace_info.inputs,
output=str(trace_info.suggested_question),
trace_id=trace_info.trace_id or trace_info.message_id,
@ -349,7 +349,7 @@ class LangFuseDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
dataset_retrieval_span_data = LangfuseSpan(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
input=trace_info.inputs,
output={"documents": trace_info.documents},
trace_id=trace_info.trace_id or trace_info.message_id,
@ -377,7 +377,7 @@ class LangFuseDataTrace(BaseTraceInstance):
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
name_generation_trace_data = LangfuseTrace(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
name=TraceTaskName.GENERATE_NAME_TRACE,
input=trace_info.inputs,
output=trace_info.outputs,
user_id=trace_info.tenant_id,
@ -388,7 +388,7 @@ class LangFuseDataTrace(BaseTraceInstance):
self.add_trace(langfuse_trace_data=name_generation_trace_data)
name_generation_span_data = LangfuseSpan(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
name=TraceTaskName.GENERATE_NAME_TRACE,
input=trace_info.inputs,
output=trace_info.outputs,
trace_id=trace_info.conversation_id,

View File

@ -81,7 +81,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if trace_info.message_id:
message_run = LangSmithRunModel(
id=trace_info.message_id,
name=TraceTaskName.MESSAGE_TRACE.value,
name=TraceTaskName.MESSAGE_TRACE,
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
run_type=LangSmithRunType.chain,
@ -110,7 +110,7 @@ class LangSmithDataTrace(BaseTraceInstance):
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
name=TraceTaskName.WORKFLOW_TRACE,
inputs=dict(trace_info.workflow_run_inputs),
run_type=LangSmithRunType.tool,
start_time=trace_info.workflow_data.created_at,
@ -271,7 +271,7 @@ class LangSmithDataTrace(BaseTraceInstance):
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
id=message_id,
name=TraceTaskName.MESSAGE_TRACE.value,
name=TraceTaskName.MESSAGE_TRACE,
inputs=trace_info.inputs,
run_type=LangSmithRunType.chain,
start_time=trace_info.start_time,
@ -327,7 +327,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
langsmith_run = LangSmithRunModel(
name=TraceTaskName.MODERATION_TRACE.value,
name=TraceTaskName.MODERATION_TRACE,
inputs=trace_info.inputs,
outputs={
"action": trace_info.action,
@ -362,7 +362,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if message_data is None:
return
suggested_question_run = LangSmithRunModel(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
inputs=trace_info.inputs,
outputs=trace_info.suggested_question,
run_type=LangSmithRunType.tool,
@ -391,7 +391,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
dataset_retrieval_run = LangSmithRunModel(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
inputs=trace_info.inputs,
outputs={"documents": trace_info.documents},
run_type=LangSmithRunType.retriever,
@ -447,7 +447,7 @@ class LangSmithDataTrace(BaseTraceInstance):
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
name_run = LangSmithRunModel(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
name=TraceTaskName.GENERATE_NAME_TRACE,
inputs=trace_info.inputs,
outputs=trace_info.outputs,
run_type=LangSmithRunType.tool,

View File

@ -108,7 +108,7 @@ class OpikDataTrace(BaseTraceInstance):
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE.value,
"name": TraceTaskName.MESSAGE_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
@ -125,7 +125,7 @@ class OpikDataTrace(BaseTraceInstance):
"id": root_span_id,
"parent_span_id": None,
"trace_id": opik_trace_id,
"name": TraceTaskName.WORKFLOW_TRACE.value,
"name": TraceTaskName.WORKFLOW_TRACE,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"start_time": trace_info.start_time,
@ -138,7 +138,7 @@ class OpikDataTrace(BaseTraceInstance):
else:
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE.value,
"name": TraceTaskName.MESSAGE_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
@ -290,7 +290,7 @@ class OpikDataTrace(BaseTraceInstance):
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, dify_trace_id),
"name": TraceTaskName.MESSAGE_TRACE.value,
"name": TraceTaskName.MESSAGE_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(metadata),
@ -329,7 +329,7 @@ class OpikDataTrace(BaseTraceInstance):
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.MODERATION_TRACE.value,
"name": TraceTaskName.MODERATION_TRACE,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
@ -355,7 +355,7 @@ class OpikDataTrace(BaseTraceInstance):
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or message_data.updated_at,
@ -375,7 +375,7 @@ class OpikDataTrace(BaseTraceInstance):
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
@ -405,7 +405,7 @@ class OpikDataTrace(BaseTraceInstance):
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
"name": TraceTaskName.GENERATE_NAME_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
@ -420,7 +420,7 @@ class OpikDataTrace(BaseTraceInstance):
span_data = {
"trace_id": trace.id,
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
"name": TraceTaskName.GENERATE_NAME_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),

View File

@ -104,7 +104,7 @@ class WeaveDataTrace(BaseTraceInstance):
message_run = WeaveTraceModel(
id=trace_info.message_id,
op=str(TraceTaskName.MESSAGE_TRACE.value),
op=str(TraceTaskName.MESSAGE_TRACE),
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
total_tokens=trace_info.total_tokens,
@ -126,7 +126,7 @@ class WeaveDataTrace(BaseTraceInstance):
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_run_id,
op=str(TraceTaskName.WORKFLOW_TRACE.value),
op=str(TraceTaskName.WORKFLOW_TRACE),
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
attributes=workflow_attributes,
@ -253,7 +253,7 @@ class WeaveDataTrace(BaseTraceInstance):
message_run = WeaveTraceModel(
id=trace_id,
op=str(TraceTaskName.MESSAGE_TRACE.value),
op=str(TraceTaskName.MESSAGE_TRACE),
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
@ -300,7 +300,7 @@ class WeaveDataTrace(BaseTraceInstance):
moderation_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.MODERATION_TRACE.value),
op=str(TraceTaskName.MODERATION_TRACE),
inputs=trace_info.inputs,
outputs={
"action": trace_info.action,
@ -330,7 +330,7 @@ class WeaveDataTrace(BaseTraceInstance):
suggested_question_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value),
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE),
inputs=trace_info.inputs,
outputs=trace_info.suggested_question,
attributes=attributes,
@ -355,7 +355,7 @@ class WeaveDataTrace(BaseTraceInstance):
dataset_retrieval_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value),
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE),
inputs=trace_info.inputs,
outputs={"documents": trace_info.documents},
attributes=attributes,
@ -397,7 +397,7 @@ class WeaveDataTrace(BaseTraceInstance):
name_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.GENERATE_NAME_TRACE.value),
op=str(TraceTaskName.GENERATE_NAME_TRACE),
inputs=trace_info.inputs,
outputs=trace_info.outputs,
attributes=attributes,

View File

@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
instruction=instruction, # instruct with variables are not supported
)
node_data_dict = node_data.model_dump()
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR.value
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR
execution = workflow_service.run_free_workflow_node(
node_data_dict,
tenant_id=tenant_id,

View File

@ -83,13 +83,13 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
raise ValueError("prompt_messages must be a list")
for i in range(len(v)):
if v[i]["role"] == PromptMessageRole.USER.value:
if v[i]["role"] == PromptMessageRole.USER:
v[i] = UserPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
elif v[i]["role"] == PromptMessageRole.ASSISTANT:
v[i] = AssistantPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
elif v[i]["role"] == PromptMessageRole.SYSTEM:
v[i] = SystemPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.TOOL.value:
elif v[i]["role"] == PromptMessageRole.TOOL:
v[i] = ToolPromptMessage.model_validate(v[i])
else:
v[i] = PromptMessage.model_validate(v[i])

View File

@ -610,7 +610,7 @@ class ProviderManager:
provider_quota_to_provider_record_dict = {}
for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value:
if provider_record.provider_type != ProviderType.SYSTEM:
continue
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
@ -627,8 +627,8 @@ class ProviderManager:
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
provider_type=ProviderType.SYSTEM,
quota_type=ProviderQuotaType.TRIAL,
quota_limit=quota.quota_limit, # type: ignore
quota_used=0,
is_valid=True,
@ -641,8 +641,8 @@ class ProviderManager:
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.provider_type == ProviderType.SYSTEM,
Provider.quota_type == ProviderQuotaType.TRIAL,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
@ -702,7 +702,7 @@ class ProviderManager:
"""Get custom provider configuration."""
# Find custom provider record (non-system)
custom_provider_record = next(
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM), None
)
if not custom_provider_record:
@ -905,7 +905,7 @@ class ProviderManager:
# Convert provider_records to dict
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value:
if provider_record.provider_type != ProviderType.SYSTEM:
continue
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
@ -1046,7 +1046,7 @@ class ProviderManager:
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
if credential_form_schema.type.value == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -46,7 +46,7 @@ class DataPostProcessor:
reranking_model: dict | None = None,
weights: dict | None = None,
) -> BaseRerankRunner | None:
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
runner = RerankRunnerFactory.create_rerank_runner(
runner_type=reranking_mode,
tenant_id=tenant_id,
@ -62,7 +62,7 @@ class DataPostProcessor:
),
)
return runner
elif reranking_mode == RerankMode.RERANKING_MODEL.value:
elif reranking_mode == RerankMode.RERANKING_MODEL:
rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model)
if rerank_model_instance is None:
return None

View File

@ -21,7 +21,7 @@ from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
@ -107,7 +107,7 @@ class RetrievalService:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
all_documents = cls._deduplicate_documents(all_documents)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
@ -245,10 +245,10 @@ class RetrievalService:
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
@ -293,10 +293,10 @@ class RetrievalService:
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(

View File

@ -488,9 +488,9 @@ class ClickzettaVector(BaseVector):
create_table_sql = f"""
CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} (
id STRING NOT NULL COMMENT 'Unique document identifier',
{Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
{Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes',
{Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
{Field.CONTENT_KEY} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
{Field.METADATA_KEY} JSON COMMENT 'Document metadata including source, type, and other attributes',
{Field.VECTOR} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
'High-dimensional embedding vector for semantic similarity search',
PRIMARY KEY (id)
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
@ -519,15 +519,15 @@ class ClickzettaVector(BaseVector):
existing_indexes = cursor.fetchall()
for idx in existing_indexes:
# Check if vector index already exists on the embedding column
if Field.VECTOR.value in str(idx).lower():
logger.info("Vector index already exists on column %s", Field.VECTOR.value)
if Field.VECTOR in str(idx).lower():
logger.info("Vector index already exists on column %s", Field.VECTOR)
return
except (RuntimeError, ValueError) as e:
logger.warning("Failed to check existing indexes: %s", e)
index_sql = f"""
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value})
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR})
PROPERTIES (
"distance.function" = "{self._config.vector_distance_function}",
"scalar.type" = "f32",
@ -560,17 +560,17 @@ class ClickzettaVector(BaseVector):
# More precise check: look for inverted index specifically on the content column
if (
"inverted" in idx_str
and Field.CONTENT_KEY.value.lower() in idx_str
and Field.CONTENT_KEY.lower() in idx_str
and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)
):
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx)
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY, idx)
return
except (RuntimeError, ValueError) as e:
logger.warning("Failed to check existing indexes: %s", e)
index_sql = f"""
CREATE INVERTED INDEX IF NOT EXISTS {index_name}
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value})
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY})
PROPERTIES (
"analyzer" = "{self._config.analyzer_type}",
"mode" = "{self._config.analyzer_mode}"
@ -588,13 +588,13 @@ class ClickzettaVector(BaseVector):
or "with the same type" in error_msg
or "cannot create inverted index" in error_msg
) and "already has index" in error_msg:
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value)
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY)
# Try to get the existing index name for logging
try:
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
existing_indexes = cursor.fetchall()
for idx in existing_indexes:
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower():
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.lower() in str(idx).lower():
logger.info("Found existing inverted index: %s", idx)
break
except (RuntimeError, ValueError):
@ -669,7 +669,7 @@ class ClickzettaVector(BaseVector):
# Use parameterized INSERT with executemany for better performance and security
# Cast JSON and VECTOR in SQL, pass raw data as parameters
columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}"
columns = f"id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}, {Field.VECTOR}"
insert_sql = (
f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) "
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
@ -767,7 +767,7 @@ class ClickzettaVector(BaseVector):
# Use json_extract_string function for ClickZetta compatibility
sql = (
f"DELETE FROM {self._config.schema_name}.{self._table_name} "
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?"
f"WHERE json_extract_string({Field.METADATA_KEY}, '$.{key}') = ?"
)
cursor.execute(sql, binding_params=[value])
@ -795,9 +795,7 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
# No need for dataset_id filter since each dataset has its own table
@ -808,23 +806,21 @@ class ClickzettaVector(BaseVector):
distance_func = "COSINE_DISTANCE"
if score_threshold > 0:
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
filter_clauses.append(
f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}"
)
filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {2 - score_threshold}")
else:
# For L2 distance, smaller is better
distance_func = "L2_DISTANCE"
if score_threshold > 0:
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}")
filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {score_threshold}")
where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1"
# Execute vector search query
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
search_sql = f"""
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value},
{distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY},
{distance_func}({Field.VECTOR}, {query_vector_str}) AS distance
FROM {self._config.schema_name}.{self._table_name}
WHERE {where_clause}
ORDER BY distance
@ -887,9 +883,7 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
# No need for dataset_id filter since each dataset has its own table
@ -897,13 +891,13 @@ class ClickzettaVector(BaseVector):
# match_all requires all terms to be present
# Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')")
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY}, '{escaped_query}')")
where_clause = " AND ".join(filter_clauses)
# Execute full-text search query
search_sql = f"""
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}
FROM {self._config.schema_name}.{self._table_name}
WHERE {where_clause}
LIMIT {top_k}
@ -986,19 +980,17 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
# No need for dataset_id filter since each dataset has its own table
# Use simple quote escaping for LIKE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
where_clause = " AND ".join(filter_clauses)
search_sql = f"""
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}
FROM {self._config.schema_name}.{self._table_name}
WHERE {where_clause}
LIMIT {top_k}

View File

@ -57,18 +57,18 @@ class ElasticSearchJaVector(ElasticSearchVector):
}
mappings = {
"properties": {
Field.CONTENT_KEY.value: {
Field.CONTENT_KEY: {
"type": "text",
"analyzer": "ja_analyzer",
"search_analyzer": "ja_analyzer",
},
Field.VECTOR.value: { # Make sure the dimension is correct here
Field.VECTOR: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type

View File

@ -163,9 +163,9 @@ class ElasticSearchVector(BaseVector):
index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] or None,
Field.METADATA_KEY.value: documents[i].metadata or {},
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i] or None,
Field.METADATA_KEY: documents[i].metadata or {},
},
)
self._client.indices.refresh(index=self._collection_name)
@ -193,7 +193,7 @@ class ElasticSearchVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
num_candidates = math.ceil(top_k * 1.5)
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
knn = {"field": Field.VECTOR, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
@ -205,9 +205,9 @@ class ElasticSearchVector(BaseVector):
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
),
hit["_score"],
)
@ -224,13 +224,13 @@ class ElasticSearchVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}}
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY: query}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query_str = {
"bool": {
"must": {"match": {Field.CONTENT_KEY.value: query}},
"must": {"match": {Field.CONTENT_KEY: query}},
"filter": {"terms": {"metadata.document_id": document_ids_filter}},
}
}
@ -240,9 +240,9 @@ class ElasticSearchVector(BaseVector):
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
)
)
@ -270,14 +270,14 @@ class ElasticSearchVector(BaseVector):
dim = len(embeddings[0])
mappings = {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: { # Make sure the dimension is correct here
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type

View File

@ -67,9 +67,9 @@ class HuaweiCloudVector(BaseVector):
index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] or None,
Field.METADATA_KEY.value: documents[i].metadata or {},
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i] or None,
Field.METADATA_KEY: documents[i].metadata or {},
},
)
self._client.indices.refresh(index=self._collection_name)
@ -101,7 +101,7 @@ class HuaweiCloudVector(BaseVector):
"size": top_k,
"query": {
"vector": {
Field.VECTOR.value: {
Field.VECTOR: {
"vector": query_vector,
"topk": top_k,
}
@ -116,9 +116,9 @@ class HuaweiCloudVector(BaseVector):
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
),
hit["_score"],
)
@ -135,15 +135,15 @@ class HuaweiCloudVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {"match": {Field.CONTENT_KEY.value: query}}
query_str = {"match": {Field.CONTENT_KEY: query}}
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = []
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
)
)
@ -171,8 +171,8 @@ class HuaweiCloudVector(BaseVector):
dim = len(embeddings[0])
mappings = {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: { # Make sure the dimension is correct here
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: { # Make sure the dimension is correct here
"type": "vector",
"dimension": dim,
"indexing": True,
@ -181,7 +181,7 @@ class HuaweiCloudVector(BaseVector):
"neighbors": 32,
"efc": 128,
},
Field.METADATA_KEY.value: {
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type

View File

@ -125,9 +125,9 @@ class LindormVectorStore(BaseVector):
}
}
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i],
Field.METADATA_KEY: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
@ -149,7 +149,7 @@ class LindormVectorStore(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = {
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
@ -252,14 +252,14 @@ class LindormVectorStore(BaseVector):
search_query: dict[str, Any] = {
"size": top_k,
"_source": True,
"query": {"knn": {Field.VECTOR.value: {"vector": query_vector, "k": top_k}}},
"query": {"knn": {Field.VECTOR: {"vector": query_vector, "k": top_k}}},
}
final_ext: dict[str, Any] = {"lvector": {}}
if filters is not None and len(filters) > 0:
# when using filter, transform filter from List[Dict] to Dict as valid format
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
search_query["query"]["knn"][Field.VECTOR.value]["filter"] = filter_dict # filter should be Dict
search_query["query"]["knn"][Field.VECTOR]["filter"] = filter_dict # filter should be Dict
final_ext["lvector"]["filter_type"] = "pre_filter"
if final_ext != {"lvector": {}}:
@ -279,9 +279,9 @@ class LindormVectorStore(BaseVector):
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
),
hit["_score"],
)
@ -318,9 +318,9 @@ class LindormVectorStore(BaseVector):
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value)
vector = hit["_source"].get(Field.VECTOR.value)
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
metadata = hit["_source"].get(Field.METADATA_KEY)
vector = hit["_source"].get(Field.VECTOR)
page_content = hit["_source"].get(Field.CONTENT_KEY)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
@ -342,8 +342,8 @@ class LindormVectorStore(BaseVector):
"settings": {"index": {"knn": True, "knn_routing": self._using_ugc}},
"mappings": {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: {
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: {
"type": "knn_vector",
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
"method": {

View File

@ -85,7 +85,7 @@ class MilvusVector(BaseVector):
collection_info = self._client.describe_collection(self._collection_name)
fields = [field["name"] for field in collection_info["fields"]]
# Since primary field is auto-id, no need to track it
self._fields = [f for f in fields if f != Field.PRIMARY_KEY.value]
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
def _check_hybrid_search_support(self) -> bool:
"""
@ -130,9 +130,9 @@ class MilvusVector(BaseVector):
insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i],
Field.METADATA_KEY: documents[i].metadata,
}
insert_dict_list.append(insert_dict)
# Total insert count
@ -243,15 +243,15 @@ class MilvusVector(BaseVector):
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
anns_field=Field.VECTOR.value,
anns_field=Field.VECTOR,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
filter=filter,
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
@ -264,7 +264,7 @@ class MilvusVector(BaseVector):
"Full-text search is disabled: set MILVUS_ENABLE_HYBRID_SEARCH=true (requires Milvus >= 2.5.0)."
)
return []
if not self.field_exists(Field.SPARSE_VECTOR.value):
if not self.field_exists(Field.SPARSE_VECTOR):
logger.warning(
"Full-text search unavailable: collection missing 'sparse_vector' field; "
"recreate the collection after enabling MILVUS_ENABLE_HYBRID_SEARCH to add BM25 sparse index."
@ -279,15 +279,15 @@ class MilvusVector(BaseVector):
results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR.value,
anns_field=Field.SPARSE_VECTOR,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
filter=filter,
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
@ -311,7 +311,7 @@ class MilvusVector(BaseVector):
dim = len(embeddings[0])
fields = []
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
fields.append(FieldSchema(Field.METADATA_KEY, DataType.JSON, max_length=65_535))
# Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
@ -326,15 +326,15 @@ class MilvusVector(BaseVector):
):
content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs))
fields.append(FieldSchema(Field.CONTENT_KEY, DataType.VARCHAR, **content_field_kwargs))
# Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
fields.append(FieldSchema(Field.PRIMARY_KEY, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
fields.append(FieldSchema(Field.VECTOR, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
fields.append(FieldSchema(Field.SPARSE_VECTOR, DataType.SPARSE_FLOAT_VECTOR))
schema = CollectionSchema(fields)
@ -342,8 +342,8 @@ class MilvusVector(BaseVector):
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY.value],
output_field_names=[Field.SPARSE_VECTOR.value],
input_field_names=[Field.CONTENT_KEY],
output_field_names=[Field.SPARSE_VECTOR],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
@ -352,12 +352,12 @@ class MilvusVector(BaseVector):
# Create Index params for the collection
index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
index_params_obj.add_index(field_name=Field.VECTOR, **index_params)
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
field_name=Field.SPARSE_VECTOR, index_type="AUTOINDEX", metric_type="BM25"
)
# Create the collection

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Literal
from typing import Any
from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
@ -8,6 +8,7 @@ from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from configs import dify_config
from configs.middleware.vdb.opensearch_config import AuthMethod
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@ -25,7 +26,7 @@ class OpenSearchConfig(BaseModel):
port: int
secure: bool = False # use_ssl
verify_certs: bool = True
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
auth_method: AuthMethod = AuthMethod.BASIC
user: str | None = None
password: str | None = None
aws_region: str | None = None
@ -98,9 +99,9 @@ class OpenSearchVector(BaseVector):
"_op_type": "index",
"_index": self._collection_name.lower(),
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY: documents[i].metadata,
},
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
@ -116,7 +117,7 @@ class OpenSearchVector(BaseVector):
)
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
query = {"query": {"term": {f"{Field.METADATA_KEY}.{key}": value}}}
response = self._client.search(index=self._collection_name.lower(), body=query)
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
@ -180,17 +181,17 @@ class OpenSearchVector(BaseVector):
query = {
"size": kwargs.get("top_k", 4),
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
"query": {"knn": {Field.VECTOR: {Field.VECTOR: query_vector, "k": kwargs.get("top_k", 4)}}},
}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query["query"] = {
"script_score": {
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID.value: document_ids_filter}}]}},
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID: document_ids_filter}}]}},
"script": {
"source": "knn_score",
"lang": "knn",
"params": {"field": Field.VECTOR.value, "query_value": query_vector, "space_type": "l2"},
"params": {"field": Field.VECTOR, "query_value": query_vector, "space_type": "l2"},
},
}
}
@ -203,7 +204,7 @@ class OpenSearchVector(BaseVector):
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value, {})
metadata = hit["_source"].get(Field.METADATA_KEY, {})
# Make sure metadata is a dictionary
if metadata is None:
@ -212,7 +213,7 @@ class OpenSearchVector(BaseVector):
metadata["score"] = hit["_score"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] >= score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY), metadata=metadata)
docs.append(doc)
return docs
@ -227,9 +228,9 @@ class OpenSearchVector(BaseVector):
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value)
vector = hit["_source"].get(Field.VECTOR.value)
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
metadata = hit["_source"].get(Field.METADATA_KEY)
vector = hit["_source"].get(Field.VECTOR)
page_content = hit["_source"].get(Field.CONTENT_KEY)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
@ -250,8 +251,8 @@ class OpenSearchVector(BaseVector):
"settings": {"index": {"knn": True}},
"mappings": {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: {
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: {
"type": "knn_vector",
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
"method": {
@ -261,7 +262,7 @@ class OpenSearchVector(BaseVector):
"parameters": {"ef_construction": 64, "m": 8},
},
},
Field.METADATA_KEY.value: {
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
@ -293,7 +294,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
port=dify_config.OPENSEARCH_PORT,
secure=dify_config.OPENSEARCH_SECURE,
verify_certs=dify_config.OPENSEARCH_VERIFY_CERTS,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
aws_region=dify_config.OPENSEARCH_AWS_REGION,

View File

@ -147,15 +147,13 @@ class QdrantVector(BaseVector):
# create group_id payload index
self._client.create_payload_index(
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
)
self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD)
# create document_id payload index
self._client.create_payload_index(
collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD
collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
@ -165,9 +163,7 @@ class QdrantVector(BaseVector):
max_token_len=20,
lowercase=True,
)
self._client.create_payload_index(
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -220,10 +216,10 @@ class QdrantVector(BaseVector):
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
Field.CONTENT_KEY,
Field.METADATA_KEY,
group_id or "", # Ensure group_id is never None
Field.GROUP_KEY.value,
Field.GROUP_KEY,
),
)
]
@ -381,12 +377,12 @@ class QdrantVector(BaseVector):
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
metadata = result.payload.get(Field.METADATA_KEY) or {}
# duplicate check score threshold
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
page_content=result.payload.get(Field.CONTENT_KEY, ""),
metadata=metadata,
)
docs.append(doc)
@ -433,7 +429,7 @@ class QdrantVector(BaseVector):
documents = []
for result in results:
if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
documents.append(document)
return documents

View File

@ -55,7 +55,7 @@ class TableStoreVector(BaseVector):
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
self._table_name = f"{collection_name}"
self._index_name = f"{collection_name}_idx"
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
self._tags_field = f"{Field.METADATA_KEY}_tags"
def create_collection(self, embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
@ -64,7 +64,7 @@ class TableStoreVector(BaseVector):
def get_by_ids(self, ids: list[str]) -> list[Document]:
docs = []
request = BatchGetRowRequest()
columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value]
columns_to_get = [Field.METADATA_KEY, Field.CONTENT_KEY]
rows_to_get = [[("id", _id)] for _id in ids]
request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1))
@ -73,11 +73,7 @@ class TableStoreVector(BaseVector):
for item in table_result:
if item.is_ok and item.row:
kv = {k: v for k, v, _ in item.row.attribute_columns}
docs.append(
Document(
page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value])
)
)
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY])))
return docs
def get_type(self) -> str:
@ -95,9 +91,9 @@ class TableStoreVector(BaseVector):
self._write_row(
primary_key=uuids[i],
attributes={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i],
Field.METADATA_KEY: documents[i].metadata,
},
)
return uuids
@ -180,7 +176,7 @@ class TableStoreVector(BaseVector):
field_schemas = [
tablestore.FieldSchema(
Field.CONTENT_KEY.value,
Field.CONTENT_KEY,
tablestore.FieldType.TEXT,
analyzer=tablestore.AnalyzerType.MAXWORD,
index=True,
@ -188,7 +184,7 @@ class TableStoreVector(BaseVector):
store=False,
),
tablestore.FieldSchema(
Field.VECTOR.value,
Field.VECTOR,
tablestore.FieldType.VECTOR,
vector_options=tablestore.VectorOptions(
data_type=tablestore.VectorDataType.VD_FLOAT_32,
@ -197,7 +193,7 @@ class TableStoreVector(BaseVector):
),
),
tablestore.FieldSchema(
Field.METADATA_KEY.value,
Field.METADATA_KEY,
tablestore.FieldType.KEYWORD,
index=True,
store=False,
@ -233,15 +229,15 @@ class TableStoreVector(BaseVector):
pk = [("id", primary_key)]
tags = []
for key, value in attributes[Field.METADATA_KEY.value].items():
for key, value in attributes[Field.METADATA_KEY].items():
tags.append(str(key) + "=" + str(value))
attribute_columns = [
(Field.CONTENT_KEY.value, attributes[Field.CONTENT_KEY.value]),
(Field.VECTOR.value, json.dumps(attributes[Field.VECTOR.value])),
(Field.CONTENT_KEY, attributes[Field.CONTENT_KEY]),
(Field.VECTOR, json.dumps(attributes[Field.VECTOR])),
(
Field.METADATA_KEY.value,
json.dumps(attributes[Field.METADATA_KEY.value]),
Field.METADATA_KEY,
json.dumps(attributes[Field.METADATA_KEY]),
),
(self._tags_field, json.dumps(tags)),
]
@ -270,7 +266,7 @@ class TableStoreVector(BaseVector):
index_name=self._index_name,
search_query=query,
columns_to_get=tablestore.ColumnsToGet(
column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED
column_names=[Field.PRIMARY_KEY], return_type=tablestore.ColumnReturnType.SPECIFIED
),
)
@ -288,7 +284,7 @@ class TableStoreVector(BaseVector):
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
knn_vector_query = tablestore.KnnVectorQuery(
field_name=Field.VECTOR.value,
field_name=Field.VECTOR,
top_k=top_k,
float32_query_vector=query_vector,
)
@ -311,8 +307,8 @@ class TableStoreVector(BaseVector):
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
vector_str = ots_column_map.get(Field.VECTOR.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
vector_str = ots_column_map.get(Field.VECTOR)
metadata_str = ots_column_map.get(Field.METADATA_KEY)
vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}
@ -321,7 +317,7 @@ class TableStoreVector(BaseVector):
documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
vector=vector,
metadata=metadata,
)
@ -343,7 +339,7 @@ class TableStoreVector(BaseVector):
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY))
if document_ids_filter:
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))
@ -374,10 +370,10 @@ class TableStoreVector(BaseVector):
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY)
metadata = json.loads(metadata_str) if metadata_str else {}
vector_str = ots_column_map.get(Field.VECTOR.value)
vector_str = ots_column_map.get(Field.VECTOR)
vector = json.loads(vector_str) if vector_str else None
if score:
@ -385,7 +381,7 @@ class TableStoreVector(BaseVector):
documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
vector=vector,
metadata=metadata,
)

View File

@ -141,15 +141,13 @@ class TidbOnQdrantVector(BaseVector):
# create group_id payload index
self._client.create_payload_index(
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
)
self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD)
# create document_id payload index
self._client.create_payload_index(
collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD
collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
@ -159,9 +157,7 @@ class TidbOnQdrantVector(BaseVector):
max_token_len=20,
lowercase=True,
)
self._client.create_payload_index(
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -211,10 +207,10 @@ class TidbOnQdrantVector(BaseVector):
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
Field.CONTENT_KEY,
Field.METADATA_KEY,
group_id or "",
Field.GROUP_KEY.value,
Field.GROUP_KEY,
),
)
]
@ -349,13 +345,13 @@ class TidbOnQdrantVector(BaseVector):
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
metadata = result.payload.get(Field.METADATA_KEY) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
page_content=result.payload.get(Field.CONTENT_KEY, ""),
metadata=metadata,
)
docs.append(doc)
@ -392,7 +388,7 @@ class TidbOnQdrantVector(BaseVector):
documents = []
for result in results:
if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
documents.append(document)
return documents

View File

@ -55,13 +55,13 @@ class TiDBVector(BaseVector):
return Table(
self._collection_name,
self._orm_base.metadata,
Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False),
Column(Field.PRIMARY_KEY, String(36), primary_key=True, nullable=False),
Column(
Field.VECTOR.value,
Field.VECTOR,
VectorType(dim),
nullable=False,
),
Column(Field.TEXT_KEY.value, TEXT, nullable=False),
Column(Field.TEXT_KEY, TEXT, nullable=False),
Column("meta", JSON, nullable=False),
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
Column(

View File

@ -76,11 +76,11 @@ class VikingDBVector(BaseVector):
if not self._has_collection():
fields = [
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension),
Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=dimension),
]
self._client.create_collection(
@ -100,7 +100,7 @@ class VikingDBVector(BaseVector):
collection_name=self._collection_name,
index_name=self._index_name,
vector_index=vector_index,
partition_by=vdb_Field.GROUP_KEY.value,
partition_by=vdb_Field.GROUP_KEY,
description="Index For Dify",
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@ -126,11 +126,11 @@ class VikingDBVector(BaseVector):
# FIXME: fix the type of metadata later
doc = Data(
{
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
vdb_Field.CONTENT_KEY.value: page_content,
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
vdb_Field.GROUP_KEY.value: self._group_id,
vdb_Field.PRIMARY_KEY: metadatas[i]["doc_id"], # type: ignore
vdb_Field.VECTOR: embeddings[i] if embeddings else None,
vdb_Field.CONTENT_KEY: page_content,
vdb_Field.METADATA_KEY: json.dumps(metadata),
vdb_Field.GROUP_KEY: self._group_id,
}
)
docs.append(doc)
@ -151,7 +151,7 @@ class VikingDBVector(BaseVector):
# Note: Metadata field value is an dict, but vikingdb field
# not support json type
results = self._client.get_index(self._collection_name, self._index_name).search(
filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]},
filter={"op": "must", "field": vdb_Field.GROUP_KEY, "conds": [self._group_id]},
# max value is 5000
limit=5000,
)
@ -161,7 +161,7 @@ class VikingDBVector(BaseVector):
ids = []
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
metadata = json.loads(metadata)
if metadata.get(key) == value:
@ -189,12 +189,12 @@ class VikingDBVector(BaseVector):
docs = []
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
metadata = json.loads(metadata)
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata)
docs.append(doc)
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return docs

View File

@ -104,7 +104,7 @@ class WeaviateVector(BaseVector):
with self._client.batch as batch:
for i, text in enumerate(texts):
data_properties = {Field.TEXT_KEY.value: text}
data_properties = {Field.TEXT_KEY: text}
if metadatas is not None:
# metadata maybe None
for key, val in (metadatas[i] or {}).items():
@ -182,7 +182,7 @@ class WeaviateVector(BaseVector):
"""Look up similar documents by embedding vector in Weaviate."""
collection_name = self._collection_name
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
properties.append(Field.TEXT_KEY)
query_obj = self._client.query.get(collection_name, properties)
vector = {"vector": query_vector}
@ -204,7 +204,7 @@ class WeaviateVector(BaseVector):
docs_and_scores = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
text = res.pop(Field.TEXT_KEY)
score = 1 - res["_additional"]["distance"]
docs_and_scores.append((Document(page_content=text, metadata=res), score))
@ -232,7 +232,7 @@ class WeaviateVector(BaseVector):
collection_name = self._collection_name
content: dict[str, Any] = {"concepts": [query]}
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
properties.append(Field.TEXT_KEY)
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties)
@ -250,7 +250,7 @@ class WeaviateVector(BaseVector):
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
text = res.pop(Field.TEXT_KEY)
additional = res.pop("_additional")
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
return docs

View File

@ -20,12 +20,12 @@ class BaseDatasourceEvent(BaseModel):
class DatasourceErrorEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.ERROR.value
event: DatasourceStreamEvent = DatasourceStreamEvent.ERROR
error: str = Field(..., description="error message")
class DatasourceCompletedEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.COMPLETED.value
event: DatasourceStreamEvent = DatasourceStreamEvent.COMPLETED
data: Mapping[str, Any] | list = Field(..., description="result")
total: int | None = Field(default=0, description="total")
completed: int | None = Field(default=0, description="completed")
@ -33,6 +33,6 @@ class DatasourceCompletedEvent(BaseDatasourceEvent):
class DatasourceProcessingEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.PROCESSING.value
event: DatasourceStreamEvent = DatasourceStreamEvent.PROCESSING
total: int | None = Field(..., description="total")
completed: int | None = Field(..., description="completed")

View File

@ -45,7 +45,7 @@ class ExtractProcessor:
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
) -> Union[list[Document], str]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=upload_file, document_model="text_model"
datasource_type=DatasourceType.FILE, upload_file=upload_file, document_model="text_model"
)
if return_text:
delimiter = "\n"
@ -76,7 +76,7 @@ class ExtractProcessor:
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}"
Path(file_path).write_bytes(response.content)
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE.value, document_model="text_model")
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model")
if return_text:
delimiter = "\n"
return delimiter.join(
@ -92,7 +92,7 @@ class ExtractProcessor:
def extract(
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE.value:
if extract_setting.datasource_type == DatasourceType.FILE:
with tempfile.TemporaryDirectory() as temp_dir:
if not file_path:
assert extract_setting.upload_file is not None, "upload_file is required"
@ -163,7 +163,7 @@ class ExtractProcessor:
# txt
extractor = TextExtractor(file_path, autodetect_encoding=True)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
elif extract_setting.datasource_type == DatasourceType.NOTION:
assert extract_setting.notion_info is not None, "notion_info is required"
extractor = NotionExtractor(
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
@ -174,7 +174,7 @@ class ExtractProcessor:
credential_id=extract_setting.notion_info.credential_id,
)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
elif extract_setting.datasource_type == DatasourceType.WEBSITE:
assert extract_setting.website_info is not None, "website_info is required"
if extract_setting.website_info.provider == "firecrawl":
extractor = FirecrawlWebExtractor(

View File

@ -8,9 +8,9 @@ class RerankRunnerFactory:
@staticmethod
def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner:
match runner_type:
case RerankMode.RERANKING_MODEL.value:
case RerankMode.RERANKING_MODEL:
return RerankModelRunner(*args, **kwargs)
case RerankMode.WEIGHTED_SCORE.value:
case RerankMode.WEIGHTED_SCORE:
return WeightRerankRunner(*args, **kwargs)
case _:
raise ValueError(f"Unknown runner type: {runner_type}")

View File

@ -61,7 +61,7 @@ from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
@ -692,7 +692,7 @@ class DatasetRetrieval:
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# get retrieval model config
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,

View File

@ -9,8 +9,8 @@ class RetrievalMethod(Enum):
@staticmethod
def is_support_semantic_search(retrieval_method: str) -> bool:
return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value}
return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.HYBRID_SEARCH}
@staticmethod
def is_support_fulltext_search(retrieval_method: str) -> bool:
return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value}
return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.HYBRID_SEARCH}

View File

@ -111,7 +111,7 @@ class BuiltinToolProviderController(ToolProviderController):
:return: the credentials schema
"""
return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
"""
@ -122,7 +122,7 @@ class BuiltinToolProviderController(ToolProviderController):
"""
if credential_type == CredentialType.OAUTH2.value:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY.value:
if credential_type == CredentialType.API_KEY:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
raise ValueError(f"Invalid credential type: {credential_type}")
@ -134,15 +134,15 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
def get_supported_credential_types(self) -> list[str]:
def get_supported_credential_types(self) -> list[CredentialType]:
"""
returns the credential support type of the provider
"""
types = []
if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
types.append(CredentialType.API_KEY.value)
types.append(CredentialType.API_KEY)
if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
types.append(CredentialType.OAUTH2.value)
types.append(CredentialType.OAUTH2)
return types
def get_tools(self) -> list[BuiltinTool]:

View File

@ -61,7 +61,7 @@ class ToolProviderApiEntity(BaseModel):
for tool in tools:
if tool.get("parameters"):
for parameter in tool.get("parameters"):
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES:
parameter["type"] = "files"
if parameter.get("input_schema") is None:
parameter.pop("input_schema", None)
@ -110,7 +110,9 @@ class ToolProviderCredentialApiEntity(BaseModel):
class ToolProviderCredentialInfoApiEntity(BaseModel):
supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
supported_credential_types: list[CredentialType] = Field(
description="The supported credential types of the provider"
)
is_oauth_custom_client_enabled: bool = Field(
default=False, description="Whether the OAuth custom client is enabled for the provider"
)

View File

@ -113,7 +113,7 @@ class ApiProviderAuthType(StrEnum):
# normalize & tiny alias for backward compatibility
v = (value or "").strip().lower()
if v == "api_key":
v = cls.API_KEY_HEADER.value
v = cls.API_KEY_HEADER
for mode in cls:
if mode.value == v:

View File

@ -18,7 +18,7 @@ from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,

View File

@ -17,7 +17,7 @@ from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"reranking_mode": "reranking_model",

View File

@ -393,7 +393,7 @@ class ApiBasedToolSchemaParser:
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.OPENAPI.value
schema_type = ApiProviderSchemaType.OPENAPI
return openapi, schema_type
except ToolApiSchemaError as e:
openapi_error = e
@ -403,7 +403,7 @@ class ApiBasedToolSchemaParser:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.SWAGGER.value
schema_type = ApiProviderSchemaType.SWAGGER
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
converted_swagger, extra_info=extra_info, warning=warning
), schema_type
@ -415,7 +415,7 @@ class ApiBasedToolSchemaParser:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning
)
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN
except ToolNotSupportedError as e:
# maybe it's not plugin at all
openapi_plugin_error = e

View File

@ -252,8 +252,8 @@ class AgentNode(Node):
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN.value) in (
ParamsAutoGenerated.CLOSE.value,
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
@ -269,7 +269,7 @@ class AgentNode(Node):
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN.value))
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
@ -420,7 +420,7 @@ class AgentNode(Node):
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]
["sys", SystemVariableKey.CONVERSATION_ID]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
@ -479,7 +479,7 @@ class AgentNode(Node):
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _transform_message(
self,

View File

@ -75,11 +75,11 @@ class DatasourceNode(Node):
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
if not datasource_type_segement:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
if not datasource_info_segement:
raise DatasourceNodeError("Datasource info is not set")
datasource_info_value = datasource_info_segement.value
@ -267,7 +267,7 @@ class DatasourceNode(Node):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []

View File

@ -234,7 +234,7 @@ class HttpRequestNode(Node):
mapping = {
"tool_file_id": tool_file.id,
"transfer_method": FileTransferMethod.TOOL_FILE.value,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
file = file_factory.build_from_mapping(
mapping=mapping,

View File

@ -95,7 +95,7 @@ class IterationNode(Node):
"config": {
"is_parallel": False,
"parallel_nums": 10,
"error_handle_mode": ErrorHandleMode.TERMINATED.value,
"error_handle_mode": ErrorHandleMode.TERMINATED,
},
}

View File

@ -27,7 +27,7 @@ from .exc import (
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
@ -77,7 +77,7 @@ class KnowledgeIndexNode(Node):
raise KnowledgeIndexNodeError("Index chunk variable is required.")
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
if invoke_from:
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
is_preview = invoke_from.value == InvokeFrom.DEBUGGER
else:
is_preview = False
chunks = variable.value

View File

@ -72,7 +72,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,

View File

@ -92,7 +92,7 @@ def fetch_memory(
return None
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value])
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
@ -143,7 +143,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.provider_type == ProviderType.SYSTEM,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)

View File

@ -945,7 +945,7 @@ class LLMNode(Node):
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
if typed_node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
if typed_node_data.prompt_config:
enable_jinja = False

View File

@ -224,7 +224,7 @@ class ToolNode(Node):
return result
def _fetch_files(self, variable_pool: "VariablePool") -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []

View File

@ -227,7 +227,7 @@ class WorkflowEntry:
"height": node_height,
"type": "custom",
"data": {
"type": NodeType.START.value,
"type": NodeType.START,
"title": "Start",
"desc": "Start",
},

View File

@ -12,7 +12,7 @@ def handle(sender, **kwargs):
if synced_draft_workflow is None:
return
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
if node_data.get("data", {}).get("type") == NodeType.TOOL:
try:
tool_entity = ToolEntity.model_validate(node_data["data"])
tool_runtime = ToolManager.get_tool_runtime(

View File

@ -53,7 +53,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]:
# fetch all knowledge retrieval nodes
knowledge_retrieval_nodes = [
node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value
node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL
]
if not knowledge_retrieval_nodes:

View File

@ -139,7 +139,7 @@ def handle(sender: Message, **kwargs):
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
provider_type=ProviderType.SYSTEM,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),

View File

@ -264,7 +264,7 @@ class FileLifecycleManager:
logger.warning("File %s not found in metadata", filename)
return False
metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value
metadata_dict[filename]["status"] = FileStatus.ARCHIVED
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict)
@ -309,7 +309,7 @@ class FileLifecycleManager:
# Update metadata
metadata_dict = self._load_metadata()
if filename in metadata_dict:
metadata_dict[filename]["status"] = FileStatus.DELETED.value
metadata_dict[filename]["status"] = FileStatus.DELETED
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict)

View File

@ -45,7 +45,7 @@ def build_from_message_file(
}
# Set the correct ID field based on transfer method
if message_file.transfer_method == FileTransferMethod.TOOL_FILE.value:
if message_file.transfer_method == FileTransferMethod.TOOL_FILE:
mapping["tool_file_id"] = message_file.upload_file_id
else:
mapping["upload_file_id"] = message_file.upload_file_id
@ -368,9 +368,7 @@ def _build_from_datasource_file(
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
return File(
id=mapping.get("datasource_file_id"),

View File

@ -9,7 +9,7 @@ from .base import Base
from .types import StringUUID
class APIBasedExtensionPoint(enum.Enum):
class APIBasedExtensionPoint(enum.StrEnum):
APP_EXTERNAL_DATA_TOOL_QUERY = "app.external_data_tool.query"
PING = "ping"
APP_MODERATION_INPUT = "app.moderation.input"

View File

@ -184,7 +184,7 @@ class Dataset(Base):
@property
def retrieval_model_dict(self):
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,

View File

@ -186,13 +186,13 @@ class App(Base):
if len(keys) >= 4:
provider_type = tool.get("provider_type", "")
provider_id = tool.get("provider_id", "")
if provider_type == ToolProviderType.API.value:
if provider_type == ToolProviderType.API:
try:
uuid.UUID(provider_id)
except Exception:
continue
api_provider_ids.append(provider_id)
if provider_type == ToolProviderType.BUILT_IN.value:
if provider_type == ToolProviderType.BUILT_IN:
try:
# check if it's hardcoded
try:
@ -251,23 +251,23 @@ class App(Base):
provider_type = tool.get("provider_type", "")
provider_id = tool.get("provider_id", "")
if provider_type == ToolProviderType.API.value:
if provider_type == ToolProviderType.API:
if uuid.UUID(provider_id) not in existing_api_providers:
deleted_tools.append(
{
"type": ToolProviderType.API.value,
"type": ToolProviderType.API,
"tool_name": tool["tool_name"],
"provider_id": provider_id,
}
)
if provider_type == ToolProviderType.BUILT_IN.value:
if provider_type == ToolProviderType.BUILT_IN:
generic_provider_id = GenericProviderID(provider_id)
if not existing_builtin_providers[generic_provider_id.provider_name]:
deleted_tools.append(
{
"type": ToolProviderType.BUILT_IN.value,
"type": ToolProviderType.BUILT_IN,
"tool_name": tool["tool_name"],
"provider_id": provider_id, # use the original one
}
@ -1154,7 +1154,7 @@ class Message(Base):
files: list[File] = []
for message_file in message_files:
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value:
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if message_file.upload_file_id is None:
raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id")
file = file_factory.build_from_mapping(
@ -1166,7 +1166,7 @@ class Message(Base):
},
tenant_id=current_app.tenant_id,
)
elif message_file.transfer_method == FileTransferMethod.REMOTE_URL.value:
elif message_file.transfer_method == FileTransferMethod.REMOTE_URL:
if message_file.url is None:
raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url")
file = file_factory.build_from_mapping(
@ -1179,7 +1179,7 @@ class Message(Base):
},
tenant_id=current_app.tenant_id,
)
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE.value:
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE:
if message_file.upload_file_id is None:
assert message_file.url is not None
message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0]

View File

@ -107,7 +107,7 @@ class Provider(Base):
"""
Returns True if the provider is enabled.
"""
if self.provider_type == ProviderType.SYSTEM.value:
if self.provider_type == ProviderType.SYSTEM:
return self.is_valid
else:
return self.is_valid and self.token_is_set

View File

@ -829,14 +829,14 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
if self.execution_metadata_dict:
from core.workflow.nodes import NodeType
if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict:
if self.node_type == NodeType.TOOL and "tool_info" in self.execution_metadata_dict:
tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"]
extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self.tenant_id,
provider_type=tool_info["provider_type"],
provider_id=tool_info["provider_id"],
)
elif self.node_type == NodeType.DATASOURCE.value and "datasource_info" in self.execution_metadata_dict:
elif self.node_type == NodeType.DATASOURCE and "datasource_info" in self.execution_metadata_dict:
datasource_info = self.execution_metadata_dict["datasource_info"]
extras["icon"] = datasource_info.get("icon")
return extras

View File

@ -127,7 +127,7 @@ class AccountService:
if not account:
return None
if account.status == AccountStatus.BANNED.value:
if account.status == AccountStatus.BANNED:
raise Unauthorized("Account is banned.")
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
@ -178,7 +178,7 @@ class AccountService:
if not account:
raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.BANNED.value:
if account.status == AccountStatus.BANNED:
raise AccountLoginError("Account is banned.")
if password and invite_token and account.password is None:
@ -193,8 +193,8 @@ class AccountService:
if account.password is None or not compare_password(password, account.password, account.password_salt):
raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
if account.status == AccountStatus.PENDING:
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()
@ -357,7 +357,7 @@ class AccountService:
@staticmethod
def close_account(account: Account):
"""Close account"""
account.status = AccountStatus.CLOSED.value
account.status = AccountStatus.CLOSED
db.session.commit()
@staticmethod
@ -397,8 +397,8 @@ class AccountService:
if ip_address:
AccountService.update_login_info(account=account, ip_address=ip_address)
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
if account.status == AccountStatus.PENDING:
account.status = AccountStatus.ACTIVE
db.session.commit()
access_token = AccountService.get_account_jwt_token(account=account)
@ -766,7 +766,7 @@ class AccountService:
if not account:
return None
if account.status == AccountStatus.BANNED.value:
if account.status == AccountStatus.BANNED:
raise Unauthorized("Account is banned.")
return account
@ -1030,7 +1030,7 @@ class TenantService:
@staticmethod
def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin:
"""Create tenant member"""
if role == TenantAccountRole.OWNER.value:
if role == TenantAccountRole.OWNER:
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]):
logger.error("Tenant %s has already an owner.", tenant.id)
raise Exception("Tenant already has an owner.")
@ -1315,7 +1315,7 @@ class RegisterService:
password=password,
is_setup=is_setup,
)
account.status = AccountStatus.ACTIVE.value if not status else status.value
account.status = status or AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
if open_id is not None and provider is not None:
@ -1376,7 +1376,7 @@ class RegisterService:
TenantService.create_tenant_member(tenant, account, role)
# Support resend invitation email when the account is pending status
if account.status != AccountStatus.PENDING.value:
if account.status != AccountStatus.PENDING:
raise AccountAlreadyInTenantError("Account already in tenant.")
token = cls.generate_invite_token(tenant, account)

View File

@ -494,7 +494,7 @@ class AppDslService:
unique_hash = None
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
@ -584,17 +584,17 @@ class AppDslService:
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
if data_type == NodeType.KNOWLEDGE_RETRIEVAL:
dataset_ids = node_data.get("dataset_ids", [])
node_data["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
for dataset_id in dataset_ids
]
# filter credential id from tool node
if not include_secret and data_type == NodeType.TOOL.value:
if not include_secret and data_type == NodeType.TOOL:
node_data.pop("credential_id", None)
# filter credential id from agent node
if not include_secret and data_type == NodeType.AGENT.value:
if not include_secret and data_type == NodeType.AGENT:
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
tool.pop("credential_id", None)
@ -658,31 +658,31 @@ class AppDslService:
try:
typ = node.get("data", {}).get("type")
match typ:
case NodeType.TOOL.value:
case NodeType.TOOL:
tool_entity = ToolNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
)
case NodeType.LLM.value:
case NodeType.LLM:
llm_entity = LLMNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
)
case NodeType.QUESTION_CLASSIFIER.value:
case NodeType.QUESTION_CLASSIFIER:
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
question_classifier_entity.model.provider
),
)
case NodeType.PARAMETER_EXTRACTOR.value:
case NodeType.PARAMETER_EXTRACTOR:
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
parameter_extractor_entity.model.provider
),
)
case NodeType.KNOWLEDGE_RETRIEVAL.value:
case NodeType.KNOWLEDGE_RETRIEVAL:
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
if knowledge_retrieval_entity.retrieval_mode == "multiple":
if knowledge_retrieval_entity.multiple_retrieval_config:

View File

@ -646,7 +646,7 @@ class DatasourceProviderService:
name=db_provider_name,
provider=provider_name,
plugin_id=plugin_id,
auth_type=CredentialType.API_KEY.value,
auth_type=CredentialType.API_KEY,
encrypted_credentials=credentials,
)
session.add(datasource_provider)
@ -674,7 +674,7 @@ class DatasourceProviderService:
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
if credential_form_schema.type.value == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.name)
return secret_input_form_variables

View File

@ -15,7 +15,7 @@ from models.dataset import Dataset, DatasetQuery
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,

View File

@ -242,7 +242,7 @@ class PluginMigration:
if data.get("type") == "tool":
provider_name = data.get("provider_name")
provider_type = data.get("provider_type")
if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN:
result.append(ToolProviderID(provider_name).plugin_id)
return result
@ -271,7 +271,7 @@ class PluginMigration:
try:
tool_entity = AgentToolEntity.model_validate(tool)
if (
tool_entity.provider_type == ToolProviderType.BUILT_IN.value
tool_entity.provider_type == ToolProviderType.BUILT_IN
and tool_entity.provider_id not in excluded_providers
):
result.append(ToolProviderID(tool_entity.provider_id).plugin_id)

View File

@ -873,7 +873,7 @@ class RagPipelineService:
variable_pool = node_instance.graph_runtime_state.variable_pool
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
if invoke_from:
if invoke_from.value == InvokeFrom.PUBLISHED.value:
if invoke_from.value == InvokeFrom.PUBLISHED:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).where(Document.id == document_id.value).first()

View File

@ -556,7 +556,7 @@ class RagPipelineDslService:
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
@ -613,7 +613,7 @@ class RagPipelineDslService:
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
features="{}",
type=WorkflowType.RAG_PIPELINE.value,
type=WorkflowType.RAG_PIPELINE,
version="draft",
graph=json.dumps(graph),
created_by=account.id,
@ -689,17 +689,17 @@ class RagPipelineDslService:
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
if data_type == NodeType.KNOWLEDGE_RETRIEVAL:
dataset_ids = node_data.get("dataset_ids", [])
node["data"]["dataset_ids"] = [
self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
for dataset_id in dataset_ids
]
# filter credential id from tool node
if not include_secret and data_type == NodeType.TOOL.value:
if not include_secret and data_type == NodeType.TOOL:
node_data.pop("credential_id", None)
# filter credential id from agent node
if not include_secret and data_type == NodeType.AGENT.value:
if not include_secret and data_type == NodeType.AGENT:
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
tool.pop("credential_id", None)
@ -733,35 +733,35 @@ class RagPipelineDslService:
try:
typ = node.get("data", {}).get("type")
match typ:
case NodeType.TOOL.value:
case NodeType.TOOL:
tool_entity = ToolNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
)
case NodeType.DATASOURCE.value:
case NodeType.DATASOURCE:
datasource_entity = DatasourceNodeData.model_validate(node["data"])
if datasource_entity.provider_type != "local_file":
dependencies.append(datasource_entity.plugin_id)
case NodeType.LLM.value:
case NodeType.LLM:
llm_entity = LLMNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
)
case NodeType.QUESTION_CLASSIFIER.value:
case NodeType.QUESTION_CLASSIFIER:
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
question_classifier_entity.model.provider
),
)
case NodeType.PARAMETER_EXTRACTOR.value:
case NodeType.PARAMETER_EXTRACTOR:
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
parameter_extractor_entity.model.provider
),
)
case NodeType.KNOWLEDGE_INDEX.value:
case NodeType.KNOWLEDGE_INDEX:
knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"])
if knowledge_index_entity.indexing_technique == "high_quality":
if knowledge_index_entity.embedding_model_provider:
@ -782,7 +782,7 @@ class RagPipelineDslService:
knowledge_index_entity.retrieval_model.reranking_model.reranking_provider_name
),
)
case NodeType.KNOWLEDGE_RETRIEVAL.value:
case NodeType.KNOWLEDGE_RETRIEVAL:
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
if knowledge_retrieval_entity.retrieval_mode == "multiple":
if knowledge_retrieval_entity.multiple_retrieval_config:
@ -927,7 +927,7 @@ class RagPipelineDslService:
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = self.import_rag_pipeline(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=None,
dataset_name=rag_pipeline_dataset_create_entity.name,

View File

@ -214,7 +214,7 @@ class RagPipelineTransformService:
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
features="{}",
type=WorkflowType.RAG_PIPELINE.value,
type=WorkflowType.RAG_PIPELINE,
version="draft",
graph=json.dumps(graph),
created_by=current_user.id,
@ -226,7 +226,7 @@ class RagPipelineTransformService:
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
features="{}",
type=WorkflowType.RAG_PIPELINE.value,
type=WorkflowType.RAG_PIPELINE,
version=str(datetime.now(UTC).replace(tzinfo=None)),
graph=json.dumps(graph),
created_by=current_user.id,

View File

@ -277,7 +277,7 @@ class ApiToolManageService:
provider.icon = json.dumps(icon)
provider.schema = schema
provider.description = extra_info.get("description", "")
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
provider.schema_type_str = ApiProviderSchemaType.OPENAPI
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
provider.custom_disclaimer = custom_disclaimer
@ -393,7 +393,7 @@ class ApiToolManageService:
icon="",
schema=schema,
description="",
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
schema_type_str=ApiProviderSchemaType.OPENAPI,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str=json.dumps(credentials),
)

View File

@ -50,16 +50,16 @@ class ToolTransformService:
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider"
)
if provider_type == ToolProviderType.BUILT_IN.value:
if provider_type == ToolProviderType.BUILT_IN:
return str(url_prefix / "builtin" / provider_name / "icon")
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}:
try:
if isinstance(icon, str):
return json.loads(icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
elif provider_type == ToolProviderType.MCP.value:
elif provider_type == ToolProviderType.MCP:
return icon
return ""

View File

@ -134,7 +134,7 @@ class VectorService:
)
# use full doc mode to generate segment's child chunk
processing_rule_dict = processing_rule.to_dict()
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC
documents = index_processor.transform(
[document],
embedding_model_instance=embedding_model_instance,

View File

@ -36,7 +36,7 @@ class WebAppAuthService:
if not account:
raise AccountNotFoundError()
if account.status == AccountStatus.BANNED.value:
if account.status == AccountStatus.BANNED:
raise AccountLoginError("Account is banned.")
if account.password is None or not compare_password(password, account.password, account.password_salt):
@ -56,7 +56,7 @@ class WebAppAuthService:
if not account:
return None
if account.status == AccountStatus.BANNED.value:
if account.status == AccountStatus.BANNED:
raise Unauthorized("Account is banned.")
return account

View File

@ -228,7 +228,7 @@ class WorkflowConverter:
"position": None,
"data": {
"title": "START",
"type": NodeType.START.value,
"type": NodeType.START,
"variables": [jsonable_encoder(v) for v in variables],
},
}
@ -273,7 +273,7 @@ class WorkflowConverter:
inputs[v.variable] = "{{#start." + v.variable + "#}}"
request_body = {
"point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value,
"point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY,
"params": {
"app_id": app_model.id,
"tool_variable": tool_variable,
@ -290,7 +290,7 @@ class WorkflowConverter:
"position": None,
"data": {
"title": f"HTTP REQUEST {api_based_extension.name}",
"type": NodeType.HTTP_REQUEST.value,
"type": NodeType.HTTP_REQUEST,
"method": "post",
"url": api_based_extension.api_endpoint,
"authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}},
@ -308,7 +308,7 @@ class WorkflowConverter:
"position": None,
"data": {
"title": f"Parse {api_based_extension.name} Response",
"type": NodeType.CODE.value,
"type": NodeType.CODE,
"variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}],
"code_language": "python3",
"code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads("
@ -348,7 +348,7 @@ class WorkflowConverter:
"position": None,
"data": {
"title": "KNOWLEDGE RETRIEVAL",
"type": NodeType.KNOWLEDGE_RETRIEVAL.value,
"type": NodeType.KNOWLEDGE_RETRIEVAL,
"query_variable_selector": query_variable_selector,
"dataset_ids": dataset_config.dataset_ids,
"retrieval_mode": retrieve_config.retrieve_strategy.value,
@ -396,16 +396,16 @@ class WorkflowConverter:
:param external_data_variable_node_mapping: external data variable node mapping
"""
# fetch start and knowledge retrieval node
start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"]))
start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START, graph["nodes"]))
knowledge_retrieval_node = next(
filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None
filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL, graph["nodes"]), None
)
role_prefix = None
prompts: Any | None = None
# Chat Model
if model_config.mode == LLMMode.CHAT.value:
if model_config.mode == LLMMode.CHAT:
if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
if not prompt_template.simple_prompt_template:
raise ValueError("Simple prompt template is required")
@ -517,7 +517,7 @@ class WorkflowConverter:
"position": None,
"data": {
"title": "LLM",
"type": NodeType.LLM.value,
"type": NodeType.LLM,
"model": {
"provider": model_config.provider,
"name": model_config.model,
@ -572,7 +572,7 @@ class WorkflowConverter:
"position": None,
"data": {
"title": "END",
"type": NodeType.END.value,
"type": NodeType.END,
"outputs": [{"variable": "result", "value_selector": ["llm", "text"]}],
},
}
@ -586,7 +586,7 @@ class WorkflowConverter:
return {
"id": "answer",
"position": None,
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
"data": {"title": "ANSWER", "type": NodeType.ANSWER, "answer": "{{#llm.text#}}"},
}
def _create_edge(self, source: str, target: str):

View File

@ -569,7 +569,7 @@ class WorkflowDraftVariableService:
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from=InvokeFrom.DEBUGGER.value,
invoke_from=InvokeFrom.DEBUGGER,
from_source="console",
from_end_user_id=None,
from_account_id=account_id,

View File

@ -74,7 +74,7 @@ class WorkflowRunService:
return self._workflow_run_repo.get_paginated_workflow_runs(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
limit=limit,
last_id=last_id,
)

View File

@ -1006,7 +1006,7 @@ def _setup_variable_pool(
)
# Only add chatflow-specific variables for non-workflow types
if workflow.type != WorkflowType.WORKFLOW.value:
if workflow.type != WorkflowType.WORKFLOW:
system_variable.query = query
system_variable.conversation_id = conversation_id
system_variable.dialogue_count = 1

View File

@ -25,7 +25,7 @@ class TestChatMessageApiPermissions:
"""Create a mock App model for testing."""
app = App()
app.id = str(uuid.uuid4())
app.mode = AppMode.CHAT.value
app.mode = AppMode.CHAT
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
return app

View File

@ -23,7 +23,7 @@ class TestModelConfigResourcePermissions:
"""Create a mock App model for testing."""
app = App()
app.id = str(uuid.uuid4())
app.mode = AppMode.CHAT.value
app.mode = AppMode.CHAT
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
app.app_model_config_id = str(uuid.uuid4())

View File

@ -542,7 +542,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
index=1,
node_execution_id=str(uuid.uuid4()),
node_id=self._node_id,
node_type=NodeType.LLM.value,
node_type=NodeType.LLM,
title="Test Node",
inputs='{"input": "test input"}',
process_data='{"test_var": "process_value", "other_var": "other_process"}',

View File

@ -44,25 +44,25 @@ class MockClient:
"hits": [
{
"_source": {
Field.CONTENT_KEY.value: "abcdef",
Field.VECTOR.value: [1, 2],
Field.METADATA_KEY.value: {},
Field.CONTENT_KEY: "abcdef",
Field.VECTOR: [1, 2],
Field.METADATA_KEY: {},
},
"_score": 1.0,
},
{
"_source": {
Field.CONTENT_KEY.value: "123456",
Field.VECTOR.value: [2, 2],
Field.METADATA_KEY.value: {},
Field.CONTENT_KEY: "123456",
Field.VECTOR: [2, 2],
Field.METADATA_KEY: {},
},
"_score": 0.9,
},
{
"_source": {
Field.CONTENT_KEY.value: "a1b2c3",
Field.VECTOR.value: [3, 2],
Field.METADATA_KEY.value: {},
Field.CONTENT_KEY: "a1b2c3",
Field.VECTOR: [3, 2],
Field.METADATA_KEY: {},
},
"_score": 0.8,
},

Some files were not shown because too many files have changed in this diff Show More