Merge branch 'feat/collaboration' into deploy/dev

This commit is contained in:
hjlarry 2025-10-13 10:16:39 +08:00
commit 9fc2a0a3a1
239 changed files with 8164 additions and 4242 deletions

View File

@ -30,6 +30,8 @@ jobs:
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
# Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union

View File

@ -4,8 +4,7 @@ on:
push:
branches:
- "main"
- "deploy/dev"
- "deploy/enterprise"
- "deploy/**"
- "build/**"
- "release/e-*"
- "hotfix/**"

View File

@ -1,4 +1,4 @@
name: Deploy RAG Dev
name: Deploy Trigger Dev
permissions:
contents: read
@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/rag-dev"
- "deploy/trigger-dev"
types:
- completed
@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/rag-dev'
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.RAG_SSH_HOST }}
host: ${{ secrets.TRIGGER_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@ -343,6 +343,15 @@ OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
# AlibabaCloud MySQL Vector configuration
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
ALIBABACLOUD_MYSQL_PORT=3306
ALIBABACLOUD_MYSQL_USER=root
ALIBABACLOUD_MYSQL_PASSWORD=root
ALIBABACLOUD_MYSQL_DATABASE=dify
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
ALIBABACLOUD_MYSQL_HNSW_M=6
# openGauss configuration
OPENGAUSS_HOST=127.0.0.1
OPENGAUSS_PORT=6600

View File

@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
from .storage.supabase_storage_config import SupabaseStorageConfig
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
from .vdb.analyticdb_config import AnalyticdbConfig
from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig
@ -330,6 +331,7 @@ class MiddlewareConfig(
ClickzettaConfig,
HuaweiCloudConfig,
MilvusConfig,
AlibabaCloudMySQLConfig,
MyScaleConfig,
OpenSearchConfig,
OracleConfig,

View File

@ -0,0 +1,54 @@
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AlibabaCloudMySQLConfig(BaseSettings):
"""
Configuration settings for AlibabaCloud MySQL vector database
"""
ALIBABACLOUD_MYSQL_HOST: str = Field(
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
default="localhost",
)
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
default=3306,
)
ALIBABACLOUD_MYSQL_USER: str = Field(
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
default="root",
)
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
default="",
)
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
default="dify",
)
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
description="Maximum number of connections in the connection pool",
default=5,
)
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
default="utf8mb4",
)
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
"(e.g., 'cosine', 'euclidean')",
default="cosine",
)
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
default=6,
)

View File

@ -1,23 +1,24 @@
from enum import Enum
from enum import StrEnum
from typing import Literal
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AuthMethod(StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
class OpenSearchConfig(BaseSettings):
"""
Configuration settings for OpenSearch
"""
class AuthMethod(Enum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: str | None = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None,

View File

@ -304,7 +304,7 @@ class AppCopyApi(Resource):
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name=args.get("name"),
description=args.get("description"),

View File

@ -70,9 +70,9 @@ class AppImportApi(Resource):
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@ -97,7 +97,7 @@ class AppImportConfirmApi(Resource):
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200

View File

@ -309,7 +309,7 @@ class ChatConversationApi(Resource):
)
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args["sort_by"]:
case "created_at":

View File

@ -14,6 +14,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models.account import Account
from models.model import AppMode, AppModelConfig
@ -172,6 +173,8 @@ class ModelConfigResource(Resource):
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app_model.updated_by = current_user.id
app_model.updated_at = naive_utc_now()
db.session.commit()
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)

View File

@ -52,7 +52,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -127,7 +127,7 @@ class DailyConversationStatistic(Resource):
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
)
.select_from(Message)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
)
if args["start"]:
@ -190,7 +190,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -263,7 +263,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -345,7 +345,7 @@ FROM
WHERE
c.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -432,7 +432,7 @@ LEFT JOIN
WHERE
m.app_id = :app_id
AND m.invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -509,7 +509,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
@ -584,7 +584,7 @@ FROM
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc

View File

@ -27,6 +27,7 @@ from fields.online_user_fields import online_user_list_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models import App
@ -679,8 +680,12 @@ class PublishedWorkflowApi(Resource):
marked_comment=args.marked_comment or "",
)
app_model.workflow_id = workflow.id
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
# Update app_model within the same session to ensure atomicity
app_model_in_session = session.get(App, app_model.id)
if app_model_in_session:
app_model_in_session.workflow_id = workflow.id
app_model_in_session.updated_by = current_user.id
app_model_in_session.updated_at = naive_utc_now()
workflow_created_at = TimestampField().format(workflow.created_at)

View File

@ -47,7 +47,7 @@ WHERE
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
timezone = pytz.timezone(account.timezone)
@ -115,7 +115,7 @@ WHERE
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
timezone = pytz.timezone(account.timezone)
@ -183,7 +183,7 @@ WHERE
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
timezone = pytz.timezone(account.timezone)
@ -269,7 +269,7 @@ GROUP BY
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
}
timezone = pytz.timezone(account.timezone)

View File

@ -103,7 +103,7 @@ class ActivateApi(Resource):
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()

View File

@ -130,11 +130,11 @@ class OAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
# Check account status
if account.status == AccountStatus.BANNED.value:
if account.status == AccountStatus.BANNED:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
if account.status == AccountStatus.PENDING:
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()

View File

@ -256,7 +256,7 @@ class DataSourceNotionApi(Resource):
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,

View File

@ -500,7 +500,7 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value,
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=args["doc_form"],
)
@ -512,7 +512,7 @@ class DatasetIndexingEstimateApi(Resource):
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
@ -529,7 +529,7 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": website_info_list["provider"],
@ -786,7 +786,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
@ -810,12 +810,13 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
| VectorType.ALIBABACLOUD_MYSQL
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH,
]
}
case _:
@ -842,7 +843,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
@ -864,12 +865,13 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
| VectorType.ALIBABACLOUD_MYSQL
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH,
]
}
case _:

View File

@ -475,7 +475,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form
)
indexing_runner = IndexingRunner()
@ -538,7 +538,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
@ -546,7 +546,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
@ -563,7 +563,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],

View File

@ -60,9 +60,9 @@ class RagPipelineImportApi(Resource):
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@ -87,7 +87,7 @@ class RagPipelineImportConfirmApi(Resource):
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200

View File

@ -25,8 +25,8 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
As a result, it could only be considered as an end user id.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
with Session(db.engine) as session:
user_model = None
@ -85,7 +85,7 @@ def get_user_tenant(view: Callable[P, R] | None = None):
raise ValueError("tenant_id is required")
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
tenant_model = (

View File

@ -313,7 +313,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
with Session(db.engine, expire_on_commit=False) as session:
end_user = (
@ -332,7 +332,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="service_api",
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
session_id=user_id,
)
session.add(end_user)

View File

@ -197,12 +197,12 @@ class DatasetConfigManager:
# strategy
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER
has_datasets = False
if config.get("agent_mode", {}).get("strategy") in {
PlanningStrategy.ROUTER.value,
PlanningStrategy.REACT_ROUTER.value,
PlanningStrategy.ROUTER,
PlanningStrategy.REACT_ROUTER,
}:
for tool in config.get("agent_mode", {}).get("tools", []):
key = list(tool.keys())[0]

View File

@ -68,9 +68,13 @@ class ModelConfigConverter:
# get model mode
model_mode = model_config.mode
if not model_mode:
model_mode = LLMMode.CHAT.value
model_mode = LLMMode.CHAT
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value
try:
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE])
except ValueError:
# Fall back to CHAT mode if the stored value is invalid
model_mode = LLMMode.CHAT
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")

View File

@ -100,7 +100,7 @@ class PromptTemplateConfigManager:
if config["model"]["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION:
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
@ -110,7 +110,7 @@ class PromptTemplateConfigManager:
if not assistant_prefix:
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
if config["model"]["mode"] == ModelMode.CHAT.value:
if config["model"]["mode"] == ModelMode.CHAT:
prompt_list = config["chat_prompt_config"]["prompt"]
if len(prompt_list) > 10:

View File

@ -186,7 +186,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
raise ValueError("enabled in agent_mode must be of boolean type")
if not agent_mode.get("strategy"):
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
agent_mode["strategy"] = PlanningStrategy.ROUTER
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")

View File

@ -198,9 +198,9 @@ class AgentChatAppRunner(AppRunner):
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
runner_cls = CotChatAgentRunner
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value:
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
runner_cls = CotCompletionAgentRunner
else:
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")

View File

@ -229,8 +229,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@ -100,8 +100,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
@ -244,8 +244,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow.id,
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT.value,
invoke_from=InvokeFrom.SERVICE_API.value,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)

View File

@ -49,7 +49,7 @@ class DatasourceProviderApiEntity(BaseModel):
for datasource in datasources:
if datasource.get("parameters"):
for parameter in datasource.get("parameters"):
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES:
parameter["type"] = "files"
# -------------

View File

@ -54,16 +54,16 @@ class DatasourceParameter(PluginParameter):
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
STRING = PluginParameterType.STRING
NUMBER = PluginParameterType.NUMBER
BOOLEAN = PluginParameterType.BOOLEAN
SELECT = PluginParameterType.SELECT
SECRET_INPUT = PluginParameterType.SECRET_INPUT
FILE = PluginParameterType.FILE
FILES = PluginParameterType.FILES
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
def as_normal_type(self):
return as_normal_type(self)

View File

@ -207,7 +207,7 @@ class ProviderConfiguration(BaseModel):
"""
stmt = select(Provider).where(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_type == ProviderType.CUSTOM,
Provider.provider_name.in_(self._get_provider_names()),
)
@ -458,7 +458,7 @@ class ProviderConfiguration(BaseModel):
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
provider_type=ProviderType.CUSTOM.value,
provider_type=ProviderType.CUSTOM,
is_valid=True,
credential_id=new_record.id,
)
@ -1414,7 +1414,7 @@ class ProviderConfiguration(BaseModel):
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -1,13 +1,13 @@
from typing import cast
import requests
import httpx
from configs import dify_config
from models.api_based_extension import APIBasedExtensionPoint
class APIBasedExtensionRequestor:
timeout: tuple[int, int] = (5, 60)
timeout: httpx.Timeout = httpx.Timeout(60.0, connect=5.0)
"""timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str):
@ -27,25 +27,23 @@ class APIBasedExtensionRequestor:
url = self.api_endpoint
try:
# proxy support for security
proxies = None
mounts: dict[str, httpx.BaseTransport] | None = None
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxies = {
"http": dify_config.SSRF_PROXY_HTTP_URL,
"https": dify_config.SSRF_PROXY_HTTPS_URL,
mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
}
response = requests.request(
method="POST",
url=url,
json={"point": point.value, "params": params},
headers=headers,
timeout=self.timeout,
proxies=proxies,
)
except requests.Timeout:
with httpx.Client(mounts=mounts, timeout=self.timeout) as client:
response = client.request(
method="POST",
url=url,
json={"point": point.value, "params": params},
headers=headers,
)
except httpx.TimeoutException:
raise ValueError("request timeout")
except requests.ConnectionError:
except httpx.RequestError:
raise ValueError("request connection error")
if response.status_code != 200:

View File

@ -343,7 +343,7 @@ class IndexingRunner:
if file_detail:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value,
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
@ -356,7 +356,7 @@ class IndexingRunner:
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
@ -379,7 +379,7 @@ class IndexingRunner:
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],

View File

@ -224,8 +224,8 @@ def _handle_native_json_schema(
# Set appropriate response format if required by the model
for rule in rules:
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA
return model_parameters
@ -239,10 +239,10 @@ def _set_response_format(model_parameters: dict, rules: list):
"""
for rule in rules:
if rule.name == "response_format":
if ResponseFormat.JSON.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON.value
elif ResponseFormat.JSON_OBJECT.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
if ResponseFormat.JSON in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON
elif ResponseFormat.JSON_OBJECT in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT
def _handle_prompt_based_schema(

View File

@ -213,9 +213,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_metadata.update(json.loads(node_execution.execution_metadata))
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN.value
span_kind = OpenInferenceSpanKindValues.CHAIN
if node_execution.node_type == "llm":
span_kind = OpenInferenceSpanKindValues.LLM.value
span_kind = OpenInferenceSpanKindValues.LLM
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
@ -230,18 +230,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
elif node_execution.node_type == "dataset_retrieval":
span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
span_kind = OpenInferenceSpanKindValues.RETRIEVER
elif node_execution.node_type == "tool":
span_kind = OpenInferenceSpanKindValues.TOOL.value
span_kind = OpenInferenceSpanKindValues.TOOL
else:
span_kind = OpenInferenceSpanKindValues.CHAIN.value
span_kind = OpenInferenceSpanKindValues.CHAIN
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},

View File

@ -73,7 +73,7 @@ class LangFuseDataTrace(BaseTraceInstance):
if trace_info.message_id:
trace_id = trace_info.trace_id or trace_info.message_id
name = TraceTaskName.MESSAGE_TRACE.value
name = TraceTaskName.MESSAGE_TRACE
trace_data = LangfuseTrace(
id=trace_id,
user_id=user_id,
@ -88,7 +88,7 @@ class LangFuseDataTrace(BaseTraceInstance):
self.add_trace(langfuse_trace_data=trace_data)
workflow_span_data = LangfuseSpan(
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
name=TraceTaskName.WORKFLOW_TRACE,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
trace_id=trace_id,
@ -103,7 +103,7 @@ class LangFuseDataTrace(BaseTraceInstance):
trace_data = LangfuseTrace(
id=trace_id,
user_id=user_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
name=TraceTaskName.WORKFLOW_TRACE,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
metadata=metadata,
@ -253,7 +253,7 @@ class LangFuseDataTrace(BaseTraceInstance):
trace_data = LangfuseTrace(
id=trace_id,
user_id=user_id,
name=TraceTaskName.MESSAGE_TRACE.value,
name=TraceTaskName.MESSAGE_TRACE,
input={
"message": trace_info.inputs,
"files": file_list,
@ -303,7 +303,7 @@ class LangFuseDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
span_data = LangfuseSpan(
name=TraceTaskName.MODERATION_TRACE.value,
name=TraceTaskName.MODERATION_TRACE,
input=trace_info.inputs,
output={
"action": trace_info.action,
@ -331,7 +331,7 @@ class LangFuseDataTrace(BaseTraceInstance):
)
generation_data = LangfuseGeneration(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
input=trace_info.inputs,
output=str(trace_info.suggested_question),
trace_id=trace_info.trace_id or trace_info.message_id,
@ -349,7 +349,7 @@ class LangFuseDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
dataset_retrieval_span_data = LangfuseSpan(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
input=trace_info.inputs,
output={"documents": trace_info.documents},
trace_id=trace_info.trace_id or trace_info.message_id,
@ -377,7 +377,7 @@ class LangFuseDataTrace(BaseTraceInstance):
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
name_generation_trace_data = LangfuseTrace(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
name=TraceTaskName.GENERATE_NAME_TRACE,
input=trace_info.inputs,
output=trace_info.outputs,
user_id=trace_info.tenant_id,
@ -388,7 +388,7 @@ class LangFuseDataTrace(BaseTraceInstance):
self.add_trace(langfuse_trace_data=name_generation_trace_data)
name_generation_span_data = LangfuseSpan(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
name=TraceTaskName.GENERATE_NAME_TRACE,
input=trace_info.inputs,
output=trace_info.outputs,
trace_id=trace_info.conversation_id,

View File

@ -81,7 +81,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if trace_info.message_id:
message_run = LangSmithRunModel(
id=trace_info.message_id,
name=TraceTaskName.MESSAGE_TRACE.value,
name=TraceTaskName.MESSAGE_TRACE,
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
run_type=LangSmithRunType.chain,
@ -110,7 +110,7 @@ class LangSmithDataTrace(BaseTraceInstance):
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
name=TraceTaskName.WORKFLOW_TRACE,
inputs=dict(trace_info.workflow_run_inputs),
run_type=LangSmithRunType.tool,
start_time=trace_info.workflow_data.created_at,
@ -271,7 +271,7 @@ class LangSmithDataTrace(BaseTraceInstance):
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
id=message_id,
name=TraceTaskName.MESSAGE_TRACE.value,
name=TraceTaskName.MESSAGE_TRACE,
inputs=trace_info.inputs,
run_type=LangSmithRunType.chain,
start_time=trace_info.start_time,
@ -327,7 +327,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
langsmith_run = LangSmithRunModel(
name=TraceTaskName.MODERATION_TRACE.value,
name=TraceTaskName.MODERATION_TRACE,
inputs=trace_info.inputs,
outputs={
"action": trace_info.action,
@ -362,7 +362,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if message_data is None:
return
suggested_question_run = LangSmithRunModel(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
inputs=trace_info.inputs,
outputs=trace_info.suggested_question,
run_type=LangSmithRunType.tool,
@ -391,7 +391,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if trace_info.message_data is None:
return
dataset_retrieval_run = LangSmithRunModel(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
inputs=trace_info.inputs,
outputs={"documents": trace_info.documents},
run_type=LangSmithRunType.retriever,
@ -447,7 +447,7 @@ class LangSmithDataTrace(BaseTraceInstance):
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
name_run = LangSmithRunModel(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
name=TraceTaskName.GENERATE_NAME_TRACE,
inputs=trace_info.inputs,
outputs=trace_info.outputs,
run_type=LangSmithRunType.tool,

View File

@ -108,7 +108,7 @@ class OpikDataTrace(BaseTraceInstance):
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE.value,
"name": TraceTaskName.MESSAGE_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
@ -125,7 +125,7 @@ class OpikDataTrace(BaseTraceInstance):
"id": root_span_id,
"parent_span_id": None,
"trace_id": opik_trace_id,
"name": TraceTaskName.WORKFLOW_TRACE.value,
"name": TraceTaskName.WORKFLOW_TRACE,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"start_time": trace_info.start_time,
@ -138,7 +138,7 @@ class OpikDataTrace(BaseTraceInstance):
else:
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE.value,
"name": TraceTaskName.MESSAGE_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
@ -290,7 +290,7 @@ class OpikDataTrace(BaseTraceInstance):
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, dify_trace_id),
"name": TraceTaskName.MESSAGE_TRACE.value,
"name": TraceTaskName.MESSAGE_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(metadata),
@ -329,7 +329,7 @@ class OpikDataTrace(BaseTraceInstance):
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.MODERATION_TRACE.value,
"name": TraceTaskName.MODERATION_TRACE,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
@ -355,7 +355,7 @@ class OpikDataTrace(BaseTraceInstance):
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or message_data.updated_at,
@ -375,7 +375,7 @@ class OpikDataTrace(BaseTraceInstance):
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
@ -405,7 +405,7 @@ class OpikDataTrace(BaseTraceInstance):
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id),
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
"name": TraceTaskName.GENERATE_NAME_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
@ -420,7 +420,7 @@ class OpikDataTrace(BaseTraceInstance):
span_data = {
"trace_id": trace.id,
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
"name": TraceTaskName.GENERATE_NAME_TRACE,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),

View File

@ -104,7 +104,7 @@ class WeaveDataTrace(BaseTraceInstance):
message_run = WeaveTraceModel(
id=trace_info.message_id,
op=str(TraceTaskName.MESSAGE_TRACE.value),
op=str(TraceTaskName.MESSAGE_TRACE),
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
total_tokens=trace_info.total_tokens,
@ -126,7 +126,7 @@ class WeaveDataTrace(BaseTraceInstance):
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_run_id,
op=str(TraceTaskName.WORKFLOW_TRACE.value),
op=str(TraceTaskName.WORKFLOW_TRACE),
inputs=dict(trace_info.workflow_run_inputs),
outputs=dict(trace_info.workflow_run_outputs),
attributes=workflow_attributes,
@ -253,7 +253,7 @@ class WeaveDataTrace(BaseTraceInstance):
message_run = WeaveTraceModel(
id=trace_id,
op=str(TraceTaskName.MESSAGE_TRACE.value),
op=str(TraceTaskName.MESSAGE_TRACE),
input_tokens=trace_info.message_tokens,
output_tokens=trace_info.answer_tokens,
total_tokens=trace_info.total_tokens,
@ -300,7 +300,7 @@ class WeaveDataTrace(BaseTraceInstance):
moderation_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.MODERATION_TRACE.value),
op=str(TraceTaskName.MODERATION_TRACE),
inputs=trace_info.inputs,
outputs={
"action": trace_info.action,
@ -330,7 +330,7 @@ class WeaveDataTrace(BaseTraceInstance):
suggested_question_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value),
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE),
inputs=trace_info.inputs,
outputs=trace_info.suggested_question,
attributes=attributes,
@ -355,7 +355,7 @@ class WeaveDataTrace(BaseTraceInstance):
dataset_retrieval_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value),
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE),
inputs=trace_info.inputs,
outputs={"documents": trace_info.documents},
attributes=attributes,
@ -397,7 +397,7 @@ class WeaveDataTrace(BaseTraceInstance):
name_run = WeaveTraceModel(
id=str(uuid.uuid4()),
op=str(TraceTaskName.GENERATE_NAME_TRACE.value),
op=str(TraceTaskName.GENERATE_NAME_TRACE),
inputs=trace_info.inputs,
outputs=trace_info.outputs,
attributes=attributes,

View File

@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
instruction=instruction, # instruct with variables are not supported
)
node_data_dict = node_data.model_dump()
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR.value
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR
execution = workflow_service.run_free_workflow_node(
node_data_dict,
tenant_id=tenant_id,

View File

@ -83,13 +83,13 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
raise ValueError("prompt_messages must be a list")
for i in range(len(v)):
if v[i]["role"] == PromptMessageRole.USER.value:
if v[i]["role"] == PromptMessageRole.USER:
v[i] = UserPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
elif v[i]["role"] == PromptMessageRole.ASSISTANT:
v[i] = AssistantPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
elif v[i]["role"] == PromptMessageRole.SYSTEM:
v[i] = SystemPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.TOOL.value:
elif v[i]["role"] == PromptMessageRole.TOOL:
v[i] = ToolPromptMessage.model_validate(v[i])
else:
v[i] = PromptMessage.model_validate(v[i])

View File

@ -2,11 +2,10 @@ import inspect
import json
import logging
from collections.abc import Callable, Generator
from typing import TypeVar
from typing import Any, TypeVar
import requests
import httpx
from pydantic import BaseModel
from requests.exceptions import HTTPError
from yarl import URL
from configs import dify_config
@ -47,29 +46,56 @@ class BasePluginClient:
data: bytes | dict | str | None = None,
params: dict | None = None,
files: dict | None = None,
stream: bool = False,
) -> requests.Response:
) -> httpx.Response:
"""
Make a request to the plugin daemon inner API.
"""
url = plugin_daemon_inner_api_baseurl / path
headers = headers or {}
headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
headers["Accept-Encoding"] = "gzip, deflate, br"
url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files)
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
data = json.dumps(data)
request_kwargs: dict[str, Any] = {
"method": method,
"url": url,
"headers": headers,
"params": params,
"files": files,
}
if isinstance(prepared_data, dict):
request_kwargs["data"] = prepared_data
elif prepared_data is not None:
request_kwargs["content"] = prepared_data
try:
response = requests.request(
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
)
except requests.ConnectionError:
response = httpx.request(**request_kwargs)
except httpx.RequestError:
logger.exception("Request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
return response
def _prepare_request(
self,
path: str,
headers: dict | None,
data: bytes | dict | str | None,
params: dict | None,
files: dict | None,
) -> tuple[str, dict, bytes | dict | str | None, dict | None, dict | None]:
url = plugin_daemon_inner_api_baseurl / path
prepared_headers = dict(headers or {})
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
prepared_data: bytes | dict | str | None = (
data if isinstance(data, (bytes, str, dict)) or data is None else None
)
if isinstance(data, dict):
if prepared_headers.get("Content-Type") == "application/json":
prepared_data = json.dumps(data)
else:
prepared_data = data
return str(url), prepared_headers, prepared_data, params, files
def _stream_request(
self,
method: str,
@ -78,17 +104,38 @@ class BasePluginClient:
headers: dict | None = None,
data: bytes | dict | None = None,
files: dict | None = None,
) -> Generator[bytes, None, None]:
) -> Generator[str, None, None]:
"""
Make a stream request to the plugin daemon inner API
"""
response = self._request(method, path, headers, data, params, files, stream=True)
for line in response.iter_lines(chunk_size=1024 * 8):
line = line.decode("utf-8").strip()
if line.startswith("data:"):
line = line[5:].strip()
if line:
yield line
url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files)
stream_kwargs: dict[str, Any] = {
"method": method,
"url": url,
"headers": headers,
"params": params,
"files": files,
}
if isinstance(prepared_data, dict):
stream_kwargs["data"] = prepared_data
elif prepared_data is not None:
stream_kwargs["content"] = prepared_data
try:
with httpx.stream(**stream_kwargs) as response:
for raw_line in response.iter_lines():
if raw_line is None:
continue
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
line = line.strip()
if line.startswith("data:"):
line = line[5:].strip()
if line:
yield line
except httpx.RequestError:
logger.exception("Stream request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
def _stream_request_with_model(
self,
@ -139,7 +186,7 @@ class BasePluginClient:
try:
response = self._request(method, path, headers, data, params, files)
response.raise_for_status()
except HTTPError as e:
except httpx.HTTPStatusError as e:
logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path)
raise e
except Exception as e:

View File

@ -610,7 +610,7 @@ class ProviderManager:
provider_quota_to_provider_record_dict = {}
for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value:
if provider_record.provider_type != ProviderType.SYSTEM:
continue
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
@ -702,7 +702,7 @@ class ProviderManager:
"""Get custom provider configuration."""
# Find custom provider record (non-system)
custom_provider_record = next(
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM), None
)
if not custom_provider_record:
@ -905,7 +905,7 @@ class ProviderManager:
# Convert provider_records to dict
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value:
if provider_record.provider_type != ProviderType.SYSTEM:
continue
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
@ -1082,7 +1082,7 @@ class ProviderManager:
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -46,7 +46,7 @@ class DataPostProcessor:
reranking_model: dict | None = None,
weights: dict | None = None,
) -> BaseRerankRunner | None:
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
runner = RerankRunnerFactory.create_rerank_runner(
runner_type=reranking_mode,
tenant_id=tenant_id,
@ -62,7 +62,7 @@ class DataPostProcessor:
),
)
return runner
elif reranking_mode == RerankMode.RERANKING_MODEL.value:
elif reranking_mode == RerankMode.RERANKING_MODEL:
rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model)
if rerank_model_instance is None:
return None

View File

@ -21,7 +21,7 @@ from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
@ -107,7 +107,7 @@ class RetrievalService:
raise ValueError(";\n".join(exceptions))
# Deduplicate documents for hybrid search to avoid duplicate chunks
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
all_documents = cls._deduplicate_documents(all_documents)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
@ -245,10 +245,10 @@ class RetrievalService:
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
@ -293,10 +293,10 @@ class RetrievalService:
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(

View File

@ -0,0 +1,388 @@
import hashlib
import json
import logging
import uuid
from contextlib import contextmanager
from typing import Any, Literal, cast
import mysql.connector
from mysql.connector import Error as MySQLError
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class AlibabaCloudMySQLVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str
max_connection: int
charset: str = "utf8mb4"
distance_function: Literal["cosine", "euclidean"] = "cosine"
hnsw_m: int = 6
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values.get("host"):
raise ValueError("config ALIBABACLOUD_MYSQL_HOST is required")
if not values.get("port"):
raise ValueError("config ALIBABACLOUD_MYSQL_PORT is required")
if not values.get("user"):
raise ValueError("config ALIBABACLOUD_MYSQL_USER is required")
if values.get("password") is None:
raise ValueError("config ALIBABACLOUD_MYSQL_PASSWORD is required")
if not values.get("database"):
raise ValueError("config ALIBABACLOUD_MYSQL_DATABASE is required")
if not values.get("max_connection"):
raise ValueError("config ALIBABACLOUD_MYSQL_MAX_CONNECTION is required")
return values
SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id VARCHAR(36) PRIMARY KEY,
text LONGTEXT NOT NULL,
meta JSON NOT NULL,
embedding VECTOR({dimension}) NOT NULL,
VECTOR INDEX (embedding) M={hnsw_m} DISTANCE={distance_function}
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
"""
SQL_CREATE_META_INDEX = """
CREATE INDEX idx_{index_hash}_meta ON {table_name}
((CAST(JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) AS CHAR(36))));
"""
SQL_CREATE_FULLTEXT_INDEX = """
CREATE FULLTEXT INDEX idx_{index_hash}_text ON {table_name} (text) WITH PARSER ngram;
"""
class AlibabaCloudMySQLVector(BaseVector):
def __init__(self, collection_name: str, config: AlibabaCloudMySQLVectorConfig):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = collection_name.lower()
self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8]
self.distance_function = config.distance_function.lower()
self.hnsw_m = config.hnsw_m
self._check_vector_support()
def get_type(self) -> str:
return VectorType.ALIBABACLOUD_MYSQL
def _create_connection_pool(self, config: AlibabaCloudMySQLVectorConfig):
# Create connection pool using mysql-connector-python pooling
pool_config: dict[str, Any] = {
"host": config.host,
"port": config.port,
"user": config.user,
"password": config.password,
"database": config.database,
"charset": config.charset,
"autocommit": True,
"pool_name": f"pool_{self.collection_name}",
"pool_size": config.max_connection,
"pool_reset_session": True,
}
return mysql.connector.pooling.MySQLConnectionPool(**pool_config)
def _check_vector_support(self):
"""Check if the MySQL server supports vector operations."""
try:
with self._get_cursor() as cur:
# Check MySQL version and vector support
cur.execute("SELECT VERSION()")
version = cur.fetchone()["VERSION()"]
logger.debug("Connected to MySQL version: %s", version)
# Try to execute a simple vector function to verify support
cur.execute("SELECT VEC_FromText('[1,2,3]') IS NOT NULL as vector_support")
result = cur.fetchone()
if not result or not result.get("vector_support"):
raise ValueError(
"RDS MySQL Vector functions are not available."
" Please ensure you're using RDS MySQL 8.0.36+ with Vector support."
)
except MySQLError as e:
if "FUNCTION" in str(e) and "VEC_FromText" in str(e):
raise ValueError(
"RDS MySQL Vector functions are not available."
" Please ensure you're using RDS MySQL 8.0.36+ with Vector support."
) from e
raise e
@contextmanager
def _get_cursor(self):
conn = self.pool.get_connection()
cur = conn.cursor(dictionary=True)
try:
yield cur
finally:
cur.close()
conn.close()
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
pks = []
for i, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
# Convert embedding list to Aliyun MySQL vector format
vector_str = "[" + ",".join(map(str, embeddings[i])) + "]"
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
vector_str,
)
)
with self._get_cursor() as cur:
insert_sql = (
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (%s, %s, %s, VEC_FromText(%s))"
)
cur.executemany(insert_sql, values)
return pks
def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
return cur.fetchone() is not None
def get_by_ids(self, ids: list[str]) -> list[Document]:
if not ids:
return []
with self._get_cursor() as cur:
placeholders = ",".join(["%s"] * len(ids))
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
docs = []
for record in cur:
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs
def delete_by_ids(self, ids: list[str]):
# Avoiding crashes caused by performing delete operations on empty lists
if not ids:
return
with self._get_cursor() as cur:
try:
placeholders = ",".join(["%s"] * len(ids))
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
except MySQLError as e:
if e.errno == 1146: # Table doesn't exist
logger.warning("Table %s not found, skipping delete operation.", self.table_name)
return
else:
raise e
def delete_by_metadata_field(self, key: str, value: str):
with self._get_cursor() as cur:
cur.execute(
f"DELETE FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s", (f"$.{key}", value)
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector using RDS MySQL vector distance functions.
:param query_vector: The input vector to search for similar items.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params = []
if document_ids_filter:
placeholders = ",".join(["%s"] * len(document_ids_filter))
where_clause = f" WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) "
params.extend(document_ids_filter)
# Convert query vector to RDS MySQL vector format
query_vector_str = "[" + ",".join(map(str, query_vector)) + "]"
# Use RSD MySQL's native vector distance functions
with self._get_cursor() as cur:
# Choose distance function based on configuration
distance_func = "VEC_DISTANCE_COSINE" if self.distance_function == "cosine" else "VEC_DISTANCE_EUCLIDEAN"
# Note: RDS MySQL optimizer will use vector index when ORDER BY + LIMIT are present
# Use column alias in ORDER BY to avoid calculating distance twice
sql = f"""
SELECT meta, text,
{distance_func}(embedding, VEC_FromText(%s)) AS distance
FROM {self.table_name}
{where_clause}
ORDER BY distance
LIMIT %s
"""
query_params = [query_vector_str] + params + [top_k]
cur.execute(sql, query_params)
docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in cur:
try:
distance = float(record["distance"])
# Convert distance to similarity score
if self.distance_function == "cosine":
# For cosine distance: similarity = 1 - distance
similarity = 1.0 - distance
else:
# For euclidean distance: use inverse relationship
# similarity = 1 / (1 + distance)
similarity = 1.0 / (1.0 + distance)
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
metadata["score"] = similarity
metadata["distance"] = distance
if similarity >= score_threshold:
docs.append(Document(page_content=record["text"], metadata=metadata))
except (ValueError, json.JSONDecodeError) as e:
logger.warning("Error processing search result: %s", e)
continue
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
params = []
if document_ids_filter:
placeholders = ",".join(["%s"] * len(document_ids_filter))
where_clause = f" AND JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) "
params.extend(document_ids_filter)
with self._get_cursor() as cur:
# Build query parameters: query (twice for MATCH clauses), document_ids_filter (if any), top_k
query_params = [query, query] + params + [top_k]
cur.execute(
f"""SELECT meta, text,
MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE) AS score
FROM {self.table_name}
WHERE MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE)
{where_clause}
ORDER BY score DESC
LIMIT %s""",
query_params,
)
docs = []
for record in cur:
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
metadata["score"] = float(record["score"])
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
def _create_collection(self, dimension: int):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{collection_exist_cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
if redis_client.get(collection_exist_cache_key):
return
with self._get_cursor() as cur:
# Create table with vector column and vector index
cur.execute(
SQL_CREATE_TABLE.format(
table_name=self.table_name,
dimension=dimension,
distance_function=self.distance_function,
hnsw_m=self.hnsw_m,
)
)
# Create metadata index (check if exists first)
try:
cur.execute(SQL_CREATE_META_INDEX.format(table_name=self.table_name, index_hash=self.index_hash))
except MySQLError as e:
if e.errno != 1061: # Duplicate key name
logger.warning("Could not create meta index: %s", e)
# Create full-text index for text search
try:
cur.execute(
SQL_CREATE_FULLTEXT_INDEX.format(table_name=self.table_name, index_hash=self.index_hash)
)
except MySQLError as e:
if e.errno != 1061: # Duplicate key name
logger.warning("Could not create fulltext index: %s", e)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class AlibabaCloudMySQLVectorFactory(AbstractVectorFactory):
def _validate_distance_function(self, distance_function: str) -> Literal["cosine", "euclidean"]:
"""Validate and return the distance function as a proper Literal type."""
if distance_function not in ["cosine", "euclidean"]:
raise ValueError(f"Invalid distance function: {distance_function}. Must be 'cosine' or 'euclidean'")
return cast(Literal["cosine", "euclidean"], distance_function)
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AlibabaCloudMySQLVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ALIBABACLOUD_MYSQL, collection_name)
)
return AlibabaCloudMySQLVector(
collection_name=collection_name,
config=AlibabaCloudMySQLVectorConfig(
host=dify_config.ALIBABACLOUD_MYSQL_HOST or "localhost",
port=dify_config.ALIBABACLOUD_MYSQL_PORT,
user=dify_config.ALIBABACLOUD_MYSQL_USER or "root",
password=dify_config.ALIBABACLOUD_MYSQL_PASSWORD or "",
database=dify_config.ALIBABACLOUD_MYSQL_DATABASE or "dify",
max_connection=dify_config.ALIBABACLOUD_MYSQL_MAX_CONNECTION,
charset=dify_config.ALIBABACLOUD_MYSQL_CHARSET or "utf8mb4",
distance_function=self._validate_distance_function(
dify_config.ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION or "cosine"
),
hnsw_m=dify_config.ALIBABACLOUD_MYSQL_HNSW_M or 6,
),
)

View File

@ -488,9 +488,9 @@ class ClickzettaVector(BaseVector):
create_table_sql = f"""
CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} (
id STRING NOT NULL COMMENT 'Unique document identifier',
{Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
{Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes',
{Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
{Field.CONTENT_KEY} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
{Field.METADATA_KEY} JSON COMMENT 'Document metadata including source, type, and other attributes',
{Field.VECTOR} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
'High-dimensional embedding vector for semantic similarity search',
PRIMARY KEY (id)
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
@ -519,15 +519,15 @@ class ClickzettaVector(BaseVector):
existing_indexes = cursor.fetchall()
for idx in existing_indexes:
# Check if vector index already exists on the embedding column
if Field.VECTOR.value in str(idx).lower():
logger.info("Vector index already exists on column %s", Field.VECTOR.value)
if Field.VECTOR in str(idx).lower():
logger.info("Vector index already exists on column %s", Field.VECTOR)
return
except (RuntimeError, ValueError) as e:
logger.warning("Failed to check existing indexes: %s", e)
index_sql = f"""
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value})
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR})
PROPERTIES (
"distance.function" = "{self._config.vector_distance_function}",
"scalar.type" = "f32",
@ -560,17 +560,17 @@ class ClickzettaVector(BaseVector):
# More precise check: look for inverted index specifically on the content column
if (
"inverted" in idx_str
and Field.CONTENT_KEY.value.lower() in idx_str
and Field.CONTENT_KEY.lower() in idx_str
and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)
):
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx)
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY, idx)
return
except (RuntimeError, ValueError) as e:
logger.warning("Failed to check existing indexes: %s", e)
index_sql = f"""
CREATE INVERTED INDEX IF NOT EXISTS {index_name}
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value})
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY})
PROPERTIES (
"analyzer" = "{self._config.analyzer_type}",
"mode" = "{self._config.analyzer_mode}"
@ -588,13 +588,13 @@ class ClickzettaVector(BaseVector):
or "with the same type" in error_msg
or "cannot create inverted index" in error_msg
) and "already has index" in error_msg:
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value)
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY)
# Try to get the existing index name for logging
try:
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
existing_indexes = cursor.fetchall()
for idx in existing_indexes:
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower():
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.lower() in str(idx).lower():
logger.info("Found existing inverted index: %s", idx)
break
except (RuntimeError, ValueError):
@ -669,7 +669,7 @@ class ClickzettaVector(BaseVector):
# Use parameterized INSERT with executemany for better performance and security
# Cast JSON and VECTOR in SQL, pass raw data as parameters
columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}"
columns = f"id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}, {Field.VECTOR}"
insert_sql = (
f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) "
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
@ -767,7 +767,7 @@ class ClickzettaVector(BaseVector):
# Use json_extract_string function for ClickZetta compatibility
sql = (
f"DELETE FROM {self._config.schema_name}.{self._table_name} "
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?"
f"WHERE json_extract_string({Field.METADATA_KEY}, '$.{key}') = ?"
)
cursor.execute(sql, binding_params=[value])
@ -795,9 +795,7 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
# No need for dataset_id filter since each dataset has its own table
@ -808,23 +806,21 @@ class ClickzettaVector(BaseVector):
distance_func = "COSINE_DISTANCE"
if score_threshold > 0:
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
filter_clauses.append(
f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}"
)
filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {2 - score_threshold}")
else:
# For L2 distance, smaller is better
distance_func = "L2_DISTANCE"
if score_threshold > 0:
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}")
filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {score_threshold}")
where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1"
# Execute vector search query
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
search_sql = f"""
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value},
{distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY},
{distance_func}({Field.VECTOR}, {query_vector_str}) AS distance
FROM {self._config.schema_name}.{self._table_name}
WHERE {where_clause}
ORDER BY distance
@ -887,9 +883,7 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
# No need for dataset_id filter since each dataset has its own table
@ -897,13 +891,13 @@ class ClickzettaVector(BaseVector):
# match_all requires all terms to be present
# Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')")
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY}, '{escaped_query}')")
where_clause = " AND ".join(filter_clauses)
# Execute full-text search query
search_sql = f"""
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}
FROM {self._config.schema_name}.{self._table_name}
WHERE {where_clause}
LIMIT {top_k}
@ -986,19 +980,17 @@ class ClickzettaVector(BaseVector):
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
# No need for dataset_id filter since each dataset has its own table
# Use simple quote escaping for LIKE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
where_clause = " AND ".join(filter_clauses)
search_sql = f"""
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}
FROM {self._config.schema_name}.{self._table_name}
WHERE {where_clause}
LIMIT {top_k}

View File

@ -57,18 +57,18 @@ class ElasticSearchJaVector(ElasticSearchVector):
}
mappings = {
"properties": {
Field.CONTENT_KEY.value: {
Field.CONTENT_KEY: {
"type": "text",
"analyzer": "ja_analyzer",
"search_analyzer": "ja_analyzer",
},
Field.VECTOR.value: { # Make sure the dimension is correct here
Field.VECTOR: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type

View File

@ -4,7 +4,7 @@ import math
from typing import Any, cast
from urllib.parse import urlparse
import requests
from elasticsearch import ConnectionError as ElasticsearchConnectionError
from elasticsearch import Elasticsearch
from flask import current_app
from packaging.version import parse as parse_version
@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector):
if not client.ping():
raise ConnectionError("Failed to connect to Elasticsearch")
except requests.ConnectionError as e:
except ElasticsearchConnectionError as e:
raise ConnectionError(f"Vector database connection error: {str(e)}")
except Exception as e:
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
@ -163,9 +163,9 @@ class ElasticSearchVector(BaseVector):
index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] or None,
Field.METADATA_KEY.value: documents[i].metadata or {},
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i] or None,
Field.METADATA_KEY: documents[i].metadata or {},
},
)
self._client.indices.refresh(index=self._collection_name)
@ -193,7 +193,7 @@ class ElasticSearchVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
num_candidates = math.ceil(top_k * 1.5)
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
knn = {"field": Field.VECTOR, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
@ -205,9 +205,9 @@ class ElasticSearchVector(BaseVector):
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
),
hit["_score"],
)
@ -224,13 +224,13 @@ class ElasticSearchVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}}
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY: query}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query_str = {
"bool": {
"must": {"match": {Field.CONTENT_KEY.value: query}},
"must": {"match": {Field.CONTENT_KEY: query}},
"filter": {"terms": {"metadata.document_id": document_ids_filter}},
}
}
@ -240,9 +240,9 @@ class ElasticSearchVector(BaseVector):
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
)
)
@ -270,14 +270,14 @@ class ElasticSearchVector(BaseVector):
dim = len(embeddings[0])
mappings = {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: { # Make sure the dimension is correct here
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type

View File

@ -67,9 +67,9 @@ class HuaweiCloudVector(BaseVector):
index=self._collection_name,
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] or None,
Field.METADATA_KEY.value: documents[i].metadata or {},
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i] or None,
Field.METADATA_KEY: documents[i].metadata or {},
},
)
self._client.indices.refresh(index=self._collection_name)
@ -101,7 +101,7 @@ class HuaweiCloudVector(BaseVector):
"size": top_k,
"query": {
"vector": {
Field.VECTOR.value: {
Field.VECTOR: {
"vector": query_vector,
"topk": top_k,
}
@ -116,9 +116,9 @@ class HuaweiCloudVector(BaseVector):
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
),
hit["_score"],
)
@ -135,15 +135,15 @@ class HuaweiCloudVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {"match": {Field.CONTENT_KEY.value: query}}
query_str = {"match": {Field.CONTENT_KEY: query}}
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = []
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
)
)
@ -171,8 +171,8 @@ class HuaweiCloudVector(BaseVector):
dim = len(embeddings[0])
mappings = {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: { # Make sure the dimension is correct here
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: { # Make sure the dimension is correct here
"type": "vector",
"dimension": dim,
"indexing": True,
@ -181,7 +181,7 @@ class HuaweiCloudVector(BaseVector):
"neighbors": 32,
"efc": 128,
},
Field.METADATA_KEY.value: {
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type

View File

@ -125,9 +125,9 @@ class LindormVectorStore(BaseVector):
}
}
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i],
Field.METADATA_KEY: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
@ -149,7 +149,7 @@ class LindormVectorStore(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = {
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
@ -252,14 +252,14 @@ class LindormVectorStore(BaseVector):
search_query: dict[str, Any] = {
"size": top_k,
"_source": True,
"query": {"knn": {Field.VECTOR.value: {"vector": query_vector, "k": top_k}}},
"query": {"knn": {Field.VECTOR: {"vector": query_vector, "k": top_k}}},
}
final_ext: dict[str, Any] = {"lvector": {}}
if filters is not None and len(filters) > 0:
# when using filter, transform filter from List[Dict] to Dict as valid format
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
search_query["query"]["knn"][Field.VECTOR.value]["filter"] = filter_dict # filter should be Dict
search_query["query"]["knn"][Field.VECTOR]["filter"] = filter_dict # filter should be Dict
final_ext["lvector"]["filter_type"] = "pre_filter"
if final_ext != {"lvector": {}}:
@ -279,9 +279,9 @@ class LindormVectorStore(BaseVector):
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"][Field.VECTOR],
metadata=hit["_source"][Field.METADATA_KEY],
),
hit["_score"],
)
@ -318,9 +318,9 @@ class LindormVectorStore(BaseVector):
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value)
vector = hit["_source"].get(Field.VECTOR.value)
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
metadata = hit["_source"].get(Field.METADATA_KEY)
vector = hit["_source"].get(Field.VECTOR)
page_content = hit["_source"].get(Field.CONTENT_KEY)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
@ -342,8 +342,8 @@ class LindormVectorStore(BaseVector):
"settings": {"index": {"knn": True, "knn_routing": self._using_ugc}},
"mappings": {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: {
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: {
"type": "knn_vector",
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
"method": {

View File

@ -85,7 +85,7 @@ class MilvusVector(BaseVector):
collection_info = self._client.describe_collection(self._collection_name)
fields = [field["name"] for field in collection_info["fields"]]
# Since primary field is auto-id, no need to track it
self._fields = [f for f in fields if f != Field.PRIMARY_KEY.value]
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
def _check_hybrid_search_support(self) -> bool:
"""
@ -130,9 +130,9 @@ class MilvusVector(BaseVector):
insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i],
Field.METADATA_KEY: documents[i].metadata,
}
insert_dict_list.append(insert_dict)
# Total insert count
@ -243,15 +243,15 @@ class MilvusVector(BaseVector):
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
anns_field=Field.VECTOR.value,
anns_field=Field.VECTOR,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
filter=filter,
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
@ -264,7 +264,7 @@ class MilvusVector(BaseVector):
"Full-text search is disabled: set MILVUS_ENABLE_HYBRID_SEARCH=true (requires Milvus >= 2.5.0)."
)
return []
if not self.field_exists(Field.SPARSE_VECTOR.value):
if not self.field_exists(Field.SPARSE_VECTOR):
logger.warning(
"Full-text search unavailable: collection missing 'sparse_vector' field; "
"recreate the collection after enabling MILVUS_ENABLE_HYBRID_SEARCH to add BM25 sparse index."
@ -279,15 +279,15 @@ class MilvusVector(BaseVector):
results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR.value,
anns_field=Field.SPARSE_VECTOR,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
filter=filter,
)
return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
@ -311,7 +311,7 @@ class MilvusVector(BaseVector):
dim = len(embeddings[0])
fields = []
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
fields.append(FieldSchema(Field.METADATA_KEY, DataType.JSON, max_length=65_535))
# Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
@ -326,15 +326,15 @@ class MilvusVector(BaseVector):
):
content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs))
fields.append(FieldSchema(Field.CONTENT_KEY, DataType.VARCHAR, **content_field_kwargs))
# Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
fields.append(FieldSchema(Field.PRIMARY_KEY, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
fields.append(FieldSchema(Field.VECTOR, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
fields.append(FieldSchema(Field.SPARSE_VECTOR, DataType.SPARSE_FLOAT_VECTOR))
schema = CollectionSchema(fields)
@ -342,8 +342,8 @@ class MilvusVector(BaseVector):
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY.value],
output_field_names=[Field.SPARSE_VECTOR.value],
input_field_names=[Field.CONTENT_KEY],
output_field_names=[Field.SPARSE_VECTOR],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
@ -352,12 +352,12 @@ class MilvusVector(BaseVector):
# Create Index params for the collection
index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
index_params_obj.add_index(field_name=Field.VECTOR, **index_params)
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
field_name=Field.SPARSE_VECTOR, index_type="AUTOINDEX", metric_type="BM25"
)
# Create the collection

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Literal
from typing import Any
from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
@ -8,6 +8,7 @@ from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from configs import dify_config
from configs.middleware.vdb.opensearch_config import AuthMethod
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@ -25,7 +26,7 @@ class OpenSearchConfig(BaseModel):
port: int
secure: bool = False # use_ssl
verify_certs: bool = True
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
auth_method: AuthMethod = AuthMethod.BASIC
user: str | None = None
password: str | None = None
aws_region: str | None = None
@ -98,9 +99,9 @@ class OpenSearchVector(BaseVector):
"_op_type": "index",
"_index": self._collection_name.lower(),
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY: documents[i].metadata,
},
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
@ -116,7 +117,7 @@ class OpenSearchVector(BaseVector):
)
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
query = {"query": {"term": {f"{Field.METADATA_KEY}.{key}": value}}}
response = self._client.search(index=self._collection_name.lower(), body=query)
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
@ -180,17 +181,17 @@ class OpenSearchVector(BaseVector):
query = {
"size": kwargs.get("top_k", 4),
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
"query": {"knn": {Field.VECTOR: {Field.VECTOR: query_vector, "k": kwargs.get("top_k", 4)}}},
}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query["query"] = {
"script_score": {
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID.value: document_ids_filter}}]}},
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID: document_ids_filter}}]}},
"script": {
"source": "knn_score",
"lang": "knn",
"params": {"field": Field.VECTOR.value, "query_value": query_vector, "space_type": "l2"},
"params": {"field": Field.VECTOR, "query_value": query_vector, "space_type": "l2"},
},
}
}
@ -203,7 +204,7 @@ class OpenSearchVector(BaseVector):
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value, {})
metadata = hit["_source"].get(Field.METADATA_KEY, {})
# Make sure metadata is a dictionary
if metadata is None:
@ -212,7 +213,7 @@ class OpenSearchVector(BaseVector):
metadata["score"] = hit["_score"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] >= score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY), metadata=metadata)
docs.append(doc)
return docs
@ -227,9 +228,9 @@ class OpenSearchVector(BaseVector):
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY.value)
vector = hit["_source"].get(Field.VECTOR.value)
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
metadata = hit["_source"].get(Field.METADATA_KEY)
vector = hit["_source"].get(Field.VECTOR)
page_content = hit["_source"].get(Field.CONTENT_KEY)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
@ -250,8 +251,8 @@ class OpenSearchVector(BaseVector):
"settings": {"index": {"knn": True}},
"mappings": {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: {
Field.CONTENT_KEY: {"type": "text"},
Field.VECTOR: {
"type": "knn_vector",
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
"method": {
@ -261,7 +262,7 @@ class OpenSearchVector(BaseVector):
"parameters": {"ef_construction": 64, "m": 8},
},
},
Field.METADATA_KEY.value: {
Field.METADATA_KEY: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
@ -293,7 +294,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
port=dify_config.OPENSEARCH_PORT,
secure=dify_config.OPENSEARCH_SECURE,
verify_certs=dify_config.OPENSEARCH_VERIFY_CERTS,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
aws_region=dify_config.OPENSEARCH_AWS_REGION,

View File

@ -289,7 +289,8 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名ns: 地名nt: 机构名
# nr: person name, ns: place name, nt: organization name
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}:
current_entity += word
else:
if current_entity:

View File

@ -213,7 +213,7 @@ class VastbaseVector(BaseVector):
with self._get_cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
# Vastbase 支持的向量维度取值范围为 [1,16000]
# Vastbase supports vector dimensions in range [1, 16000]
if dimension <= 16000:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)

View File

@ -147,15 +147,13 @@ class QdrantVector(BaseVector):
# create group_id payload index
self._client.create_payload_index(
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
)
self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD)
# create document_id payload index
self._client.create_payload_index(
collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD
collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
@ -165,9 +163,7 @@ class QdrantVector(BaseVector):
max_token_len=20,
lowercase=True,
)
self._client.create_payload_index(
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -220,10 +216,10 @@ class QdrantVector(BaseVector):
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
Field.CONTENT_KEY,
Field.METADATA_KEY,
group_id or "", # Ensure group_id is never None
Field.GROUP_KEY.value,
Field.GROUP_KEY,
),
)
]
@ -381,12 +377,12 @@ class QdrantVector(BaseVector):
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
metadata = result.payload.get(Field.METADATA_KEY) or {}
# duplicate check score threshold
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
page_content=result.payload.get(Field.CONTENT_KEY, ""),
metadata=metadata,
)
docs.append(doc)
@ -433,7 +429,7 @@ class QdrantVector(BaseVector):
documents = []
for result in results:
if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
documents.append(document)
return documents

View File

@ -55,7 +55,7 @@ class TableStoreVector(BaseVector):
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
self._table_name = f"{collection_name}"
self._index_name = f"{collection_name}_idx"
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
self._tags_field = f"{Field.METADATA_KEY}_tags"
def create_collection(self, embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
@ -64,7 +64,7 @@ class TableStoreVector(BaseVector):
def get_by_ids(self, ids: list[str]) -> list[Document]:
docs = []
request = BatchGetRowRequest()
columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value]
columns_to_get = [Field.METADATA_KEY, Field.CONTENT_KEY]
rows_to_get = [[("id", _id)] for _id in ids]
request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1))
@ -73,11 +73,7 @@ class TableStoreVector(BaseVector):
for item in table_result:
if item.is_ok and item.row:
kv = {k: v for k, v, _ in item.row.attribute_columns}
docs.append(
Document(
page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value])
)
)
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY])))
return docs
def get_type(self) -> str:
@ -95,9 +91,9 @@ class TableStoreVector(BaseVector):
self._write_row(
primary_key=uuids[i],
attributes={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
Field.CONTENT_KEY: documents[i].page_content,
Field.VECTOR: embeddings[i],
Field.METADATA_KEY: documents[i].metadata,
},
)
return uuids
@ -180,7 +176,7 @@ class TableStoreVector(BaseVector):
field_schemas = [
tablestore.FieldSchema(
Field.CONTENT_KEY.value,
Field.CONTENT_KEY,
tablestore.FieldType.TEXT,
analyzer=tablestore.AnalyzerType.MAXWORD,
index=True,
@ -188,7 +184,7 @@ class TableStoreVector(BaseVector):
store=False,
),
tablestore.FieldSchema(
Field.VECTOR.value,
Field.VECTOR,
tablestore.FieldType.VECTOR,
vector_options=tablestore.VectorOptions(
data_type=tablestore.VectorDataType.VD_FLOAT_32,
@ -197,7 +193,7 @@ class TableStoreVector(BaseVector):
),
),
tablestore.FieldSchema(
Field.METADATA_KEY.value,
Field.METADATA_KEY,
tablestore.FieldType.KEYWORD,
index=True,
store=False,
@ -233,15 +229,15 @@ class TableStoreVector(BaseVector):
pk = [("id", primary_key)]
tags = []
for key, value in attributes[Field.METADATA_KEY.value].items():
for key, value in attributes[Field.METADATA_KEY].items():
tags.append(str(key) + "=" + str(value))
attribute_columns = [
(Field.CONTENT_KEY.value, attributes[Field.CONTENT_KEY.value]),
(Field.VECTOR.value, json.dumps(attributes[Field.VECTOR.value])),
(Field.CONTENT_KEY, attributes[Field.CONTENT_KEY]),
(Field.VECTOR, json.dumps(attributes[Field.VECTOR])),
(
Field.METADATA_KEY.value,
json.dumps(attributes[Field.METADATA_KEY.value]),
Field.METADATA_KEY,
json.dumps(attributes[Field.METADATA_KEY]),
),
(self._tags_field, json.dumps(tags)),
]
@ -270,7 +266,7 @@ class TableStoreVector(BaseVector):
index_name=self._index_name,
search_query=query,
columns_to_get=tablestore.ColumnsToGet(
column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED
column_names=[Field.PRIMARY_KEY], return_type=tablestore.ColumnReturnType.SPECIFIED
),
)
@ -288,7 +284,7 @@ class TableStoreVector(BaseVector):
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
knn_vector_query = tablestore.KnnVectorQuery(
field_name=Field.VECTOR.value,
field_name=Field.VECTOR,
top_k=top_k,
float32_query_vector=query_vector,
)
@ -311,8 +307,8 @@ class TableStoreVector(BaseVector):
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
vector_str = ots_column_map.get(Field.VECTOR.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
vector_str = ots_column_map.get(Field.VECTOR)
metadata_str = ots_column_map.get(Field.METADATA_KEY)
vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}
@ -321,7 +317,7 @@ class TableStoreVector(BaseVector):
documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
vector=vector,
metadata=metadata,
)
@ -343,7 +339,7 @@ class TableStoreVector(BaseVector):
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY))
if document_ids_filter:
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))
@ -374,10 +370,10 @@ class TableStoreVector(BaseVector):
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY)
metadata = json.loads(metadata_str) if metadata_str else {}
vector_str = ots_column_map.get(Field.VECTOR.value)
vector_str = ots_column_map.get(Field.VECTOR)
vector = json.loads(vector_str) if vector_str else None
if score:
@ -385,7 +381,7 @@ class TableStoreVector(BaseVector):
documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
vector=vector,
metadata=metadata,
)

View File

@ -5,9 +5,10 @@ from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Union
import httpx
import qdrant_client
import requests
from flask import current_app
from httpx import DigestAuth
from pydantic import BaseModel
from qdrant_client.http import models as rest
from qdrant_client.http.models import (
@ -19,7 +20,6 @@ from qdrant_client.http.models import (
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from requests.auth import HTTPDigestAuth
from sqlalchemy import select
from configs import dify_config
@ -141,15 +141,13 @@ class TidbOnQdrantVector(BaseVector):
# create group_id payload index
self._client.create_payload_index(
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD
)
# create doc_id payload index
self._client.create_payload_index(
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
)
self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD)
# create document_id payload index
self._client.create_payload_index(
collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD
collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD
)
# create full text index
text_index_params = TextIndexParams(
@ -159,9 +157,7 @@ class TidbOnQdrantVector(BaseVector):
max_token_len=20,
lowercase=True,
)
self._client.create_payload_index(
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -211,10 +207,10 @@ class TidbOnQdrantVector(BaseVector):
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
Field.CONTENT_KEY,
Field.METADATA_KEY,
group_id or "",
Field.GROUP_KEY.value,
Field.GROUP_KEY,
),
)
]
@ -349,13 +345,13 @@ class TidbOnQdrantVector(BaseVector):
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
metadata = result.payload.get(Field.METADATA_KEY) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
page_content=result.payload.get(Field.CONTENT_KEY, ""),
metadata=metadata,
)
docs.append(doc)
@ -392,7 +388,7 @@ class TidbOnQdrantVector(BaseVector):
documents = []
for result in results:
if result:
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
documents.append(document)
return documents
@ -504,10 +500,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
}
cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
response = requests.post(
response = httpx.post(
f"{tidb_config.api_url}/clusters",
json=cluster_data,
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
auth=DigestAuth(tidb_config.public_key, tidb_config.private_key),
)
if response.status_code == 200:
@ -527,10 +523,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
body = {"password": new_password}
response = requests.put(
response = httpx.put(
f"{tidb_config.api_url}/clusters/{cluster_id}/password",
json=body,
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
auth=DigestAuth(tidb_config.public_key, tidb_config.private_key),
)
if response.status_code == 200:

View File

@ -2,8 +2,8 @@ import time
import uuid
from collections.abc import Sequence
import requests
from requests.auth import HTTPDigestAuth
import httpx
from httpx import DigestAuth
from configs import dify_config
from extensions.ext_database import db
@ -49,7 +49,7 @@ class TidbService:
"rootPassword": password,
}
response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key))
response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
response_data = response.json()
@ -83,7 +83,7 @@ class TidbService:
:return: The response from the API.
"""
response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
response = httpx.delete(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
@ -102,7 +102,7 @@ class TidbService:
:return: The response from the API.
"""
response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
response = httpx.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
@ -127,10 +127,10 @@ class TidbService:
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
response = requests.patch(
response = httpx.patch(
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
json=body,
auth=HTTPDigestAuth(public_key, private_key),
auth=DigestAuth(public_key, private_key),
)
if response.status_code == 200:
@ -161,9 +161,7 @@ class TidbService:
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "BASIC"}
response = requests.get(
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
)
response = httpx.get(f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
response_data = response.json()
@ -224,8 +222,8 @@ class TidbService:
clusters.append(cluster_data)
request_body = {"requests": clusters}
response = requests.post(
f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key)
response = httpx.post(
f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key)
)
if response.status_code == 200:

View File

@ -55,13 +55,13 @@ class TiDBVector(BaseVector):
return Table(
self._collection_name,
self._orm_base.metadata,
Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False),
Column(Field.PRIMARY_KEY, String(36), primary_key=True, nullable=False),
Column(
Field.VECTOR.value,
Field.VECTOR,
VectorType(dim),
nullable=False,
),
Column(Field.TEXT_KEY.value, TEXT, nullable=False),
Column(Field.TEXT_KEY, TEXT, nullable=False),
Column("meta", JSON, nullable=False),
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
Column(

View File

@ -71,6 +71,12 @@ class Vector:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory
case VectorType.ALIBABACLOUD_MYSQL:
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVectorFactory,
)
return AlibabaCloudMySQLVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory

View File

@ -2,6 +2,7 @@ from enum import StrEnum
class VectorType(StrEnum):
ALIBABACLOUD_MYSQL = "alibabacloud_mysql"
ANALYTICDB = "analyticdb"
CHROMA = "chroma"
MILVUS = "milvus"

View File

@ -76,11 +76,11 @@ class VikingDBVector(BaseVector):
if not self._has_collection():
fields = [
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension),
Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=dimension),
]
self._client.create_collection(
@ -100,7 +100,7 @@ class VikingDBVector(BaseVector):
collection_name=self._collection_name,
index_name=self._index_name,
vector_index=vector_index,
partition_by=vdb_Field.GROUP_KEY.value,
partition_by=vdb_Field.GROUP_KEY,
description="Index For Dify",
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@ -126,11 +126,11 @@ class VikingDBVector(BaseVector):
# FIXME: fix the type of metadata later
doc = Data(
{
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
vdb_Field.CONTENT_KEY.value: page_content,
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
vdb_Field.GROUP_KEY.value: self._group_id,
vdb_Field.PRIMARY_KEY: metadatas[i]["doc_id"], # type: ignore
vdb_Field.VECTOR: embeddings[i] if embeddings else None,
vdb_Field.CONTENT_KEY: page_content,
vdb_Field.METADATA_KEY: json.dumps(metadata),
vdb_Field.GROUP_KEY: self._group_id,
}
)
docs.append(doc)
@ -151,7 +151,7 @@ class VikingDBVector(BaseVector):
# Note: Metadata field value is an dict, but vikingdb field
# not support json type
results = self._client.get_index(self._collection_name, self._index_name).search(
filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]},
filter={"op": "must", "field": vdb_Field.GROUP_KEY, "conds": [self._group_id]},
# max value is 5000
limit=5000,
)
@ -161,7 +161,7 @@ class VikingDBVector(BaseVector):
ids = []
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
metadata = json.loads(metadata)
if metadata.get(key) == value:
@ -189,12 +189,12 @@ class VikingDBVector(BaseVector):
docs = []
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
metadata = json.loads(metadata)
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata)
docs.append(doc)
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return docs

View File

@ -2,7 +2,6 @@ import datetime
import json
from typing import Any
import requests
import weaviate # type: ignore
from pydantic import BaseModel, model_validator
@ -45,8 +44,8 @@ class WeaviateVector(BaseVector):
client = weaviate.Client(
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
)
except requests.ConnectionError:
raise ConnectionError("Vector database connection error")
except Exception as exc:
raise ConnectionError("Vector database connection error") from exc
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
@ -105,7 +104,7 @@ class WeaviateVector(BaseVector):
with self._client.batch as batch:
for i, text in enumerate(texts):
data_properties = {Field.TEXT_KEY.value: text}
data_properties = {Field.TEXT_KEY: text}
if metadatas is not None:
# metadata maybe None
for key, val in (metadatas[i] or {}).items():
@ -183,7 +182,7 @@ class WeaviateVector(BaseVector):
"""Look up similar documents by embedding vector in Weaviate."""
collection_name = self._collection_name
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
properties.append(Field.TEXT_KEY)
query_obj = self._client.query.get(collection_name, properties)
vector = {"vector": query_vector}
@ -205,7 +204,7 @@ class WeaviateVector(BaseVector):
docs_and_scores = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
text = res.pop(Field.TEXT_KEY)
score = 1 - res["_additional"]["distance"]
docs_and_scores.append((Document(page_content=text, metadata=res), score))
@ -233,7 +232,7 @@ class WeaviateVector(BaseVector):
collection_name = self._collection_name
content: dict[str, Any] = {"concepts": [query]}
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
properties.append(Field.TEXT_KEY)
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties)
@ -251,7 +250,7 @@ class WeaviateVector(BaseVector):
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
text = res.pop(Field.TEXT_KEY)
additional = res.pop("_additional")
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
return docs

View File

@ -20,12 +20,12 @@ class BaseDatasourceEvent(BaseModel):
class DatasourceErrorEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.ERROR.value
event: DatasourceStreamEvent = DatasourceStreamEvent.ERROR
error: str = Field(..., description="error message")
class DatasourceCompletedEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.COMPLETED.value
event: DatasourceStreamEvent = DatasourceStreamEvent.COMPLETED
data: Mapping[str, Any] | list = Field(..., description="result")
total: int | None = Field(default=0, description="total")
completed: int | None = Field(default=0, description="completed")
@ -33,6 +33,6 @@ class DatasourceCompletedEvent(BaseDatasourceEvent):
class DatasourceProcessingEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.PROCESSING.value
event: DatasourceStreamEvent = DatasourceStreamEvent.PROCESSING
total: int | None = Field(..., description="total")
completed: int | None = Field(..., description="completed")

View File

@ -45,7 +45,7 @@ class ExtractProcessor:
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
) -> Union[list[Document], str]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=upload_file, document_model="text_model"
datasource_type=DatasourceType.FILE, upload_file=upload_file, document_model="text_model"
)
if return_text:
delimiter = "\n"
@ -76,7 +76,7 @@ class ExtractProcessor:
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}"
Path(file_path).write_bytes(response.content)
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE.value, document_model="text_model")
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model")
if return_text:
delimiter = "\n"
return delimiter.join(
@ -92,7 +92,7 @@ class ExtractProcessor:
def extract(
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE.value:
if extract_setting.datasource_type == DatasourceType.FILE:
with tempfile.TemporaryDirectory() as temp_dir:
if not file_path:
assert extract_setting.upload_file is not None, "upload_file is required"
@ -163,7 +163,7 @@ class ExtractProcessor:
# txt
extractor = TextExtractor(file_path, autodetect_encoding=True)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
elif extract_setting.datasource_type == DatasourceType.NOTION:
assert extract_setting.notion_info is not None, "notion_info is required"
extractor = NotionExtractor(
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
@ -174,7 +174,7 @@ class ExtractProcessor:
credential_id=extract_setting.notion_info.credential_id,
)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
elif extract_setting.datasource_type == DatasourceType.WEBSITE:
assert extract_setting.website_info is not None, "website_info is required"
if extract_setting.website_info.provider == "firecrawl":
extractor = FirecrawlWebExtractor(

View File

@ -2,7 +2,7 @@ import json
import time
from typing import Any, cast
import requests
import httpx
from extensions.ext_storage import storage
@ -104,18 +104,18 @@ class FirecrawlApp:
def _prepare_headers(self) -> dict[str, Any]:
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> requests.Response:
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
for attempt in range(retries):
response = requests.post(url, headers=headers, json=data)
response = httpx.post(url, headers=headers, json=data)
if response.status_code == 502:
time.sleep(backoff_factor * (2**attempt))
else:
return response
return response
def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> requests.Response:
def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
for attempt in range(retries):
response = requests.get(url, headers=headers)
response = httpx.get(url, headers=headers)
if response.status_code == 502:
time.sleep(backoff_factor * (2**attempt))
else:

View File

@ -3,7 +3,7 @@ import logging
import operator
from typing import Any, cast
import requests
import httpx
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
@ -92,7 +92,7 @@ class NotionExtractor(BaseExtractor):
if next_cursor:
current_query["start_cursor"] = next_cursor
res = requests.post(
res = httpx.post(
DATABASE_URL_TMPL.format(database_id=database_id),
headers={
"Authorization": "Bearer " + self._notion_access_token,
@ -160,7 +160,7 @@ class NotionExtractor(BaseExtractor):
while True:
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
try:
res = requests.request(
res = httpx.request(
"GET",
block_url,
headers={
@ -173,7 +173,7 @@ class NotionExtractor(BaseExtractor):
if res.status_code != 200:
raise ValueError(f"Error fetching Notion block data: {res.text}")
data = res.json()
except requests.RequestException as e:
except httpx.HTTPError as e:
raise ValueError("Error fetching Notion block data") from e
if "results" not in data or not isinstance(data["results"], list):
raise ValueError("Error fetching Notion block data")
@ -222,7 +222,7 @@ class NotionExtractor(BaseExtractor):
while True:
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
res = requests.request(
res = httpx.request(
"GET",
block_url,
headers={
@ -282,7 +282,7 @@ class NotionExtractor(BaseExtractor):
while not done:
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
res = requests.request(
res = httpx.request(
"GET",
block_url,
headers={
@ -354,7 +354,7 @@ class NotionExtractor(BaseExtractor):
query_dict: dict[str, Any] = {}
res = requests.request(
res = httpx.request(
"GET",
retrieve_page_url,
headers={

View File

@ -3,8 +3,8 @@ from collections.abc import Generator
from typing import Union
from urllib.parse import urljoin
import requests
from requests import Response
import httpx
from httpx import Response
from core.rag.extractor.watercrawl.exceptions import (
WaterCrawlAuthenticationError,
@ -20,28 +20,45 @@ class BaseAPIClient:
self.session = self.init_session()
def init_session(self):
session = requests.Session()
session.headers.update({"X-API-Key": self.api_key})
session.headers.update({"Content-Type": "application/json"})
session.headers.update({"Accept": "application/json"})
session.headers.update({"User-Agent": "WaterCrawl-Plugin"})
session.headers.update({"Accept-Language": "en-US"})
return session
headers = {
"X-API-Key": self.api_key,
"Content-Type": "application/json",
"Accept": "application/json",
"User-Agent": "WaterCrawl-Plugin",
"Accept-Language": "en-US",
}
return httpx.Client(headers=headers, timeout=None)
def _request(
self,
method: str,
endpoint: str,
query_params: dict | None = None,
data: dict | None = None,
**kwargs,
) -> Response:
stream = kwargs.pop("stream", False)
url = urljoin(self.base_url, endpoint)
if stream:
request = self.session.build_request(method, url, params=query_params, json=data)
return self.session.send(request, stream=True, **kwargs)
return self.session.request(method, url, params=query_params, json=data, **kwargs)
def _get(self, endpoint: str, query_params: dict | None = None, **kwargs):
return self.session.get(urljoin(self.base_url, endpoint), params=query_params, **kwargs)
return self._request("GET", endpoint, query_params=query_params, **kwargs)
def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
return self.session.post(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
return self._request("POST", endpoint, query_params=query_params, data=data, **kwargs)
def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
return self.session.put(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
return self._request("PUT", endpoint, query_params=query_params, data=data, **kwargs)
def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs):
return self.session.delete(urljoin(self.base_url, endpoint), params=query_params, **kwargs)
return self._request("DELETE", endpoint, query_params=query_params, **kwargs)
def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
return self.session.patch(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
return self._request("PATCH", endpoint, query_params=query_params, data=data, **kwargs)
class WaterCrawlAPIClient(BaseAPIClient):
@ -49,14 +66,17 @@ class WaterCrawlAPIClient(BaseAPIClient):
super().__init__(api_key, base_url)
def process_eventstream(self, response: Response, download: bool = False) -> Generator:
for line in response.iter_lines():
line = line.decode("utf-8")
if line.startswith("data:"):
line = line[5:].strip()
data = json.loads(line)
if data["type"] == "result" and download:
data["data"] = self.download_result(data["data"])
yield data
try:
for raw_line in response.iter_lines():
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
if line.startswith("data:"):
line = line[5:].strip()
data = json.loads(line)
if data["type"] == "result" and download:
data["data"] = self.download_result(data["data"])
yield data
finally:
response.close()
def process_response(self, response: Response) -> dict | bytes | list | None | Generator:
if response.status_code == 401:
@ -170,7 +190,10 @@ class WaterCrawlAPIClient(BaseAPIClient):
return event_data["data"]
def download_result(self, result_object: dict):
response = requests.get(result_object["result"])
response.raise_for_status()
result_object["result"] = response.json()
response = httpx.get(result_object["result"], timeout=None)
try:
response.raise_for_status()
result_object["result"] = response.json()
finally:
response.close()
return result_object

View File

@ -9,7 +9,7 @@ import uuid
from urllib.parse import urlparse
from xml.etree import ElementTree
import requests
import httpx
from docx import Document as DocxDocument
from configs import dify_config
@ -43,15 +43,19 @@ class WordExtractor(BaseExtractor):
# If the file is a web path, download it to a temporary file, and use that
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
r = requests.get(self.file_path)
response = httpx.get(self.file_path, timeout=None)
if r.status_code != 200:
raise ValueError(f"Check the url of your file; returned status code {r.status_code}")
if response.status_code != 200:
response.close()
raise ValueError(f"Check the url of your file; returned status code {response.status_code}")
self.web_path = self.file_path
# TODO: use a better way to handle the file
self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115
self.temp_file.write(r.content)
try:
self.temp_file.write(response.content)
finally:
response.close()
self.file_path = self.temp_file.name
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")

View File

@ -8,9 +8,9 @@ class RerankRunnerFactory:
@staticmethod
def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner:
match runner_type:
case RerankMode.RERANKING_MODEL.value:
case RerankMode.RERANKING_MODEL:
return RerankModelRunner(*args, **kwargs)
case RerankMode.WEIGHTED_SCORE.value:
case RerankMode.WEIGHTED_SCORE:
return WeightRerankRunner(*args, **kwargs)
case _:
raise ValueError(f"Unknown runner type: {runner_type}")

View File

@ -61,7 +61,7 @@ from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
@ -692,7 +692,7 @@ class DatasetRetrieval:
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# get retrieval model config
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,

View File

@ -9,8 +9,8 @@ class RetrievalMethod(Enum):
@staticmethod
def is_support_semantic_search(retrieval_method: str) -> bool:
return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value}
return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.HYBRID_SEARCH}
@staticmethod
def is_support_fulltext_search(retrieval_method: str) -> bool:
return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value}
return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.HYBRID_SEARCH}

View File

@ -111,7 +111,7 @@ class BuiltinToolProviderController(ToolProviderController):
:return: the credentials schema
"""
return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
"""
@ -122,7 +122,7 @@ class BuiltinToolProviderController(ToolProviderController):
"""
if credential_type == CredentialType.OAUTH2.value:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY.value:
if credential_type == CredentialType.API_KEY:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
raise ValueError(f"Invalid credential type: {credential_type}")
@ -134,15 +134,15 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
def get_supported_credential_types(self) -> list[str]:
def get_supported_credential_types(self) -> list[CredentialType]:
"""
returns the credential support type of the provider
"""
types = []
if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
types.append(CredentialType.API_KEY.value)
types.append(CredentialType.API_KEY)
if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
types.append(CredentialType.OAUTH2.value)
types.append(CredentialType.OAUTH2)
return types
def get_tools(self) -> list[BuiltinTool]:

View File

@ -290,6 +290,7 @@ class ApiTool(Tool):
method_lc
]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926
url,
max_retries=0,
params=params,
headers=headers,
cookies=cookies,

View File

@ -61,7 +61,7 @@ class ToolProviderApiEntity(BaseModel):
for tool in tools:
if tool.get("parameters"):
for parameter in tool.get("parameters"):
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES:
parameter["type"] = "files"
if parameter.get("input_schema") is None:
parameter.pop("input_schema", None)
@ -110,7 +110,9 @@ class ToolProviderCredentialApiEntity(BaseModel):
class ToolProviderCredentialInfoApiEntity(BaseModel):
supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
supported_credential_types: list[CredentialType] = Field(
description="The supported credential types of the provider"
)
is_oauth_custom_client_enabled: bool = Field(
default=False, description="Whether the OAuth custom client is enabled for the provider"
)

View File

@ -113,7 +113,7 @@ class ApiProviderAuthType(StrEnum):
# normalize & tiny alias for backward compatibility
v = (value or "").strip().lower()
if v == "api_key":
v = cls.API_KEY_HEADER.value
v = cls.API_KEY_HEADER
for mode in cls:
if mode.value == v:

View File

@ -18,7 +18,7 @@ from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,

View File

@ -17,7 +17,7 @@ from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"reranking_mode": "reranking_model",

View File

@ -4,8 +4,8 @@ from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Any
import httpx
from flask import request
from requests import get
from yaml import YAMLError, safe_load
from core.tools.entities.common_entities import I18nObject
@ -334,15 +334,20 @@ class ApiBasedToolSchemaParser:
raise ToolNotSupportedError("Only openapi is supported now.")
# get openapi yaml
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
if response.status_code != 200:
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
response.text, extra_info=extra_info, warning=warning
response = httpx.get(
api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5
)
try:
if response.status_code != 200:
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
response.text, extra_info=extra_info, warning=warning
)
finally:
response.close()
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
@ -388,7 +393,7 @@ class ApiBasedToolSchemaParser:
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.OPENAPI.value
schema_type = ApiProviderSchemaType.OPENAPI
return openapi, schema_type
except ToolApiSchemaError as e:
openapi_error = e
@ -398,7 +403,7 @@ class ApiBasedToolSchemaParser:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.SWAGGER.value
schema_type = ApiProviderSchemaType.SWAGGER
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
converted_swagger, extra_info=extra_info, warning=warning
), schema_type
@ -410,7 +415,7 @@ class ApiBasedToolSchemaParser:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning
)
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN
except ToolNotSupportedError as e:
# maybe it's not plugin at all
openapi_plugin_error = e

View File

@ -252,7 +252,10 @@ class AgentNode(Node):
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN.value) == ParamsAutoGenerated.CLOSE.value:
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
params[key] = value_param.get("value", "") if value_param is not None else None
else:
@ -266,7 +269,7 @@ class AgentNode(Node):
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN.value))
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
@ -417,7 +420,7 @@ class AgentNode(Node):
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]
["sys", SystemVariableKey.CONVERSATION_ID]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
@ -476,7 +479,7 @@ class AgentNode(Node):
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _transform_message(
self,

View File

@ -75,11 +75,11 @@ class DatasourceNode(Node):
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
if not datasource_type_segement:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
if not datasource_info_segement:
raise DatasourceNodeError("Datasource info is not set")
datasource_info_value = datasource_info_segement.value
@ -267,7 +267,7 @@ class DatasourceNode(Node):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []

View File

@ -234,7 +234,7 @@ class HttpRequestNode(Node):
mapping = {
"tool_file_id": tool_file.id,
"transfer_method": FileTransferMethod.TOOL_FILE.value,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
file = file_factory.build_from_mapping(
mapping=mapping,

View File

@ -95,7 +95,7 @@ class IterationNode(Node):
"config": {
"is_parallel": False,
"parallel_nums": 10,
"error_handle_mode": ErrorHandleMode.TERMINATED.value,
"error_handle_mode": ErrorHandleMode.TERMINATED,
},
}

View File

@ -27,7 +27,7 @@ from .exc import (
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
@ -77,7 +77,7 @@ class KnowledgeIndexNode(Node):
raise KnowledgeIndexNodeError("Index chunk variable is required.")
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
if invoke_from:
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
is_preview = invoke_from.value == InvokeFrom.DEBUGGER
else:
is_preview = False
chunks = variable.value

View File

@ -72,7 +72,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,

View File

@ -92,7 +92,7 @@ def fetch_memory(
return None
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value])
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value

View File

@ -956,7 +956,7 @@ class LLMNode(Node):
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
if typed_node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
if typed_node_data.prompt_config:
enable_jinja = False

View File

@ -224,7 +224,7 @@ class ToolNode(Node):
return result
def _fetch_files(self, variable_pool: "VariablePool") -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []

View File

@ -227,7 +227,7 @@ class WorkflowEntry:
"height": node_height,
"type": "custom",
"data": {
"type": NodeType.START.value,
"type": NodeType.START,
"title": "Start",
"desc": "Start",
},

View File

@ -12,7 +12,7 @@ def handle(sender, **kwargs):
if synced_draft_workflow is None:
return
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
if node_data.get("data", {}).get("type") == NodeType.TOOL:
try:
tool_entity = ToolEntity.model_validate(node_data["data"])
tool_runtime = ToolManager.get_tool_runtime(

View File

@ -53,7 +53,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]:
# fetch all knowledge retrieval nodes
knowledge_retrieval_nodes = [
node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL.value
node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL
]
if not knowledge_retrieval_nodes:

View File

@ -138,7 +138,6 @@ def init_app(app: DifyApp):
from opentelemetry.instrumentation.flask import FlaskInstrumentor
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.instrumentation.requests import RequestsInstrumentor
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider
from opentelemetry.propagate import set_global_textmap
@ -238,7 +237,6 @@ def init_app(app: DifyApp):
instrument_exception_logging()
init_sqlalchemy_instrumentor(app)
RedisInstrumentor().instrument()
RequestsInstrumentor().instrument()
HTTPXClientInstrumentor().instrument()
atexit.register(shutdown_tracer)

View File

@ -264,7 +264,7 @@ class FileLifecycleManager:
logger.warning("File %s not found in metadata", filename)
return False
metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value
metadata_dict[filename]["status"] = FileStatus.ARCHIVED
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict)
@ -309,7 +309,7 @@ class FileLifecycleManager:
# Update metadata
metadata_dict = self._load_metadata()
if filename in metadata_dict:
metadata_dict[filename]["status"] = FileStatus.DELETED.value
metadata_dict[filename]["status"] = FileStatus.DELETED
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict)

View File

@ -45,7 +45,7 @@ def build_from_message_file(
}
# Set the correct ID field based on transfer method
if message_file.transfer_method == FileTransferMethod.TOOL_FILE.value:
if message_file.transfer_method == FileTransferMethod.TOOL_FILE:
mapping["tool_file_id"] = message_file.upload_file_id
else:
mapping["upload_file_id"] = message_file.upload_file_id
@ -368,9 +368,7 @@ def _build_from_datasource_file(
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
return File(
id=mapping.get("datasource_file_id"),

View File

@ -9,7 +9,7 @@ from .base import Base
from .types import StringUUID
class APIBasedExtensionPoint(enum.Enum):
class APIBasedExtensionPoint(enum.StrEnum):
APP_EXTERNAL_DATA_TOOL_QUERY = "app.external_data_tool.query"
PING = "ping"
APP_MODERATION_INPUT = "app.moderation.input"

View File

@ -61,18 +61,18 @@ class Dataset(Base):
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
embedding_model = mapped_column(db.String(255), nullable=True)
embedding_model_provider = mapped_column(db.String(255), nullable=True)
keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10"))
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=db.text("10"))
collection_binding_id = mapped_column(StringUUID, nullable=True)
retrieval_model = mapped_column(JSONB, nullable=True)
built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
icon_info = db.Column(JSONB, nullable=True)
runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying"))
pipeline_id = db.Column(StringUUID, nullable=True)
chunk_structure = db.Column(db.String(255), nullable=True)
enable_api = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
icon_info = mapped_column(JSONB, nullable=True)
runtime_mode = mapped_column(db.String(255), nullable=True, server_default=db.text("'general'::character varying"))
pipeline_id = mapped_column(StringUUID, nullable=True)
chunk_structure = mapped_column(db.String(255), nullable=True)
enable_api = mapped_column(sa.Boolean, nullable=False, server_default=db.text("true"))
@property
def total_documents(self):
@ -184,7 +184,7 @@ class Dataset(Base):
@property
def retrieval_model_dict(self):
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
@ -1226,21 +1226,21 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
__tablename__ = "pipeline_built_in_templates"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False)
chunk_structure = db.Column(db.String(255), nullable=False)
icon = db.Column(db.JSON, nullable=False)
yaml_content = db.Column(db.Text, nullable=False)
copyright = db.Column(db.String(255), nullable=False)
privacy_policy = db.Column(db.String(255), nullable=False)
position = db.Column(db.Integer, nullable=False)
install_count = db.Column(db.Integer, nullable=False, default=0)
language = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_by = db.Column(StringUUID, nullable=False)
updated_by = db.Column(StringUUID, nullable=True)
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
name = mapped_column(db.String(255), nullable=False)
description = mapped_column(sa.Text, nullable=False)
chunk_structure = mapped_column(db.String(255), nullable=False)
icon = mapped_column(sa.JSON, nullable=False)
yaml_content = mapped_column(sa.Text, nullable=False)
copyright = mapped_column(db.String(255), nullable=False)
privacy_policy = mapped_column(db.String(255), nullable=False)
position = mapped_column(sa.Integer, nullable=False)
install_count = mapped_column(sa.Integer, nullable=False, default=0)
language = mapped_column(db.String(255), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_by = mapped_column(StringUUID, nullable=False)
updated_by = mapped_column(StringUUID, nullable=True)
@property
def created_user_name(self):
@ -1257,20 +1257,20 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
db.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False)
chunk_structure = db.Column(db.String(255), nullable=False)
icon = db.Column(db.JSON, nullable=False)
position = db.Column(db.Integer, nullable=False)
yaml_content = db.Column(db.Text, nullable=False)
install_count = db.Column(db.Integer, nullable=False, default=0)
language = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
updated_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id = mapped_column(StringUUID, nullable=False)
name = mapped_column(db.String(255), nullable=False)
description = mapped_column(sa.Text, nullable=False)
chunk_structure = mapped_column(db.String(255), nullable=False)
icon = mapped_column(sa.JSON, nullable=False)
position = mapped_column(sa.Integer, nullable=False)
yaml_content = mapped_column(sa.Text, nullable=False)
install_count = mapped_column(sa.Integer, nullable=False, default=0)
language = mapped_column(db.String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
updated_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def created_user_name(self):
@ -1284,17 +1284,17 @@ class Pipeline(Base): # type: ignore[name-defined]
__tablename__ = "pipelines"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
workflow_id = db.Column(StringUUID, nullable=True)
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name = mapped_column(db.String(255), nullable=False)
description = mapped_column(sa.Text, nullable=False, server_default=db.text("''::character varying"))
workflow_id = mapped_column(StringUUID, nullable=True)
is_public = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
is_published = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
def retrieve_dataset(self, session: Session):
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
@ -1307,25 +1307,25 @@ class DocumentPipelineExecutionLog(Base):
db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
pipeline_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
datasource_type = db.Column(db.String(255), nullable=False)
datasource_info = db.Column(db.Text, nullable=False)
datasource_node_id = db.Column(db.String(255), nullable=False)
input_data = db.Column(db.JSON, nullable=False)
created_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
pipeline_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
datasource_type = mapped_column(db.String(255), nullable=False)
datasource_info = mapped_column(sa.Text, nullable=False)
datasource_node_id = mapped_column(db.String(255), nullable=False)
input_data = mapped_column(sa.JSON, nullable=False)
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class PipelineRecommendedPlugin(Base):
__tablename__ = "pipeline_recommended_plugins"
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
plugin_id = db.Column(db.Text, nullable=False)
provider_name = db.Column(db.Text, nullable=False)
position = db.Column(db.Integer, nullable=False, default=0)
active = db.Column(db.Boolean, nullable=False, default=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
plugin_id = mapped_column(sa.Text, nullable=False)
provider_name = mapped_column(sa.Text, nullable=False)
position = mapped_column(sa.Integer, nullable=False, default=0)
active = mapped_column(sa.Boolean, nullable=False, default=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -186,13 +186,13 @@ class App(Base):
if len(keys) >= 4:
provider_type = tool.get("provider_type", "")
provider_id = tool.get("provider_id", "")
if provider_type == ToolProviderType.API.value:
if provider_type == ToolProviderType.API:
try:
uuid.UUID(provider_id)
except Exception:
continue
api_provider_ids.append(provider_id)
if provider_type == ToolProviderType.BUILT_IN.value:
if provider_type == ToolProviderType.BUILT_IN:
try:
# check if it's hardcoded
try:
@ -251,23 +251,23 @@ class App(Base):
provider_type = tool.get("provider_type", "")
provider_id = tool.get("provider_id", "")
if provider_type == ToolProviderType.API.value:
if provider_type == ToolProviderType.API:
if uuid.UUID(provider_id) not in existing_api_providers:
deleted_tools.append(
{
"type": ToolProviderType.API.value,
"type": ToolProviderType.API,
"tool_name": tool["tool_name"],
"provider_id": provider_id,
}
)
if provider_type == ToolProviderType.BUILT_IN.value:
if provider_type == ToolProviderType.BUILT_IN:
generic_provider_id = GenericProviderID(provider_id)
if not existing_builtin_providers[generic_provider_id.provider_name]:
deleted_tools.append(
{
"type": ToolProviderType.BUILT_IN.value,
"type": ToolProviderType.BUILT_IN,
"tool_name": tool["tool_name"],
"provider_id": provider_id, # use the original one
}
@ -1212,7 +1212,7 @@ class Message(Base):
files: list[File] = []
for message_file in message_files:
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value:
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if message_file.upload_file_id is None:
raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id")
file = file_factory.build_from_mapping(
@ -1224,7 +1224,7 @@ class Message(Base):
},
tenant_id=current_app.tenant_id,
)
elif message_file.transfer_method == FileTransferMethod.REMOTE_URL.value:
elif message_file.transfer_method == FileTransferMethod.REMOTE_URL:
if message_file.url is None:
raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url")
file = file_factory.build_from_mapping(
@ -1237,7 +1237,7 @@ class Message(Base):
},
tenant_id=current_app.tenant_id,
)
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE.value:
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE:
if message_file.upload_file_id is None:
assert message_file.url is not None
message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0]

Some files were not shown because too many files have changed in this diff Show More