mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
commit
f3b415c095
|
|
@ -32,6 +32,8 @@ jobs:
|
|||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
|
||||
# Convert Optional[T] to T | None (ignoring quoted types)
|
||||
cat > /tmp/optional-rule.yml << 'EOF'
|
||||
id: convert-optional-to-union
|
||||
|
|
|
|||
|
|
@ -343,6 +343,15 @@ OCEANBASE_VECTOR_DATABASE=test
|
|||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
|
||||
# AlibabaCloud MySQL Vector configuration
|
||||
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
|
||||
ALIBABACLOUD_MYSQL_PORT=3306
|
||||
ALIBABACLOUD_MYSQL_USER=root
|
||||
ALIBABACLOUD_MYSQL_PASSWORD=root
|
||||
ALIBABACLOUD_MYSQL_DATABASE=dify
|
||||
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
|
||||
ALIBABACLOUD_MYSQL_HNSW_M=6
|
||||
|
||||
# openGauss configuration
|
||||
OPENGAUSS_HOST=127.0.0.1
|
||||
OPENGAUSS_PORT=6600
|
||||
|
|
|
|||
|
|
@ -1570,6 +1570,14 @@ def transform_datasource_credentials():
|
|||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not firecrawl_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
|
|
@ -1625,6 +1633,14 @@ def transform_datasource_credentials():
|
|||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not jina_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
|
|||
from .storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
|
||||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
|
|
@ -330,6 +331,7 @@ class MiddlewareConfig(
|
|||
ClickzettaConfig,
|
||||
HuaweiCloudConfig,
|
||||
MilvusConfig,
|
||||
AlibabaCloudMySQLConfig,
|
||||
MyScaleConfig,
|
||||
OpenSearchConfig,
|
||||
OracleConfig,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AlibabaCloudMySQLConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for AlibabaCloud MySQL vector database
|
||||
"""
|
||||
|
||||
ALIBABACLOUD_MYSQL_HOST: str = Field(
|
||||
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
|
||||
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
|
||||
default=3306,
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_USER: str = Field(
|
||||
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
|
||||
default="root",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
|
||||
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
|
||||
default="",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
|
||||
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
|
||||
description="Maximum number of connections in the connection pool",
|
||||
default=5,
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
|
||||
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
|
||||
default="utf8mb4",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
|
||||
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
|
||||
"(e.g., 'cosine', 'euclidean')",
|
||||
default="cosine",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
|
||||
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
|
||||
default=6,
|
||||
)
|
||||
|
|
@ -1,23 +1,24 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AuthMethod(StrEnum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||
|
||||
|
||||
class OpenSearchConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for OpenSearch
|
||||
"""
|
||||
|
||||
class AuthMethod(Enum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||
|
||||
OPENSEARCH_HOST: str | None = Field(
|
||||
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
|
||||
default=None,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import flask_restx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
|
|
@ -8,7 +7,8 @@ from werkzeug.exceptions import Forbidden
|
|||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
|
|
@ -57,6 +57,8 @@ class BaseApiKeyListResource(Resource):
|
|||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(
|
||||
|
|
@ -69,8 +71,10 @@ class BaseApiKeyListResource(Resource):
|
|||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
|
|
@ -108,6 +112,8 @@ class BaseApiKeyResource(Resource):
|
|||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
|
|
|
|||
|
|
@ -304,7 +304,7 @@ class AppCopyApi(Resource):
|
|||
account = cast(Account, current_user)
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=yaml_content,
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
|
|
|
|||
|
|
@ -70,9 +70,9 @@ class AppImportApi(Resource):
|
|||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
|
@ -97,7 +97,7 @@ class AppImportConfirmApi(Resource):
|
|||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
|
|
|||
|
|
@ -309,7 +309,7 @@ class ChatConversationApi(Resource):
|
|||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
match args["sort_by"]:
|
||||
case "created_at":
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.tools.tool_manager import ToolManager
|
|||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import AppMode, AppModelConfig
|
||||
|
|
@ -172,6 +173,8 @@ class ModelConfigResource(Resource):
|
|||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app_model.updated_by = current_user.id
|
||||
app_model.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ FROM
|
|||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
@ -127,7 +127,7 @@ class DailyConversationStatistic(Resource):
|
|||
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
|
||||
)
|
||||
.select_from(Message)
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
|
||||
)
|
||||
|
||||
if args["start"]:
|
||||
|
|
@ -190,7 +190,7 @@ FROM
|
|||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
@ -263,7 +263,7 @@ FROM
|
|||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
@ -345,7 +345,7 @@ FROM
|
|||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
@ -432,7 +432,7 @@ LEFT JOIN
|
|||
WHERE
|
||||
m.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
@ -509,7 +509,7 @@ FROM
|
|||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
@ -584,7 +584,7 @@ FROM
|
|||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from factories import file_factory, variable_factory
|
|||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
|
|
@ -682,8 +683,12 @@ class PublishedWorkflowApi(Resource):
|
|||
marked_comment=args.marked_comment or "",
|
||||
)
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
|
||||
# Update app_model within the same session to ensure atomicity
|
||||
app_model_in_session = session.get(App, app_model.id)
|
||||
if app_model_in_session:
|
||||
app_model_in_session.workflow_id = workflow.id
|
||||
app_model_in_session.updated_by = current_user.id
|
||||
app_model_in_session.updated_at = naive_utc_now()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ WHERE
|
|||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
|
@ -115,7 +115,7 @@ WHERE
|
|||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
|
@ -183,7 +183,7 @@ WHERE
|
|||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
|
@ -269,7 +269,7 @@ GROUP BY
|
|||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ class ActivateApi(Resource):
|
|||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -130,11 +130,11 @@ class OAuthCallback(Resource):
|
|||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
|
||||
|
||||
# Check account status
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
if account.status == AccountStatus.BANNED:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
if account.status == AccountStatus.PENDING:
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import console_ns
|
||||
|
|
@ -17,6 +17,8 @@ class ComplianceApi(Resource):
|
|||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("doc_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ class DataSourceNotionApi(Resource):
|
|||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
|
|
|
|||
|
|
@ -45,6 +45,79 @@ def _validate_name(name: str) -> str:
|
|||
return name
|
||||
|
||||
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get supported retrieval methods based on vector database type.
|
||||
|
||||
Args:
|
||||
vector_type: Vector database type, can be None
|
||||
is_mock: Whether this is a Mock API, affects MILVUS handling
|
||||
|
||||
Returns:
|
||||
Dictionary containing supported retrieval methods
|
||||
|
||||
Raises:
|
||||
ValueError: If vector_type is None or unsupported
|
||||
"""
|
||||
if vector_type is None:
|
||||
raise ValueError("Vector store type is not configured.")
|
||||
|
||||
# Define vector database types that only support semantic search
|
||||
semantic_only_types = {
|
||||
VectorType.RELYT,
|
||||
VectorType.TIDB_VECTOR,
|
||||
VectorType.CHROMA,
|
||||
VectorType.PGVECTO_RS,
|
||||
VectorType.VIKINGDB,
|
||||
VectorType.UPSTASH,
|
||||
}
|
||||
|
||||
# Define vector database types that support all retrieval methods
|
||||
full_search_types = {
|
||||
VectorType.QDRANT,
|
||||
VectorType.WEAVIATE,
|
||||
VectorType.OPENSEARCH,
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.MYSCALE,
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
VectorType.ELASTICSEARCH_JA,
|
||||
VectorType.PGVECTOR,
|
||||
VectorType.VASTBASE,
|
||||
VectorType.TIDB_ON_QDRANT,
|
||||
VectorType.LINDORM,
|
||||
VectorType.COUCHBASE,
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.OCEANBASE,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.HUAWEI_CLOUD,
|
||||
VectorType.TENCENT,
|
||||
VectorType.MATRIXONE,
|
||||
VectorType.CLICKZETTA,
|
||||
VectorType.BAIDU,
|
||||
VectorType.ALIBABACLOUD_MYSQL,
|
||||
}
|
||||
|
||||
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
full_methods = {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
|
||||
if vector_type == VectorType.MILVUS:
|
||||
return semantic_methods if is_mock else full_methods
|
||||
|
||||
if vector_type in semantic_only_types:
|
||||
return semantic_methods
|
||||
elif vector_type in full_search_types:
|
||||
return full_methods
|
||||
else:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
|
||||
|
||||
@console_ns.route("/datasets")
|
||||
class DatasetListApi(Resource):
|
||||
@api.doc("get_datasets")
|
||||
|
|
@ -500,7 +573,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
|
|
@ -512,7 +585,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
|
|
@ -529,7 +602,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": website_info_list["provider"],
|
||||
|
|
@ -777,49 +850,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.VASTBASE
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.MILVUS
|
||||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
|
||||
|
|
@ -832,48 +863,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.VASTBASE
|
||||
| VectorType.LINDORM
|
||||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.TENCENT
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
|
||||
|
|
|
|||
|
|
@ -475,7 +475,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
|
@ -538,7 +538,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
|
|
@ -546,7 +546,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
|
|
@ -563,7 +563,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
|
|
@ -21,6 +19,7 @@ from core.errors.error import (
|
|||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
|
@ -31,6 +30,7 @@ logger = logging.getLogger(__name__)
|
|||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
|
@ -57,11 +57,12 @@ class DatasetsHitTestingBase:
|
|||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=cast(Account, current_user),
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
|
|
|
|||
|
|
@ -60,9 +60,9 @@ class RagPipelineImportApi(Resource):
|
|||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ class RagPipelineImportConfirmApi(Resource):
|
|||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
|
|
|||
|
|
@ -2,15 +2,15 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import InstalledApp
|
||||
from models.account import Account
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
|
@ -24,6 +24,8 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
|||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(
|
||||
|
|
@ -56,6 +58,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
|||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||
feature = FeatureService.get_system_features()
|
||||
if feature.webapp_auth.enabled:
|
||||
assert isinstance(current_user, Account)
|
||||
app_id = installed_app.app_id
|
||||
app_code = AppService.get_app_code_by_id(app_id)
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
|
@ -47,6 +47,8 @@ class APIBasedExtensionAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tenant_id = current_user.current_tenant_id
|
||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
|
||||
|
|
@ -68,6 +70,8 @@ class APIBasedExtensionAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||
|
|
@ -95,6 +99,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
@ -119,6 +125,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
@ -146,6 +154,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api, console_ns
|
||||
|
|
@ -23,6 +23,8 @@ class FeatureApi(Resource):
|
|||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get feature configuration for current tenant"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
|
|
@ -16,6 +14,7 @@ from core.file import helpers as file_helpers
|
|||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
|
|
@ -65,7 +64,8 @@ class RemoteFileUploadApi(Resource):
|
|||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user = cast(Account, current_user)
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
|
@ -24,6 +24,8 @@ class TagListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(dataset_tag_fields)
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_type = request.args.get("type", type=str, default="")
|
||||
keyword = request.args.get("keyword", default=None, type=str)
|
||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
||||
|
|
@ -34,8 +36,10 @@ class TagListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -59,9 +63,11 @@ class TagUpdateDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -81,9 +87,11 @@ class TagUpdateDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, tag_id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
|
@ -97,8 +105,10 @@ class TagBindingCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -123,8 +133,10 @@ class TagBindingDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
|
|
@ -21,7 +21,9 @@ class AgentProviderListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
|
@ -43,7 +45,9 @@ class AgentProviderApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -6,10 +5,18 @@ from controllers.console import api, console_ns
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
def _current_account_with_tenant() -> tuple[Account, str]:
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
assert tenant_id is not None
|
||||
return current_user, tenant_id
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@api.doc("create_endpoint")
|
||||
|
|
@ -34,7 +41,7 @@ class EndpointCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -51,7 +58,7 @@ class EndpointCreateApi(Resource):
|
|||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
|
|
@ -80,7 +87,7 @@ class EndpointListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -93,7 +100,7 @@ class EndpointListApi(Resource):
|
|||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
|
|
@ -123,7 +130,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -138,7 +145,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
|
|
@ -165,7 +172,7 @@ class EndpointDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -177,9 +184,7 @@ class EndpointDeleteApi(Resource):
|
|||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -207,7 +212,7 @@ class EndpointUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -224,7 +229,7 @@ class EndpointUpdateApi(Resource):
|
|||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
|
|
@ -250,7 +255,7 @@ class EndpointEnableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -262,9 +267,7 @@ class EndpointEnableApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -285,7 +288,7 @@ class EndpointDisableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -297,7 +300,5 @@ class EndpointDisableApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from urllib import parse
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
|
|
@ -26,7 +25,7 @@ from controllers.console.wraps import (
|
|||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
|
@ -24,7 +23,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account, Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ from functools import wraps
|
|||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import AccountStatus
|
||||
from libs.login import current_user
|
||||
from models.account import Account, AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
|
@ -25,11 +25,16 @@ P = ParamSpec("P")
|
|||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _current_account() -> Account:
|
||||
assert isinstance(current_user, Account)
|
||||
return current_user
|
||||
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# check account initialization
|
||||
account = current_user
|
||||
account = _current_account()
|
||||
|
||||
if account.status == AccountStatus.UNINITIALIZED:
|
||||
raise AccountNotInitializedError()
|
||||
|
|
@ -75,7 +80,9 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
|||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
|
@ -87,7 +94,10 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
|
|
@ -128,7 +138,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
|
|
@ -151,10 +163,13 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{current_user.current_tenant_id}"
|
||||
key = f"rate_limit_{tenant_id}"
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
|
||||
|
|
@ -165,7 +180,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
if request_count > knowledge_rate_limit.limit:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
|
|
@ -185,14 +200,17 @@ def cloud_utm_record(view: Callable[P, R]):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: dict = json.loads(utm_info)
|
||||
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
|
||||
OperationService.record_utm(tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
@ -271,7 +289,9 @@ def enable_change_email(view: Callable[P, R]):
|
|||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if features.is_allow_transfer_workspace:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
@ -284,7 +304,9 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
|||
def knowledge_pipeline_publish_enabled(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if features.knowledge_pipeline.publish_enabled:
|
||||
return view(*args, **kwargs)
|
||||
abort(403)
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
|||
As a result, it could only be considered as an end user id.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
user_model = None
|
||||
|
|
@ -85,7 +85,7 @@ def get_user_tenant(view: Callable[P, R] | None = None):
|
|||
raise ValueError("tenant_id is required")
|
||||
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
|
|
|
|||
|
|
@ -313,7 +313,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
|
|||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
end_user = (
|
||||
|
|
@ -332,7 +332,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None =
|
|||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
|
||||
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
|
|
|
|||
|
|
@ -197,12 +197,12 @@ class DatasetConfigManager:
|
|||
|
||||
# strategy
|
||||
if "strategy" not in config["agent_mode"] or not config["agent_mode"].get("strategy"):
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER
|
||||
|
||||
has_datasets = False
|
||||
if config.get("agent_mode", {}).get("strategy") in {
|
||||
PlanningStrategy.ROUTER.value,
|
||||
PlanningStrategy.REACT_ROUTER.value,
|
||||
PlanningStrategy.ROUTER,
|
||||
PlanningStrategy.REACT_ROUTER,
|
||||
}:
|
||||
for tool in config.get("agent_mode", {}).get("tools", []):
|
||||
key = list(tool.keys())[0]
|
||||
|
|
|
|||
|
|
@ -68,9 +68,13 @@ class ModelConfigConverter:
|
|||
# get model mode
|
||||
model_mode = model_config.mode
|
||||
if not model_mode:
|
||||
model_mode = LLMMode.CHAT.value
|
||||
model_mode = LLMMode.CHAT
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
try:
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE])
|
||||
except ValueError:
|
||||
# Fall back to CHAT mode if the stored value is invalid
|
||||
model_mode = LLMMode.CHAT
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class PromptTemplateConfigManager:
|
|||
if config["model"]["mode"] not in model_mode_vals:
|
||||
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
|
||||
|
||||
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
|
||||
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION:
|
||||
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
|
||||
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
|
||||
|
||||
|
|
@ -110,7 +110,7 @@ class PromptTemplateConfigManager:
|
|||
if not assistant_prefix:
|
||||
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
|
||||
|
||||
if config["model"]["mode"] == ModelMode.CHAT.value:
|
||||
if config["model"]["mode"] == ModelMode.CHAT:
|
||||
prompt_list = config["chat_prompt_config"]["prompt"]
|
||||
|
||||
if len(prompt_list) > 10:
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
|||
raise ValueError("enabled in agent_mode must be of boolean type")
|
||||
|
||||
if not agent_mode.get("strategy"):
|
||||
agent_mode["strategy"] = PlanningStrategy.ROUTER.value
|
||||
agent_mode["strategy"] = PlanningStrategy.ROUTER
|
||||
|
||||
if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]:
|
||||
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
||||
|
|
|
|||
|
|
@ -198,9 +198,9 @@ class AgentChatAppRunner(AppRunner):
|
|||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
|
||||
runner_cls = CotChatAgentRunner
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION.value:
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
|
||||
runner_cls = CotCompletionAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
|
||||
|
|
|
|||
|
|
@ -229,8 +229,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -101,8 +101,8 @@ class WorkflowBasedAppRunner:
|
|||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
|
@ -245,8 +245,8 @@ class WorkflowBasedAppRunner:
|
|||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT.value,
|
||||
invoke_from=InvokeFrom.SERVICE_API.value,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class DatasourceProviderApiEntity(BaseModel):
|
|||
for datasource in datasources:
|
||||
if datasource.get("parameters"):
|
||||
for parameter in datasource.get("parameters"):
|
||||
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
|
||||
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES:
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import enum
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
|
@ -54,16 +54,16 @@ class DatasourceParameter(PluginParameter):
|
|||
removes TOOLS_SELECTOR from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = PluginParameterType.STRING.value
|
||||
NUMBER = PluginParameterType.NUMBER.value
|
||||
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||
SELECT = PluginParameterType.SELECT.value
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
STRING = PluginParameterType.STRING
|
||||
NUMBER = PluginParameterType.NUMBER
|
||||
BOOLEAN = PluginParameterType.BOOLEAN
|
||||
SELECT = PluginParameterType.SELECT
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT
|
||||
FILE = PluginParameterType.FILE
|
||||
FILES = PluginParameterType.FILES
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
|
@ -218,7 +218,7 @@ class DatasourceLabel(BaseModel):
|
|||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class DatasourceInvokeFrom(Enum):
|
||||
class DatasourceInvokeFrom(StrEnum):
|
||||
"""
|
||||
Enum class for datasource invoke
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_type == ProviderType.CUSTOM,
|
||||
Provider.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
|
|
@ -458,7 +458,7 @@ class ProviderConfiguration(BaseModel):
|
|||
provider_record = Provider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
provider_type=ProviderType.CUSTOM,
|
||||
is_valid=True,
|
||||
credential_id=new_record.id,
|
||||
)
|
||||
|
|
@ -1414,7 +1414,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
from typing import cast
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from models.api_based_extension import APIBasedExtensionPoint
|
||||
|
||||
|
||||
class APIBasedExtensionRequestor:
|
||||
timeout: tuple[int, int] = (5, 60)
|
||||
timeout: httpx.Timeout = httpx.Timeout(60.0, connect=5.0)
|
||||
"""timeout for request connect and read"""
|
||||
|
||||
def __init__(self, api_endpoint: str, api_key: str):
|
||||
|
|
@ -27,25 +27,23 @@ class APIBasedExtensionRequestor:
|
|||
url = self.api_endpoint
|
||||
|
||||
try:
|
||||
# proxy support for security
|
||||
proxies = None
|
||||
mounts: dict[str, httpx.BaseTransport] | None = None
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxies = {
|
||||
"http": dify_config.SSRF_PROXY_HTTP_URL,
|
||||
"https": dify_config.SSRF_PROXY_HTTPS_URL,
|
||||
mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
|
||||
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
|
||||
}
|
||||
|
||||
response = requests.request(
|
||||
method="POST",
|
||||
url=url,
|
||||
json={"point": point.value, "params": params},
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
proxies=proxies,
|
||||
)
|
||||
except requests.Timeout:
|
||||
with httpx.Client(mounts=mounts, timeout=self.timeout) as client:
|
||||
response = client.request(
|
||||
method="POST",
|
||||
url=url,
|
||||
json={"point": point.value, "params": params},
|
||||
headers=headers,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError("request timeout")
|
||||
except requests.ConnectionError:
|
||||
except httpx.RequestError:
|
||||
raise ValueError("request connection error")
|
||||
|
||||
if response.status_code != 200:
|
||||
|
|
|
|||
|
|
@ -343,7 +343,7 @@ class IndexingRunner:
|
|||
|
||||
if file_detail:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=dataset_document.doc_form,
|
||||
)
|
||||
|
|
@ -356,7 +356,7 @@ class IndexingRunner:
|
|||
):
|
||||
raise ValueError("no notion import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
|
|
@ -379,7 +379,7 @@ class IndexingRunner:
|
|||
):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
|
|
|
|||
|
|
@ -224,8 +224,8 @@ def _handle_native_json_schema(
|
|||
|
||||
# Set appropriate response format if required by the model
|
||||
for rule in rules:
|
||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
|
||||
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA
|
||||
|
||||
return model_parameters
|
||||
|
||||
|
|
@ -239,10 +239,10 @@ def _set_response_format(model_parameters: dict, rules: list):
|
|||
"""
|
||||
for rule in rules:
|
||||
if rule.name == "response_format":
|
||||
if ResponseFormat.JSON.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON.value
|
||||
elif ResponseFormat.JSON_OBJECT.value in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
|
||||
if ResponseFormat.JSON in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON
|
||||
elif ResponseFormat.JSON_OBJECT in rule.options:
|
||||
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT
|
||||
|
||||
|
||||
def _handle_prompt_based_schema(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum, auto
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
|
|
@ -7,7 +7,7 @@ from core.model_runtime.entities.common_entities import I18nObject
|
|||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
class ConfigurateMethod(Enum):
|
||||
class ConfigurateMethod(StrEnum):
|
||||
"""
|
||||
Enum class for configurate method of provider model.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -213,9 +213,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
node_metadata.update(json.loads(node_execution.execution_metadata))
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN.value
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
if node_execution.node_type == "llm":
|
||||
span_kind = OpenInferenceSpanKindValues.LLM.value
|
||||
span_kind = OpenInferenceSpanKindValues.LLM
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
|
|
@ -230,18 +230,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
|||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
|
||||
elif node_execution.node_type == "dataset_retrieval":
|
||||
span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
|
||||
span_kind = OpenInferenceSpanKindValues.RETRIEVER
|
||||
elif node_execution.node_type == "tool":
|
||||
span_kind = OpenInferenceSpanKindValues.TOOL.value
|
||||
span_kind = OpenInferenceSpanKindValues.TOOL
|
||||
else:
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN.value
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
|
||||
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
|
||||
if trace_info.message_id:
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
name = TraceTaskName.MESSAGE_TRACE.value
|
||||
name = TraceTaskName.MESSAGE_TRACE
|
||||
trace_data = LangfuseTrace(
|
||||
id=trace_id,
|
||||
user_id=user_id,
|
||||
|
|
@ -88,7 +88,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
self.add_trace(langfuse_trace_data=trace_data)
|
||||
workflow_span_data = LangfuseSpan(
|
||||
id=trace_info.workflow_run_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
input=dict(trace_info.workflow_run_inputs),
|
||||
output=dict(trace_info.workflow_run_outputs),
|
||||
trace_id=trace_id,
|
||||
|
|
@ -103,7 +103,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
trace_data = LangfuseTrace(
|
||||
id=trace_id,
|
||||
user_id=user_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
input=dict(trace_info.workflow_run_inputs),
|
||||
output=dict(trace_info.workflow_run_outputs),
|
||||
metadata=metadata,
|
||||
|
|
@ -253,7 +253,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
trace_data = LangfuseTrace(
|
||||
id=trace_id,
|
||||
user_id=user_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
input={
|
||||
"message": trace_info.inputs,
|
||||
"files": file_list,
|
||||
|
|
@ -303,7 +303,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
if trace_info.message_data is None:
|
||||
return
|
||||
span_data = LangfuseSpan(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
name=TraceTaskName.MODERATION_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output={
|
||||
"action": trace_info.action,
|
||||
|
|
@ -331,7 +331,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
)
|
||||
|
||||
generation_data = LangfuseGeneration(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output=str(trace_info.suggested_question),
|
||||
trace_id=trace_info.trace_id or trace_info.message_id,
|
||||
|
|
@ -349,7 +349,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
if trace_info.message_data is None:
|
||||
return
|
||||
dataset_retrieval_span_data = LangfuseSpan(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output={"documents": trace_info.documents},
|
||||
trace_id=trace_info.trace_id or trace_info.message_id,
|
||||
|
|
@ -377,7 +377,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
name_generation_trace_data = LangfuseTrace(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output=trace_info.outputs,
|
||||
user_id=trace_info.tenant_id,
|
||||
|
|
@ -388,7 +388,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
self.add_trace(langfuse_trace_data=name_generation_trace_data)
|
||||
|
||||
name_generation_span_data = LangfuseSpan(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE,
|
||||
input=trace_info.inputs,
|
||||
output=trace_info.outputs,
|
||||
trace_id=trace_info.conversation_id,
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
if trace_info.message_id:
|
||||
message_run = LangSmithRunModel(
|
||||
id=trace_info.message_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
run_type=LangSmithRunType.chain,
|
||||
|
|
@ -110,7 +110,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
file_list=trace_info.file_list,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=trace_info.workflow_run_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
run_type=LangSmithRunType.tool,
|
||||
start_time=trace_info.workflow_data.created_at,
|
||||
|
|
@ -271,7 +271,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=message_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
run_type=LangSmithRunType.chain,
|
||||
start_time=trace_info.start_time,
|
||||
|
|
@ -327,7 +327,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
if trace_info.message_data is None:
|
||||
return
|
||||
langsmith_run = LangSmithRunModel(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
name=TraceTaskName.MODERATION_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
outputs={
|
||||
"action": trace_info.action,
|
||||
|
|
@ -362,7 +362,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
if message_data is None:
|
||||
return
|
||||
suggested_question_run = LangSmithRunModel(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.suggested_question,
|
||||
run_type=LangSmithRunType.tool,
|
||||
|
|
@ -391,7 +391,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
if trace_info.message_data is None:
|
||||
return
|
||||
dataset_retrieval_run = LangSmithRunModel(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
outputs={"documents": trace_info.documents},
|
||||
run_type=LangSmithRunType.retriever,
|
||||
|
|
@ -447,7 +447,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
name_run = LangSmithRunModel(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE,
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
run_type=LangSmithRunType.tool,
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"name": TraceTaskName.MESSAGE_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
|
|
@ -125,7 +125,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
"id": root_span_id,
|
||||
"parent_span_id": None,
|
||||
"trace_id": opik_trace_id,
|
||||
"name": TraceTaskName.WORKFLOW_TRACE.value,
|
||||
"name": TraceTaskName.WORKFLOW_TRACE,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"start_time": trace_info.start_time,
|
||||
|
|
@ -138,7 +138,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
else:
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"name": TraceTaskName.MESSAGE_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
|
|
@ -290,7 +290,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, dify_trace_id),
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"name": TraceTaskName.MESSAGE_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
|
|
@ -329,7 +329,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.MODERATION_TRACE.value,
|
||||
"name": TraceTaskName.MODERATION_TRACE,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
|
|
@ -355,7 +355,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or message_data.updated_at,
|
||||
|
|
@ -375,7 +375,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
|
|
@ -405,7 +405,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id),
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
|
|
@ -420,7 +420,7 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
span_data = {
|
||||
"trace_id": trace.id,
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
message_run = WeaveTraceModel(
|
||||
id=trace_info.message_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE.value),
|
||||
op=str(TraceTaskName.MESSAGE_TRACE),
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
total_tokens=trace_info.total_tokens,
|
||||
|
|
@ -126,7 +126,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
file_list=trace_info.file_list,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=trace_info.workflow_run_id,
|
||||
op=str(TraceTaskName.WORKFLOW_TRACE.value),
|
||||
op=str(TraceTaskName.WORKFLOW_TRACE),
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
attributes=workflow_attributes,
|
||||
|
|
@ -253,7 +253,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
message_run = WeaveTraceModel(
|
||||
id=trace_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE.value),
|
||||
op=str(TraceTaskName.MESSAGE_TRACE),
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
|
|
@ -300,7 +300,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
moderation_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.MODERATION_TRACE.value),
|
||||
op=str(TraceTaskName.MODERATION_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs={
|
||||
"action": trace_info.action,
|
||||
|
|
@ -330,7 +330,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
suggested_question_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value),
|
||||
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.suggested_question,
|
||||
attributes=attributes,
|
||||
|
|
@ -355,7 +355,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
dataset_retrieval_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value),
|
||||
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs={"documents": trace_info.documents},
|
||||
attributes=attributes,
|
||||
|
|
@ -397,7 +397,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
name_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.GENERATE_NAME_TRACE.value),
|
||||
op=str(TraceTaskName.GENERATE_NAME_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
attributes=attributes,
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
|||
instruction=instruction, # instruct with variables are not supported
|
||||
)
|
||||
node_data_dict = node_data.model_dump()
|
||||
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR.value
|
||||
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR
|
||||
execution = workflow_service.run_free_workflow_node(
|
||||
node_data_dict,
|
||||
tenant_id=tenant_id,
|
||||
|
|
|
|||
|
|
@ -85,13 +85,13 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
|||
raise ValueError("prompt_messages must be a list")
|
||||
|
||||
for i in range(len(v)):
|
||||
if v[i]["role"] == PromptMessageRole.USER.value:
|
||||
if v[i]["role"] == PromptMessageRole.USER:
|
||||
v[i] = UserPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
|
||||
elif v[i]["role"] == PromptMessageRole.ASSISTANT:
|
||||
v[i] = AssistantPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
|
||||
elif v[i]["role"] == PromptMessageRole.SYSTEM:
|
||||
v[i] = SystemPromptMessage.model_validate(v[i])
|
||||
elif v[i]["role"] == PromptMessageRole.TOOL.value:
|
||||
elif v[i]["role"] == PromptMessageRole.TOOL:
|
||||
v[i] = ToolPromptMessage.model_validate(v[i])
|
||||
else:
|
||||
v[i] = PromptMessage.model_validate(v[i])
|
||||
|
|
|
|||
|
|
@ -2,11 +2,10 @@ import inspect
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from requests.exceptions import HTTPError
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -47,29 +46,56 @@ class BasePluginClient:
|
|||
data: bytes | dict | str | None = None,
|
||||
params: dict | None = None,
|
||||
files: dict | None = None,
|
||||
stream: bool = False,
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API.
|
||||
"""
|
||||
url = plugin_daemon_inner_api_baseurl / path
|
||||
headers = headers or {}
|
||||
headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
headers["Accept-Encoding"] = "gzip, deflate, br"
|
||||
url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files)
|
||||
|
||||
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
request_kwargs: dict[str, Any] = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"params": params,
|
||||
"files": files,
|
||||
}
|
||||
if isinstance(prepared_data, dict):
|
||||
request_kwargs["data"] = prepared_data
|
||||
elif prepared_data is not None:
|
||||
request_kwargs["content"] = prepared_data
|
||||
|
||||
try:
|
||||
response = requests.request(
|
||||
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
|
||||
)
|
||||
except requests.ConnectionError:
|
||||
response = httpx.request(**request_kwargs)
|
||||
except httpx.RequestError:
|
||||
logger.exception("Request to Plugin Daemon Service failed")
|
||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||
|
||||
return response
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
path: str,
|
||||
headers: dict | None,
|
||||
data: bytes | dict | str | None,
|
||||
params: dict | None,
|
||||
files: dict | None,
|
||||
) -> tuple[str, dict, bytes | dict | str | None, dict | None, dict | None]:
|
||||
url = plugin_daemon_inner_api_baseurl / path
|
||||
prepared_headers = dict(headers or {})
|
||||
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
|
||||
|
||||
prepared_data: bytes | dict | str | None = (
|
||||
data if isinstance(data, (bytes, str, dict)) or data is None else None
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
if prepared_headers.get("Content-Type") == "application/json":
|
||||
prepared_data = json.dumps(data)
|
||||
else:
|
||||
prepared_data = data
|
||||
|
||||
return str(url), prepared_headers, prepared_data, params, files
|
||||
|
||||
def _stream_request(
|
||||
self,
|
||||
method: str,
|
||||
|
|
@ -78,17 +104,38 @@ class BasePluginClient:
|
|||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
files: dict | None = None,
|
||||
) -> Generator[bytes, None, None]:
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files, stream=True)
|
||||
for line in response.iter_lines(chunk_size=1024 * 8):
|
||||
line = line.decode("utf-8").strip()
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
if line:
|
||||
yield line
|
||||
url, headers, prepared_data, params, files = self._prepare_request(path, headers, data, params, files)
|
||||
|
||||
stream_kwargs: dict[str, Any] = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"params": params,
|
||||
"files": files,
|
||||
}
|
||||
if isinstance(prepared_data, dict):
|
||||
stream_kwargs["data"] = prepared_data
|
||||
elif prepared_data is not None:
|
||||
stream_kwargs["content"] = prepared_data
|
||||
|
||||
try:
|
||||
with httpx.stream(**stream_kwargs) as response:
|
||||
for raw_line in response.iter_lines():
|
||||
if raw_line is None:
|
||||
continue
|
||||
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
|
||||
line = line.strip()
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
if line:
|
||||
yield line
|
||||
except httpx.RequestError:
|
||||
logger.exception("Stream request to Plugin Daemon Service failed")
|
||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||
|
||||
def _stream_request_with_model(
|
||||
self,
|
||||
|
|
@ -139,7 +186,7 @@ class BasePluginClient:
|
|||
try:
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path)
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
|
@ -208,7 +255,7 @@ class BasePluginClient:
|
|||
except Exception:
|
||||
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
|
||||
|
||||
logger.error("Error in stream reponse for plugin %s", rep.__dict__)
|
||||
logger.error("Error in stream response for plugin %s", rep.__dict__)
|
||||
self._handle_plugin_daemon_error(error.error_type, error.message)
|
||||
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
|
||||
if rep.data is None:
|
||||
|
|
|
|||
|
|
@ -610,7 +610,7 @@ class ProviderManager:
|
|||
|
||||
provider_quota_to_provider_record_dict = {}
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM:
|
||||
continue
|
||||
|
||||
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
|
|
@ -627,8 +627,8 @@ class ProviderManager:
|
|||
tenant_id=tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_limit=quota.quota_limit, # type: ignore
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
|
|
@ -641,8 +641,8 @@ class ProviderManager:
|
|||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
Provider.provider_type == ProviderType.SYSTEM,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL,
|
||||
)
|
||||
existed_provider_record = db.session.scalar(stmt)
|
||||
if not existed_provider_record:
|
||||
|
|
@ -702,7 +702,7 @@ class ProviderManager:
|
|||
"""Get custom provider configuration."""
|
||||
# Find custom provider record (non-system)
|
||||
custom_provider_record = next(
|
||||
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
|
||||
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM), None
|
||||
)
|
||||
|
||||
if not custom_provider_record:
|
||||
|
|
@ -905,7 +905,7 @@ class ProviderManager:
|
|||
# Convert provider_records to dict
|
||||
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM:
|
||||
continue
|
||||
|
||||
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
|
|
@ -1046,7 +1046,7 @@ class ProviderManager:
|
|||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class DataPostProcessor:
|
|||
reranking_model: dict | None = None,
|
||||
weights: dict | None = None,
|
||||
) -> BaseRerankRunner | None:
|
||||
if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
|
||||
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
|
||||
runner = RerankRunnerFactory.create_rerank_runner(
|
||||
runner_type=reranking_mode,
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -62,7 +62,7 @@ class DataPostProcessor:
|
|||
),
|
||||
)
|
||||
return runner
|
||||
elif reranking_mode == RerankMode.RERANKING_MODEL.value:
|
||||
elif reranking_mode == RerankMode.RERANKING_MODEL:
|
||||
rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model)
|
||||
if rerank_model_instance is None:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from models.dataset import Document as DatasetDocument
|
|||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
|
|
@ -34,7 +34,7 @@ class RetrievalService:
|
|||
@classmethod
|
||||
def retrieve(
|
||||
cls,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
|
|
@ -56,7 +56,7 @@ class RetrievalService:
|
|||
# Optimize multithreading with thread pools
|
||||
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
||||
futures = []
|
||||
if retrieval_method == "keyword_search":
|
||||
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.keyword_search,
|
||||
|
|
@ -107,7 +107,7 @@ class RetrievalService:
|
|||
raise ValueError(";\n".join(exceptions))
|
||||
|
||||
# Deduplicate documents for hybrid search to avoid duplicate chunks
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH:
|
||||
all_documents = cls._deduplicate_documents(all_documents)
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
||||
|
|
@ -220,7 +220,7 @@ class RetrievalService:
|
|||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
exceptions: list,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
):
|
||||
|
|
@ -245,10 +245,10 @@ class RetrievalService:
|
|||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
|
|
@ -293,10 +293,10 @@ class RetrievalService:
|
|||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False
|
||||
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False
|
||||
)
|
||||
all_documents.extend(
|
||||
data_post_processor.invoke(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,388 @@
|
|||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import mysql.connector
|
||||
from mysql.connector import Error as MySQLError
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AlibabaCloudMySQLVectorConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
max_connection: int
|
||||
charset: str = "utf8mb4"
|
||||
distance_function: Literal["cosine", "euclidean"] = "cosine"
|
||||
hnsw_m: int = 6
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
if not values.get("host"):
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_HOST is required")
|
||||
if not values.get("port"):
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_PORT is required")
|
||||
if not values.get("user"):
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_USER is required")
|
||||
if values.get("password") is None:
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_PASSWORD is required")
|
||||
if not values.get("database"):
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_DATABASE is required")
|
||||
if not values.get("max_connection"):
|
||||
raise ValueError("config ALIBABACLOUD_MYSQL_MAX_CONNECTION is required")
|
||||
return values
|
||||
|
||||
|
||||
SQL_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
text LONGTEXT NOT NULL,
|
||||
meta JSON NOT NULL,
|
||||
embedding VECTOR({dimension}) NOT NULL,
|
||||
VECTOR INDEX (embedding) M={hnsw_m} DISTANCE={distance_function}
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
|
||||
"""
|
||||
|
||||
SQL_CREATE_META_INDEX = """
|
||||
CREATE INDEX idx_{index_hash}_meta ON {table_name}
|
||||
((CAST(JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) AS CHAR(36))));
|
||||
"""
|
||||
|
||||
SQL_CREATE_FULLTEXT_INDEX = """
|
||||
CREATE FULLTEXT INDEX idx_{index_hash}_text ON {table_name} (text) WITH PARSER ngram;
|
||||
"""
|
||||
|
||||
|
||||
class AlibabaCloudMySQLVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: AlibabaCloudMySQLVectorConfig):
|
||||
super().__init__(collection_name)
|
||||
self.pool = self._create_connection_pool(config)
|
||||
self.table_name = collection_name.lower()
|
||||
self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8]
|
||||
self.distance_function = config.distance_function.lower()
|
||||
self.hnsw_m = config.hnsw_m
|
||||
self._check_vector_support()
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ALIBABACLOUD_MYSQL
|
||||
|
||||
def _create_connection_pool(self, config: AlibabaCloudMySQLVectorConfig):
|
||||
# Create connection pool using mysql-connector-python pooling
|
||||
pool_config: dict[str, Any] = {
|
||||
"host": config.host,
|
||||
"port": config.port,
|
||||
"user": config.user,
|
||||
"password": config.password,
|
||||
"database": config.database,
|
||||
"charset": config.charset,
|
||||
"autocommit": True,
|
||||
"pool_name": f"pool_{self.collection_name}",
|
||||
"pool_size": config.max_connection,
|
||||
"pool_reset_session": True,
|
||||
}
|
||||
return mysql.connector.pooling.MySQLConnectionPool(**pool_config)
|
||||
|
||||
def _check_vector_support(self):
|
||||
"""Check if the MySQL server supports vector operations."""
|
||||
try:
|
||||
with self._get_cursor() as cur:
|
||||
# Check MySQL version and vector support
|
||||
cur.execute("SELECT VERSION()")
|
||||
version = cur.fetchone()["VERSION()"]
|
||||
logger.debug("Connected to MySQL version: %s", version)
|
||||
# Try to execute a simple vector function to verify support
|
||||
cur.execute("SELECT VEC_FromText('[1,2,3]') IS NOT NULL as vector_support")
|
||||
result = cur.fetchone()
|
||||
if not result or not result.get("vector_support"):
|
||||
raise ValueError(
|
||||
"RDS MySQL Vector functions are not available."
|
||||
" Please ensure you're using RDS MySQL 8.0.36+ with Vector support."
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if "FUNCTION" in str(e) and "VEC_FromText" in str(e):
|
||||
raise ValueError(
|
||||
"RDS MySQL Vector functions are not available."
|
||||
" Please ensure you're using RDS MySQL 8.0.36+ with Vector support."
|
||||
) from e
|
||||
raise e
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.get_connection()
|
||||
cur = conn.cursor(dictionary=True)
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
pks = []
|
||||
for i, doc in enumerate(documents):
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
# Convert embedding list to Aliyun MySQL vector format
|
||||
vector_str = "[" + ",".join(map(str, embeddings[i])) + "]"
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
vector_str,
|
||||
)
|
||||
)
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
insert_sql = (
|
||||
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (%s, %s, %s, VEC_FromText(%s))"
|
||||
)
|
||||
cur.executemany(insert_sql, values)
|
||||
return pks
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
if not ids:
|
||||
return []
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
placeholders = ",".join(["%s"] * len(ids))
|
||||
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata = record["meta"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
# Avoiding crashes caused by performing delete operations on empty lists
|
||||
if not ids:
|
||||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
placeholders = ",".join(["%s"] * len(ids))
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
|
||||
except MySQLError as e:
|
||||
if e.errno == 1146: # Table doesn't exist
|
||||
logger.warning("Table %s not found, skipping delete operation.", self.table_name)
|
||||
return
|
||||
else:
|
||||
raise e
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"DELETE FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s", (f"$.{key}", value)
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Search the nearest neighbors to a vector using RDS MySQL vector distance functions.
|
||||
|
||||
:param query_vector: The input vector to search for similar items.
|
||||
:return: List of Documents that are nearest to the query vector.
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
params = []
|
||||
|
||||
if document_ids_filter:
|
||||
placeholders = ",".join(["%s"] * len(document_ids_filter))
|
||||
where_clause = f" WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) "
|
||||
params.extend(document_ids_filter)
|
||||
|
||||
# Convert query vector to RDS MySQL vector format
|
||||
query_vector_str = "[" + ",".join(map(str, query_vector)) + "]"
|
||||
|
||||
# Use RSD MySQL's native vector distance functions
|
||||
with self._get_cursor() as cur:
|
||||
# Choose distance function based on configuration
|
||||
distance_func = "VEC_DISTANCE_COSINE" if self.distance_function == "cosine" else "VEC_DISTANCE_EUCLIDEAN"
|
||||
|
||||
# Note: RDS MySQL optimizer will use vector index when ORDER BY + LIMIT are present
|
||||
# Use column alias in ORDER BY to avoid calculating distance twice
|
||||
sql = f"""
|
||||
SELECT meta, text,
|
||||
{distance_func}(embedding, VEC_FromText(%s)) AS distance
|
||||
FROM {self.table_name}
|
||||
{where_clause}
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
"""
|
||||
query_params = [query_vector_str] + params + [top_k]
|
||||
|
||||
cur.execute(sql, query_params)
|
||||
|
||||
docs = []
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
for record in cur:
|
||||
try:
|
||||
distance = float(record["distance"])
|
||||
# Convert distance to similarity score
|
||||
if self.distance_function == "cosine":
|
||||
# For cosine distance: similarity = 1 - distance
|
||||
similarity = 1.0 - distance
|
||||
else:
|
||||
# For euclidean distance: use inverse relationship
|
||||
# similarity = 1 / (1 + distance)
|
||||
similarity = 1.0 / (1.0 + distance)
|
||||
|
||||
metadata = record["meta"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
metadata["score"] = similarity
|
||||
metadata["distance"] = distance
|
||||
|
||||
if similarity >= score_threshold:
|
||||
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||
except (ValueError, json.JSONDecodeError) as e:
|
||||
logger.warning("Error processing search result: %s", e)
|
||||
continue
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
params = []
|
||||
|
||||
if document_ids_filter:
|
||||
placeholders = ",".join(["%s"] * len(document_ids_filter))
|
||||
where_clause = f" AND JSON_UNQUOTE(JSON_EXTRACT(meta, '$.document_id')) IN ({placeholders}) "
|
||||
params.extend(document_ids_filter)
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
# Build query parameters: query (twice for MATCH clauses), document_ids_filter (if any), top_k
|
||||
query_params = [query, query] + params + [top_k]
|
||||
cur.execute(
|
||||
f"""SELECT meta, text,
|
||||
MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE MATCH(text) AGAINST(%s IN NATURAL LANGUAGE MODE)
|
||||
{where_clause}
|
||||
ORDER BY score DESC
|
||||
LIMIT %s""",
|
||||
query_params,
|
||||
)
|
||||
docs = []
|
||||
for record in cur:
|
||||
metadata = record["meta"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
metadata["score"] = float(record["score"])
|
||||
docs.append(Document(page_content=record["text"], metadata=metadata))
|
||||
return docs
|
||||
|
||||
def delete(self):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{collection_exist_cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
with self._get_cursor() as cur:
|
||||
# Create table with vector column and vector index
|
||||
cur.execute(
|
||||
SQL_CREATE_TABLE.format(
|
||||
table_name=self.table_name,
|
||||
dimension=dimension,
|
||||
distance_function=self.distance_function,
|
||||
hnsw_m=self.hnsw_m,
|
||||
)
|
||||
)
|
||||
# Create metadata index (check if exists first)
|
||||
try:
|
||||
cur.execute(SQL_CREATE_META_INDEX.format(table_name=self.table_name, index_hash=self.index_hash))
|
||||
except MySQLError as e:
|
||||
if e.errno != 1061: # Duplicate key name
|
||||
logger.warning("Could not create meta index: %s", e)
|
||||
|
||||
# Create full-text index for text search
|
||||
try:
|
||||
cur.execute(
|
||||
SQL_CREATE_FULLTEXT_INDEX.format(table_name=self.table_name, index_hash=self.index_hash)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if e.errno != 1061: # Duplicate key name
|
||||
logger.warning("Could not create fulltext index: %s", e)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class AlibabaCloudMySQLVectorFactory(AbstractVectorFactory):
|
||||
def _validate_distance_function(self, distance_function: str) -> Literal["cosine", "euclidean"]:
|
||||
"""Validate and return the distance function as a proper Literal type."""
|
||||
if distance_function not in ["cosine", "euclidean"]:
|
||||
raise ValueError(f"Invalid distance function: {distance_function}. Must be 'cosine' or 'euclidean'")
|
||||
return cast(Literal["cosine", "euclidean"], distance_function)
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AlibabaCloudMySQLVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ALIBABACLOUD_MYSQL, collection_name)
|
||||
)
|
||||
return AlibabaCloudMySQLVector(
|
||||
collection_name=collection_name,
|
||||
config=AlibabaCloudMySQLVectorConfig(
|
||||
host=dify_config.ALIBABACLOUD_MYSQL_HOST or "localhost",
|
||||
port=dify_config.ALIBABACLOUD_MYSQL_PORT,
|
||||
user=dify_config.ALIBABACLOUD_MYSQL_USER or "root",
|
||||
password=dify_config.ALIBABACLOUD_MYSQL_PASSWORD or "",
|
||||
database=dify_config.ALIBABACLOUD_MYSQL_DATABASE or "dify",
|
||||
max_connection=dify_config.ALIBABACLOUD_MYSQL_MAX_CONNECTION,
|
||||
charset=dify_config.ALIBABACLOUD_MYSQL_CHARSET or "utf8mb4",
|
||||
distance_function=self._validate_distance_function(
|
||||
dify_config.ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION or "cosine"
|
||||
),
|
||||
hnsw_m=dify_config.ALIBABACLOUD_MYSQL_HNSW_M or 6,
|
||||
),
|
||||
)
|
||||
|
|
@ -488,9 +488,9 @@ class ClickzettaVector(BaseVector):
|
|||
create_table_sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} (
|
||||
id STRING NOT NULL COMMENT 'Unique document identifier',
|
||||
{Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
|
||||
{Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes',
|
||||
{Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
|
||||
{Field.CONTENT_KEY} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
|
||||
{Field.METADATA_KEY} JSON COMMENT 'Document metadata including source, type, and other attributes',
|
||||
{Field.VECTOR} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
|
||||
'High-dimensional embedding vector for semantic similarity search',
|
||||
PRIMARY KEY (id)
|
||||
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
|
||||
|
|
@ -519,15 +519,15 @@ class ClickzettaVector(BaseVector):
|
|||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
# Check if vector index already exists on the embedding column
|
||||
if Field.VECTOR.value in str(idx).lower():
|
||||
logger.info("Vector index already exists on column %s", Field.VECTOR.value)
|
||||
if Field.VECTOR in str(idx).lower():
|
||||
logger.info("Vector index already exists on column %s", Field.VECTOR)
|
||||
return
|
||||
except (RuntimeError, ValueError) as e:
|
||||
logger.warning("Failed to check existing indexes: %s", e)
|
||||
|
||||
index_sql = f"""
|
||||
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
|
||||
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value})
|
||||
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR})
|
||||
PROPERTIES (
|
||||
"distance.function" = "{self._config.vector_distance_function}",
|
||||
"scalar.type" = "f32",
|
||||
|
|
@ -560,17 +560,17 @@ class ClickzettaVector(BaseVector):
|
|||
# More precise check: look for inverted index specifically on the content column
|
||||
if (
|
||||
"inverted" in idx_str
|
||||
and Field.CONTENT_KEY.value.lower() in idx_str
|
||||
and Field.CONTENT_KEY.lower() in idx_str
|
||||
and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)
|
||||
):
|
||||
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx)
|
||||
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY, idx)
|
||||
return
|
||||
except (RuntimeError, ValueError) as e:
|
||||
logger.warning("Failed to check existing indexes: %s", e)
|
||||
|
||||
index_sql = f"""
|
||||
CREATE INVERTED INDEX IF NOT EXISTS {index_name}
|
||||
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value})
|
||||
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY})
|
||||
PROPERTIES (
|
||||
"analyzer" = "{self._config.analyzer_type}",
|
||||
"mode" = "{self._config.analyzer_mode}"
|
||||
|
|
@ -588,13 +588,13 @@ class ClickzettaVector(BaseVector):
|
|||
or "with the same type" in error_msg
|
||||
or "cannot create inverted index" in error_msg
|
||||
) and "already has index" in error_msg:
|
||||
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value)
|
||||
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY)
|
||||
# Try to get the existing index name for logging
|
||||
try:
|
||||
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
|
||||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower():
|
||||
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.lower() in str(idx).lower():
|
||||
logger.info("Found existing inverted index: %s", idx)
|
||||
break
|
||||
except (RuntimeError, ValueError):
|
||||
|
|
@ -669,7 +669,7 @@ class ClickzettaVector(BaseVector):
|
|||
|
||||
# Use parameterized INSERT with executemany for better performance and security
|
||||
# Cast JSON and VECTOR in SQL, pass raw data as parameters
|
||||
columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}"
|
||||
columns = f"id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}, {Field.VECTOR}"
|
||||
insert_sql = (
|
||||
f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) "
|
||||
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
|
||||
|
|
@ -767,7 +767,7 @@ class ClickzettaVector(BaseVector):
|
|||
# Use json_extract_string function for ClickZetta compatibility
|
||||
sql = (
|
||||
f"DELETE FROM {self._config.schema_name}.{self._table_name} "
|
||||
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?"
|
||||
f"WHERE json_extract_string({Field.METADATA_KEY}, '$.{key}') = ?"
|
||||
)
|
||||
cursor.execute(sql, binding_params=[value])
|
||||
|
||||
|
|
@ -795,9 +795,7 @@ class ClickzettaVector(BaseVector):
|
|||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
|
|
@ -808,23 +806,21 @@ class ClickzettaVector(BaseVector):
|
|||
distance_func = "COSINE_DISTANCE"
|
||||
if score_threshold > 0:
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
filter_clauses.append(
|
||||
f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}"
|
||||
)
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {2 - score_threshold}")
|
||||
else:
|
||||
# For L2 distance, smaller is better
|
||||
distance_func = "L2_DISTANCE"
|
||||
if score_threshold > 0:
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}")
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR}, {query_vector_str}) < {score_threshold}")
|
||||
|
||||
where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1"
|
||||
|
||||
# Execute vector search query
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value},
|
||||
{distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance
|
||||
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY},
|
||||
{distance_func}({Field.VECTOR}, {query_vector_str}) AS distance
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
ORDER BY distance
|
||||
|
|
@ -887,9 +883,7 @@ class ClickzettaVector(BaseVector):
|
|||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
|
|
@ -897,13 +891,13 @@ class ClickzettaVector(BaseVector):
|
|||
# match_all requires all terms to be present
|
||||
# Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')")
|
||||
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY}, '{escaped_query}')")
|
||||
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
# Execute full-text search query
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
|
||||
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
LIMIT {top_k}
|
||||
|
|
@ -986,19 +980,17 @@ class ClickzettaVector(BaseVector):
|
|||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
filter_clauses.append(f"json_extract_string({Field.METADATA_KEY}, '$.document_id') IN ({doc_ids_str})")
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use simple quote escaping for LIKE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
|
||||
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
LIMIT {top_k}
|
||||
|
|
|
|||
|
|
@ -57,18 +57,18 @@ class ElasticSearchJaVector(ElasticSearchVector):
|
|||
}
|
||||
mappings = {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {
|
||||
Field.CONTENT_KEY: {
|
||||
"type": "text",
|
||||
"analyzer": "ja_analyzer",
|
||||
"search_analyzer": "ja_analyzer",
|
||||
},
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
Field.VECTOR: { # Make sure the dimension is correct here
|
||||
"type": "dense_vector",
|
||||
"dims": dim,
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
Field.METADATA_KEY: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import math
|
|||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from elasticsearch import ConnectionError as ElasticsearchConnectionError
|
||||
from elasticsearch import Elasticsearch
|
||||
from flask import current_app
|
||||
from packaging.version import parse as parse_version
|
||||
|
|
@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector):
|
|||
if not client.ping():
|
||||
raise ConnectionError("Failed to connect to Elasticsearch")
|
||||
|
||||
except requests.ConnectionError as e:
|
||||
except ElasticsearchConnectionError as e:
|
||||
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
||||
|
|
@ -163,9 +163,9 @@ class ElasticSearchVector(BaseVector):
|
|||
index=self._collection_name,
|
||||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] or None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata or {},
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i] or None,
|
||||
Field.METADATA_KEY: documents[i].metadata or {},
|
||||
},
|
||||
)
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
|
|
@ -193,7 +193,7 @@ class ElasticSearchVector(BaseVector):
|
|||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
num_candidates = math.ceil(top_k * 1.5)
|
||||
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
||||
knn = {"field": Field.VECTOR, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
|
|
@ -205,9 +205,9 @@ class ElasticSearchVector(BaseVector):
|
|||
docs_and_scores.append(
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
page_content=hit["_source"][Field.CONTENT_KEY],
|
||||
vector=hit["_source"][Field.VECTOR],
|
||||
metadata=hit["_source"][Field.METADATA_KEY],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
|
|
@ -224,13 +224,13 @@ class ElasticSearchVector(BaseVector):
|
|||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY: query}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
if document_ids_filter:
|
||||
query_str = {
|
||||
"bool": {
|
||||
"must": {"match": {Field.CONTENT_KEY.value: query}},
|
||||
"must": {"match": {Field.CONTENT_KEY: query}},
|
||||
"filter": {"terms": {"metadata.document_id": document_ids_filter}},
|
||||
}
|
||||
}
|
||||
|
|
@ -240,9 +240,9 @@ class ElasticSearchVector(BaseVector):
|
|||
for hit in results["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
page_content=hit["_source"][Field.CONTENT_KEY],
|
||||
vector=hit["_source"][Field.VECTOR],
|
||||
metadata=hit["_source"][Field.METADATA_KEY],
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -270,14 +270,14 @@ class ElasticSearchVector(BaseVector):
|
|||
dim = len(embeddings[0])
|
||||
mappings = {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
Field.CONTENT_KEY: {"type": "text"},
|
||||
Field.VECTOR: { # Make sure the dimension is correct here
|
||||
"type": "dense_vector",
|
||||
"dims": dim,
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
Field.METADATA_KEY: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
|
||||
|
|
|
|||
|
|
@ -67,9 +67,9 @@ class HuaweiCloudVector(BaseVector):
|
|||
index=self._collection_name,
|
||||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] or None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata or {},
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i] or None,
|
||||
Field.METADATA_KEY: documents[i].metadata or {},
|
||||
},
|
||||
)
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
|
|
@ -101,7 +101,7 @@ class HuaweiCloudVector(BaseVector):
|
|||
"size": top_k,
|
||||
"query": {
|
||||
"vector": {
|
||||
Field.VECTOR.value: {
|
||||
Field.VECTOR: {
|
||||
"vector": query_vector,
|
||||
"topk": top_k,
|
||||
}
|
||||
|
|
@ -116,9 +116,9 @@ class HuaweiCloudVector(BaseVector):
|
|||
docs_and_scores.append(
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
page_content=hit["_source"][Field.CONTENT_KEY],
|
||||
vector=hit["_source"][Field.VECTOR],
|
||||
metadata=hit["_source"][Field.METADATA_KEY],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
|
|
@ -135,15 +135,15 @@ class HuaweiCloudVector(BaseVector):
|
|||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||
query_str = {"match": {Field.CONTENT_KEY: query}}
|
||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||
docs = []
|
||||
for hit in results["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
page_content=hit["_source"][Field.CONTENT_KEY],
|
||||
vector=hit["_source"][Field.VECTOR],
|
||||
metadata=hit["_source"][Field.METADATA_KEY],
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -171,8 +171,8 @@ class HuaweiCloudVector(BaseVector):
|
|||
dim = len(embeddings[0])
|
||||
mappings = {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
Field.VECTOR.value: { # Make sure the dimension is correct here
|
||||
Field.CONTENT_KEY: {"type": "text"},
|
||||
Field.VECTOR: { # Make sure the dimension is correct here
|
||||
"type": "vector",
|
||||
"dimension": dim,
|
||||
"indexing": True,
|
||||
|
|
@ -181,7 +181,7 @@ class HuaweiCloudVector(BaseVector):
|
|||
"neighbors": 32,
|
||||
"efc": 128,
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
Field.METADATA_KEY: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
|
|
|
|||
|
|
@ -125,9 +125,9 @@ class LindormVectorStore(BaseVector):
|
|||
}
|
||||
}
|
||||
action_values: dict[str, Any] = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i],
|
||||
Field.METADATA_KEY: documents[i].metadata,
|
||||
}
|
||||
if self._using_ugc:
|
||||
action_header["index"]["routing"] = self._routing
|
||||
|
|
@ -149,7 +149,7 @@ class LindormVectorStore(BaseVector):
|
|||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query: dict[str, Any] = {
|
||||
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}
|
||||
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
|
||||
}
|
||||
if self._using_ugc:
|
||||
query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
|
||||
|
|
@ -252,14 +252,14 @@ class LindormVectorStore(BaseVector):
|
|||
search_query: dict[str, Any] = {
|
||||
"size": top_k,
|
||||
"_source": True,
|
||||
"query": {"knn": {Field.VECTOR.value: {"vector": query_vector, "k": top_k}}},
|
||||
"query": {"knn": {Field.VECTOR: {"vector": query_vector, "k": top_k}}},
|
||||
}
|
||||
|
||||
final_ext: dict[str, Any] = {"lvector": {}}
|
||||
if filters is not None and len(filters) > 0:
|
||||
# when using filter, transform filter from List[Dict] to Dict as valid format
|
||||
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
|
||||
search_query["query"]["knn"][Field.VECTOR.value]["filter"] = filter_dict # filter should be Dict
|
||||
search_query["query"]["knn"][Field.VECTOR]["filter"] = filter_dict # filter should be Dict
|
||||
final_ext["lvector"]["filter_type"] = "pre_filter"
|
||||
|
||||
if final_ext != {"lvector": {}}:
|
||||
|
|
@ -279,9 +279,9 @@ class LindormVectorStore(BaseVector):
|
|||
docs_and_scores.append(
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
page_content=hit["_source"][Field.CONTENT_KEY],
|
||||
vector=hit["_source"][Field.VECTOR],
|
||||
metadata=hit["_source"][Field.METADATA_KEY],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
|
|
@ -318,9 +318,9 @@ class LindormVectorStore(BaseVector):
|
|||
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value)
|
||||
vector = hit["_source"].get(Field.VECTOR.value)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY)
|
||||
vector = hit["_source"].get(Field.VECTOR)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY)
|
||||
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
|
|
@ -342,8 +342,8 @@ class LindormVectorStore(BaseVector):
|
|||
"settings": {"index": {"knn": True, "knn_routing": self._using_ugc}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
Field.VECTOR.value: {
|
||||
Field.CONTENT_KEY: {"type": "text"},
|
||||
Field.VECTOR: {
|
||||
"type": "knn_vector",
|
||||
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
|
||||
"method": {
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class MilvusVector(BaseVector):
|
|||
collection_info = self._client.describe_collection(self._collection_name)
|
||||
fields = [field["name"] for field in collection_info["fields"]]
|
||||
# Since primary field is auto-id, no need to track it
|
||||
self._fields = [f for f in fields if f != Field.PRIMARY_KEY.value]
|
||||
self._fields = [f for f in fields if f != Field.PRIMARY_KEY]
|
||||
|
||||
def _check_hybrid_search_support(self) -> bool:
|
||||
"""
|
||||
|
|
@ -130,9 +130,9 @@ class MilvusVector(BaseVector):
|
|||
insert_dict = {
|
||||
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
|
||||
# function will automatically convert the native text into a sparse vector for us.
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i],
|
||||
Field.METADATA_KEY: documents[i].metadata,
|
||||
}
|
||||
insert_dict_list.append(insert_dict)
|
||||
# Total insert count
|
||||
|
|
@ -243,15 +243,15 @@ class MilvusVector(BaseVector):
|
|||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
anns_field=Field.VECTOR.value,
|
||||
anns_field=Field.VECTOR,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
results,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
|
||||
|
|
@ -264,7 +264,7 @@ class MilvusVector(BaseVector):
|
|||
"Full-text search is disabled: set MILVUS_ENABLE_HYBRID_SEARCH=true (requires Milvus >= 2.5.0)."
|
||||
)
|
||||
return []
|
||||
if not self.field_exists(Field.SPARSE_VECTOR.value):
|
||||
if not self.field_exists(Field.SPARSE_VECTOR):
|
||||
logger.warning(
|
||||
"Full-text search unavailable: collection missing 'sparse_vector' field; "
|
||||
"recreate the collection after enabling MILVUS_ENABLE_HYBRID_SEARCH to add BM25 sparse index."
|
||||
|
|
@ -279,15 +279,15 @@ class MilvusVector(BaseVector):
|
|||
results = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
data=[query],
|
||||
anns_field=Field.SPARSE_VECTOR.value,
|
||||
anns_field=Field.SPARSE_VECTOR,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return self._process_search_results(
|
||||
results,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY],
|
||||
score_threshold=float(kwargs.get("score_threshold") or 0.0),
|
||||
)
|
||||
|
||||
|
|
@ -311,7 +311,7 @@ class MilvusVector(BaseVector):
|
|||
dim = len(embeddings[0])
|
||||
fields = []
|
||||
if metadatas:
|
||||
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
|
||||
fields.append(FieldSchema(Field.METADATA_KEY, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field, enable_analyzer will be set True to support milvus automatically
|
||||
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
|
||||
|
|
@ -326,15 +326,15 @@ class MilvusVector(BaseVector):
|
|||
):
|
||||
content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params
|
||||
|
||||
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs))
|
||||
fields.append(FieldSchema(Field.CONTENT_KEY, DataType.VARCHAR, **content_field_kwargs))
|
||||
|
||||
# Create the primary key field
|
||||
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
|
||||
fields.append(FieldSchema(Field.PRIMARY_KEY, DataType.INT64, is_primary=True, auto_id=True))
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
|
||||
fields.append(FieldSchema(Field.VECTOR, infer_dtype_bydata(embeddings[0]), dim=dim))
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self._hybrid_search_enabled:
|
||||
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))
|
||||
fields.append(FieldSchema(Field.SPARSE_VECTOR, DataType.SPARSE_FLOAT_VECTOR))
|
||||
|
||||
schema = CollectionSchema(fields)
|
||||
|
||||
|
|
@ -342,8 +342,8 @@ class MilvusVector(BaseVector):
|
|||
if self._hybrid_search_enabled:
|
||||
bm25_function = Function(
|
||||
name="text_bm25_emb",
|
||||
input_field_names=[Field.CONTENT_KEY.value],
|
||||
output_field_names=[Field.SPARSE_VECTOR.value],
|
||||
input_field_names=[Field.CONTENT_KEY],
|
||||
output_field_names=[Field.SPARSE_VECTOR],
|
||||
function_type=FunctionType.BM25,
|
||||
)
|
||||
schema.add_function(bm25_function)
|
||||
|
|
@ -352,12 +352,12 @@ class MilvusVector(BaseVector):
|
|||
|
||||
# Create Index params for the collection
|
||||
index_params_obj = IndexParams()
|
||||
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
||||
index_params_obj.add_index(field_name=Field.VECTOR, **index_params)
|
||||
|
||||
# Create Sparse Vector Index for the collection
|
||||
if self._hybrid_search_enabled:
|
||||
index_params_obj.add_index(
|
||||
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
|
||||
field_name=Field.SPARSE_VECTOR, index_type="AUTOINDEX", metric_type="BM25"
|
||||
)
|
||||
|
||||
# Create the collection
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
|
|
@ -8,6 +8,7 @@ from opensearchpy.helpers import BulkIndexError
|
|||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from configs.middleware.vdb.opensearch_config import AuthMethod
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
|
|
@ -25,7 +26,7 @@ class OpenSearchConfig(BaseModel):
|
|||
port: int
|
||||
secure: bool = False # use_ssl
|
||||
verify_certs: bool = True
|
||||
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
|
||||
auth_method: AuthMethod = AuthMethod.BASIC
|
||||
user: str | None = None
|
||||
password: str | None = None
|
||||
aws_region: str | None = None
|
||||
|
|
@ -98,9 +99,9 @@ class OpenSearchVector(BaseVector):
|
|||
"_op_type": "index",
|
||||
"_index": self._collection_name.lower(),
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY: documents[i].metadata,
|
||||
},
|
||||
}
|
||||
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
|
||||
|
|
@ -116,7 +117,7 @@ class OpenSearchVector(BaseVector):
|
|||
)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY}.{key}": value}}}
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
if response["hits"]["hits"]:
|
||||
return [hit["_id"] for hit in response["hits"]["hits"]]
|
||||
|
|
@ -180,17 +181,17 @@ class OpenSearchVector(BaseVector):
|
|||
|
||||
query = {
|
||||
"size": kwargs.get("top_k", 4),
|
||||
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
|
||||
"query": {"knn": {Field.VECTOR: {Field.VECTOR: query_vector, "k": kwargs.get("top_k", 4)}}},
|
||||
}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query["query"] = {
|
||||
"script_score": {
|
||||
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID.value: document_ids_filter}}]}},
|
||||
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID: document_ids_filter}}]}},
|
||||
"script": {
|
||||
"source": "knn_score",
|
||||
"lang": "knn",
|
||||
"params": {"field": Field.VECTOR.value, "query_value": query_vector, "space_type": "l2"},
|
||||
"params": {"field": Field.VECTOR, "query_value": query_vector, "space_type": "l2"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
@ -203,7 +204,7 @@ class OpenSearchVector(BaseVector):
|
|||
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value, {})
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY, {})
|
||||
|
||||
# Make sure metadata is a dictionary
|
||||
if metadata is None:
|
||||
|
|
@ -212,7 +213,7 @@ class OpenSearchVector(BaseVector):
|
|||
metadata["score"] = hit["_score"]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if hit["_score"] >= score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
|
@ -227,9 +228,9 @@ class OpenSearchVector(BaseVector):
|
|||
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value)
|
||||
vector = hit["_source"].get(Field.VECTOR.value)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY)
|
||||
vector = hit["_source"].get(Field.VECTOR)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY)
|
||||
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
|
|
@ -250,8 +251,8 @@ class OpenSearchVector(BaseVector):
|
|||
"settings": {"index": {"knn": True}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
Field.VECTOR.value: {
|
||||
Field.CONTENT_KEY: {"type": "text"},
|
||||
Field.VECTOR: {
|
||||
"type": "knn_vector",
|
||||
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
|
||||
"method": {
|
||||
|
|
@ -261,7 +262,7 @@ class OpenSearchVector(BaseVector):
|
|||
"parameters": {"ef_construction": 64, "m": 8},
|
||||
},
|
||||
},
|
||||
Field.METADATA_KEY.value: {
|
||||
Field.METADATA_KEY: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
|
||||
|
|
@ -293,7 +294,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
|||
port=dify_config.OPENSEARCH_PORT,
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
verify_certs=dify_config.OPENSEARCH_VERIFY_CERTS,
|
||||
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
|
||||
auth_method=dify_config.OPENSEARCH_AUTH_METHOD,
|
||||
user=dify_config.OPENSEARCH_USER,
|
||||
password=dify_config.OPENSEARCH_PASSWORD,
|
||||
aws_region=dify_config.OPENSEARCH_AWS_REGION,
|
||||
|
|
|
|||
|
|
@ -147,15 +147,13 @@ class QdrantVector(BaseVector):
|
|||
|
||||
# create group_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create doc_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD)
|
||||
# create document_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create full text index
|
||||
text_index_params = TextIndexParams(
|
||||
|
|
@ -165,9 +163,7 @@ class QdrantVector(BaseVector):
|
|||
max_token_len=20,
|
||||
lowercase=True,
|
||||
)
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
|
||||
)
|
||||
self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
|
|
@ -220,10 +216,10 @@ class QdrantVector(BaseVector):
|
|||
self._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
Field.CONTENT_KEY,
|
||||
Field.METADATA_KEY,
|
||||
group_id or "", # Ensure group_id is never None
|
||||
Field.GROUP_KEY.value,
|
||||
Field.GROUP_KEY,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
|
@ -381,12 +377,12 @@ class QdrantVector(BaseVector):
|
|||
for result in results:
|
||||
if result.payload is None:
|
||||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
metadata = result.payload.get(Field.METADATA_KEY) or {}
|
||||
# duplicate check score threshold
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
page_content=result.payload.get(Field.CONTENT_KEY, ""),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
|
|
@ -433,7 +429,7 @@ class QdrantVector(BaseVector):
|
|||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
|
||||
documents.append(document)
|
||||
|
||||
return documents
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class TableStoreVector(BaseVector):
|
|||
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
|
||||
self._table_name = f"{collection_name}"
|
||||
self._index_name = f"{collection_name}_idx"
|
||||
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
|
||||
self._tags_field = f"{Field.METADATA_KEY}_tags"
|
||||
|
||||
def create_collection(self, embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
|
|
@ -64,7 +64,7 @@ class TableStoreVector(BaseVector):
|
|||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
docs = []
|
||||
request = BatchGetRowRequest()
|
||||
columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value]
|
||||
columns_to_get = [Field.METADATA_KEY, Field.CONTENT_KEY]
|
||||
rows_to_get = [[("id", _id)] for _id in ids]
|
||||
request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1))
|
||||
|
||||
|
|
@ -73,11 +73,7 @@ class TableStoreVector(BaseVector):
|
|||
for item in table_result:
|
||||
if item.is_ok and item.row:
|
||||
kv = {k: v for k, v, _ in item.row.attribute_columns}
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value])
|
||||
)
|
||||
)
|
||||
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY])))
|
||||
return docs
|
||||
|
||||
def get_type(self) -> str:
|
||||
|
|
@ -95,9 +91,9 @@ class TableStoreVector(BaseVector):
|
|||
self._write_row(
|
||||
primary_key=uuids[i],
|
||||
attributes={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i],
|
||||
Field.METADATA_KEY: documents[i].metadata,
|
||||
},
|
||||
)
|
||||
return uuids
|
||||
|
|
@ -180,7 +176,7 @@ class TableStoreVector(BaseVector):
|
|||
|
||||
field_schemas = [
|
||||
tablestore.FieldSchema(
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.CONTENT_KEY,
|
||||
tablestore.FieldType.TEXT,
|
||||
analyzer=tablestore.AnalyzerType.MAXWORD,
|
||||
index=True,
|
||||
|
|
@ -188,7 +184,7 @@ class TableStoreVector(BaseVector):
|
|||
store=False,
|
||||
),
|
||||
tablestore.FieldSchema(
|
||||
Field.VECTOR.value,
|
||||
Field.VECTOR,
|
||||
tablestore.FieldType.VECTOR,
|
||||
vector_options=tablestore.VectorOptions(
|
||||
data_type=tablestore.VectorDataType.VD_FLOAT_32,
|
||||
|
|
@ -197,7 +193,7 @@ class TableStoreVector(BaseVector):
|
|||
),
|
||||
),
|
||||
tablestore.FieldSchema(
|
||||
Field.METADATA_KEY.value,
|
||||
Field.METADATA_KEY,
|
||||
tablestore.FieldType.KEYWORD,
|
||||
index=True,
|
||||
store=False,
|
||||
|
|
@ -233,15 +229,15 @@ class TableStoreVector(BaseVector):
|
|||
pk = [("id", primary_key)]
|
||||
|
||||
tags = []
|
||||
for key, value in attributes[Field.METADATA_KEY.value].items():
|
||||
for key, value in attributes[Field.METADATA_KEY].items():
|
||||
tags.append(str(key) + "=" + str(value))
|
||||
|
||||
attribute_columns = [
|
||||
(Field.CONTENT_KEY.value, attributes[Field.CONTENT_KEY.value]),
|
||||
(Field.VECTOR.value, json.dumps(attributes[Field.VECTOR.value])),
|
||||
(Field.CONTENT_KEY, attributes[Field.CONTENT_KEY]),
|
||||
(Field.VECTOR, json.dumps(attributes[Field.VECTOR])),
|
||||
(
|
||||
Field.METADATA_KEY.value,
|
||||
json.dumps(attributes[Field.METADATA_KEY.value]),
|
||||
Field.METADATA_KEY,
|
||||
json.dumps(attributes[Field.METADATA_KEY]),
|
||||
),
|
||||
(self._tags_field, json.dumps(tags)),
|
||||
]
|
||||
|
|
@ -270,7 +266,7 @@ class TableStoreVector(BaseVector):
|
|||
index_name=self._index_name,
|
||||
search_query=query,
|
||||
columns_to_get=tablestore.ColumnsToGet(
|
||||
column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED
|
||||
column_names=[Field.PRIMARY_KEY], return_type=tablestore.ColumnReturnType.SPECIFIED
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -288,7 +284,7 @@ class TableStoreVector(BaseVector):
|
|||
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
knn_vector_query = tablestore.KnnVectorQuery(
|
||||
field_name=Field.VECTOR.value,
|
||||
field_name=Field.VECTOR,
|
||||
top_k=top_k,
|
||||
float32_query_vector=query_vector,
|
||||
)
|
||||
|
|
@ -311,8 +307,8 @@ class TableStoreVector(BaseVector):
|
|||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR.value)
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
|
||||
vector_str = ots_column_map.get(Field.VECTOR)
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY)
|
||||
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
|
|
@ -321,7 +317,7 @@ class TableStoreVector(BaseVector):
|
|||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
|
@ -343,7 +339,7 @@ class TableStoreVector(BaseVector):
|
|||
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
|
||||
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
|
||||
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY))
|
||||
|
||||
if document_ids_filter:
|
||||
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))
|
||||
|
|
@ -374,10 +370,10 @@ class TableStoreVector(BaseVector):
|
|||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY)
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR.value)
|
||||
vector_str = ots_column_map.get(Field.VECTOR)
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
|
||||
if score:
|
||||
|
|
@ -385,7 +381,7 @@ class TableStoreVector(BaseVector):
|
|||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,10 @@ from collections.abc import Generator, Iterable, Sequence
|
|||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import httpx
|
||||
import qdrant_client
|
||||
import requests
|
||||
from flask import current_app
|
||||
from httpx import DigestAuth
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client.http import models as rest
|
||||
from qdrant_client.http.models import (
|
||||
|
|
@ -19,7 +20,6 @@ from qdrant_client.http.models import (
|
|||
TokenizerType,
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
from requests.auth import HTTPDigestAuth
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -141,15 +141,13 @@ class TidbOnQdrantVector(BaseVector):
|
|||
|
||||
# create group_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
collection_name, Field.GROUP_KEY, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create doc_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
self._client.create_payload_index(collection_name, Field.DOC_ID, field_schema=PayloadSchemaType.KEYWORD)
|
||||
# create document_id payload index
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD
|
||||
collection_name, Field.DOCUMENT_ID, field_schema=PayloadSchemaType.KEYWORD
|
||||
)
|
||||
# create full text index
|
||||
text_index_params = TextIndexParams(
|
||||
|
|
@ -159,9 +157,7 @@ class TidbOnQdrantVector(BaseVector):
|
|||
max_token_len=20,
|
||||
lowercase=True,
|
||||
)
|
||||
self._client.create_payload_index(
|
||||
collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params
|
||||
)
|
||||
self._client.create_payload_index(collection_name, Field.CONTENT_KEY, field_schema=text_index_params)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
|
|
@ -211,10 +207,10 @@ class TidbOnQdrantVector(BaseVector):
|
|||
self._build_payloads(
|
||||
batch_texts,
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
Field.CONTENT_KEY,
|
||||
Field.METADATA_KEY,
|
||||
group_id or "",
|
||||
Field.GROUP_KEY.value,
|
||||
Field.GROUP_KEY,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
|
@ -349,13 +345,13 @@ class TidbOnQdrantVector(BaseVector):
|
|||
for result in results:
|
||||
if result.payload is None:
|
||||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
metadata = result.payload.get(Field.METADATA_KEY) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
page_content=result.payload.get(Field.CONTENT_KEY, ""),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
|
|
@ -392,7 +388,7 @@ class TidbOnQdrantVector(BaseVector):
|
|||
documents = []
|
||||
for result in results:
|
||||
if result:
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value)
|
||||
document = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
|
||||
documents.append(document)
|
||||
|
||||
return documents
|
||||
|
|
@ -504,10 +500,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||
}
|
||||
cluster_data = {"displayName": display_name, "region": region_object, "labels": labels}
|
||||
|
||||
response = requests.post(
|
||||
response = httpx.post(
|
||||
f"{tidb_config.api_url}/clusters",
|
||||
json=cluster_data,
|
||||
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||
auth=DigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
|
|
@ -527,10 +523,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
|||
|
||||
body = {"password": new_password}
|
||||
|
||||
response = requests.put(
|
||||
response = httpx.put(
|
||||
f"{tidb_config.api_url}/clusters/{cluster_id}/password",
|
||||
json=body,
|
||||
auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||
auth=DigestAuth(tidb_config.public_key, tidb_config.private_key),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ import time
|
|||
import uuid
|
||||
from collections.abc import Sequence
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPDigestAuth
|
||||
import httpx
|
||||
from httpx import DigestAuth
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -49,7 +49,7 @@ class TidbService:
|
|||
"rootPassword": password,
|
||||
}
|
||||
|
||||
response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key))
|
||||
response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
|
|
@ -83,7 +83,7 @@ class TidbService:
|
|||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
|
||||
response = httpx.delete(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
|
@ -102,7 +102,7 @@ class TidbService:
|
|||
:return: The response from the API.
|
||||
"""
|
||||
|
||||
response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key))
|
||||
response = httpx.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
|
|
@ -127,10 +127,10 @@ class TidbService:
|
|||
|
||||
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
|
||||
|
||||
response = requests.patch(
|
||||
response = httpx.patch(
|
||||
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
|
||||
json=body,
|
||||
auth=HTTPDigestAuth(public_key, private_key),
|
||||
auth=DigestAuth(public_key, private_key),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
|
|
@ -161,9 +161,7 @@ class TidbService:
|
|||
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
|
||||
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
|
||||
params = {"clusterIds": cluster_ids, "view": "BASIC"}
|
||||
response = requests.get(
|
||||
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
|
||||
)
|
||||
response = httpx.get(f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key))
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
|
|
@ -224,8 +222,8 @@ class TidbService:
|
|||
clusters.append(cluster_data)
|
||||
|
||||
request_body = {"requests": clusters}
|
||||
response = requests.post(
|
||||
f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key)
|
||||
response = httpx.post(
|
||||
f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key)
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
|
|
|
|||
|
|
@ -55,13 +55,13 @@ class TiDBVector(BaseVector):
|
|||
return Table(
|
||||
self._collection_name,
|
||||
self._orm_base.metadata,
|
||||
Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False),
|
||||
Column(Field.PRIMARY_KEY, String(36), primary_key=True, nullable=False),
|
||||
Column(
|
||||
Field.VECTOR.value,
|
||||
Field.VECTOR,
|
||||
VectorType(dim),
|
||||
nullable=False,
|
||||
),
|
||||
Column(Field.TEXT_KEY.value, TEXT, nullable=False),
|
||||
Column(Field.TEXT_KEY, TEXT, nullable=False),
|
||||
Column("meta", JSON, nullable=False),
|
||||
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
|
||||
Column(
|
||||
|
|
|
|||
|
|
@ -71,6 +71,12 @@ class Vector:
|
|||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
||||
|
||||
return MilvusVectorFactory
|
||||
case VectorType.ALIBABACLOUD_MYSQL:
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||
AlibabaCloudMySQLVectorFactory,
|
||||
)
|
||||
|
||||
return AlibabaCloudMySQLVectorFactory
|
||||
case VectorType.MYSCALE:
|
||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from enum import StrEnum
|
|||
|
||||
|
||||
class VectorType(StrEnum):
|
||||
ALIBABACLOUD_MYSQL = "alibabacloud_mysql"
|
||||
ANALYTICDB = "analyticdb"
|
||||
CHROMA = "chroma"
|
||||
MILVUS = "milvus"
|
||||
|
|
|
|||
|
|
@ -76,11 +76,11 @@ class VikingDBVector(BaseVector):
|
|||
|
||||
if not self._has_collection():
|
||||
fields = [
|
||||
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
|
||||
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
|
||||
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension),
|
||||
Field(field_name=vdb_Field.PRIMARY_KEY, field_type=FieldType.String, is_primary_key=True),
|
||||
Field(field_name=vdb_Field.METADATA_KEY, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.GROUP_KEY, field_type=FieldType.String),
|
||||
Field(field_name=vdb_Field.CONTENT_KEY, field_type=FieldType.Text),
|
||||
Field(field_name=vdb_Field.VECTOR, field_type=FieldType.Vector, dim=dimension),
|
||||
]
|
||||
|
||||
self._client.create_collection(
|
||||
|
|
@ -100,7 +100,7 @@ class VikingDBVector(BaseVector):
|
|||
collection_name=self._collection_name,
|
||||
index_name=self._index_name,
|
||||
vector_index=vector_index,
|
||||
partition_by=vdb_Field.GROUP_KEY.value,
|
||||
partition_by=vdb_Field.GROUP_KEY,
|
||||
description="Index For Dify",
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
|
@ -126,11 +126,11 @@ class VikingDBVector(BaseVector):
|
|||
# FIXME: fix the type of metadata later
|
||||
doc = Data(
|
||||
{
|
||||
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore
|
||||
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
|
||||
vdb_Field.CONTENT_KEY.value: page_content,
|
||||
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
|
||||
vdb_Field.GROUP_KEY.value: self._group_id,
|
||||
vdb_Field.PRIMARY_KEY: metadatas[i]["doc_id"], # type: ignore
|
||||
vdb_Field.VECTOR: embeddings[i] if embeddings else None,
|
||||
vdb_Field.CONTENT_KEY: page_content,
|
||||
vdb_Field.METADATA_KEY: json.dumps(metadata),
|
||||
vdb_Field.GROUP_KEY: self._group_id,
|
||||
}
|
||||
)
|
||||
docs.append(doc)
|
||||
|
|
@ -151,7 +151,7 @@ class VikingDBVector(BaseVector):
|
|||
# Note: Metadata field value is an dict, but vikingdb field
|
||||
# not support json type
|
||||
results = self._client.get_index(self._collection_name, self._index_name).search(
|
||||
filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]},
|
||||
filter={"op": "must", "field": vdb_Field.GROUP_KEY, "conds": [self._group_id]},
|
||||
# max value is 5000
|
||||
limit=5000,
|
||||
)
|
||||
|
|
@ -161,7 +161,7 @@ class VikingDBVector(BaseVector):
|
|||
|
||||
ids = []
|
||||
for result in results:
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY)
|
||||
if metadata is not None:
|
||||
metadata = json.loads(metadata)
|
||||
if metadata.get(key) == value:
|
||||
|
|
@ -189,12 +189,12 @@ class VikingDBVector(BaseVector):
|
|||
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY)
|
||||
if metadata is not None:
|
||||
metadata = json.loads(metadata)
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
|
||||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata)
|
||||
docs.append(doc)
|
||||
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
|
||||
return docs
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import datetime
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import weaviate # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
|
@ -45,8 +44,8 @@ class WeaviateVector(BaseVector):
|
|||
client = weaviate.Client(
|
||||
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
|
||||
)
|
||||
except requests.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
except Exception as exc:
|
||||
raise ConnectionError("Vector database connection error") from exc
|
||||
|
||||
client.batch.configure(
|
||||
# `batch_size` takes an `int` value to enable auto-batching
|
||||
|
|
@ -105,7 +104,7 @@ class WeaviateVector(BaseVector):
|
|||
|
||||
with self._client.batch as batch:
|
||||
for i, text in enumerate(texts):
|
||||
data_properties = {Field.TEXT_KEY.value: text}
|
||||
data_properties = {Field.TEXT_KEY: text}
|
||||
if metadatas is not None:
|
||||
# metadata maybe None
|
||||
for key, val in (metadatas[i] or {}).items():
|
||||
|
|
@ -183,7 +182,7 @@ class WeaviateVector(BaseVector):
|
|||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
collection_name = self._collection_name
|
||||
properties = self._attributes
|
||||
properties.append(Field.TEXT_KEY.value)
|
||||
properties.append(Field.TEXT_KEY)
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
|
||||
vector = {"vector": query_vector}
|
||||
|
|
@ -205,7 +204,7 @@ class WeaviateVector(BaseVector):
|
|||
|
||||
docs_and_scores = []
|
||||
for res in result["data"]["Get"][collection_name]:
|
||||
text = res.pop(Field.TEXT_KEY.value)
|
||||
text = res.pop(Field.TEXT_KEY)
|
||||
score = 1 - res["_additional"]["distance"]
|
||||
docs_and_scores.append((Document(page_content=text, metadata=res), score))
|
||||
|
||||
|
|
@ -233,7 +232,7 @@ class WeaviateVector(BaseVector):
|
|||
collection_name = self._collection_name
|
||||
content: dict[str, Any] = {"concepts": [query]}
|
||||
properties = self._attributes
|
||||
properties.append(Field.TEXT_KEY.value)
|
||||
properties.append(Field.TEXT_KEY)
|
||||
if kwargs.get("search_distance"):
|
||||
content["certainty"] = kwargs.get("search_distance")
|
||||
query_obj = self._client.query.get(collection_name, properties)
|
||||
|
|
@ -251,7 +250,7 @@ class WeaviateVector(BaseVector):
|
|||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
docs = []
|
||||
for res in result["data"]["Get"][collection_name]:
|
||||
text = res.pop(Field.TEXT_KEY.value)
|
||||
text = res.pop(Field.TEXT_KEY)
|
||||
additional = res.pop("_additional")
|
||||
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
|
||||
return docs
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DatasourceStreamEvent(Enum):
|
||||
class DatasourceStreamEvent(StrEnum):
|
||||
"""
|
||||
Datasource Stream event
|
||||
"""
|
||||
|
|
@ -20,12 +20,12 @@ class BaseDatasourceEvent(BaseModel):
|
|||
|
||||
|
||||
class DatasourceErrorEvent(BaseDatasourceEvent):
|
||||
event: str = DatasourceStreamEvent.ERROR.value
|
||||
event: DatasourceStreamEvent = DatasourceStreamEvent.ERROR
|
||||
error: str = Field(..., description="error message")
|
||||
|
||||
|
||||
class DatasourceCompletedEvent(BaseDatasourceEvent):
|
||||
event: str = DatasourceStreamEvent.COMPLETED.value
|
||||
event: DatasourceStreamEvent = DatasourceStreamEvent.COMPLETED
|
||||
data: Mapping[str, Any] | list = Field(..., description="result")
|
||||
total: int | None = Field(default=0, description="total")
|
||||
completed: int | None = Field(default=0, description="completed")
|
||||
|
|
@ -33,6 +33,6 @@ class DatasourceCompletedEvent(BaseDatasourceEvent):
|
|||
|
||||
|
||||
class DatasourceProcessingEvent(BaseDatasourceEvent):
|
||||
event: str = DatasourceStreamEvent.PROCESSING.value
|
||||
event: DatasourceStreamEvent = DatasourceStreamEvent.PROCESSING
|
||||
total: int | None = Field(..., description="total")
|
||||
completed: int | None = Field(..., description="completed")
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class ExtractProcessor:
|
|||
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
|
||||
) -> Union[list[Document], str]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=upload_file, document_model="text_model"
|
||||
datasource_type=DatasourceType.FILE, upload_file=upload_file, document_model="text_model"
|
||||
)
|
||||
if return_text:
|
||||
delimiter = "\n"
|
||||
|
|
@ -76,7 +76,7 @@ class ExtractProcessor:
|
|||
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
|
||||
file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}"
|
||||
Path(file_path).write_bytes(response.content)
|
||||
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE.value, document_model="text_model")
|
||||
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model")
|
||||
if return_text:
|
||||
delimiter = "\n"
|
||||
return delimiter.join(
|
||||
|
|
@ -92,7 +92,7 @@ class ExtractProcessor:
|
|||
def extract(
|
||||
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
|
||||
) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE.value:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
if not file_path:
|
||||
assert extract_setting.upload_file is not None, "upload_file is required"
|
||||
|
|
@ -163,7 +163,7 @@ class ExtractProcessor:
|
|||
# txt
|
||||
extractor = TextExtractor(file_path, autodetect_encoding=True)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
|
||||
elif extract_setting.datasource_type == DatasourceType.NOTION:
|
||||
assert extract_setting.notion_info is not None, "notion_info is required"
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
|
||||
|
|
@ -174,7 +174,7 @@ class ExtractProcessor:
|
|||
credential_id=extract_setting.notion_info.credential_id,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
|
||||
elif extract_setting.datasource_type == DatasourceType.WEBSITE:
|
||||
assert extract_setting.website_info is not None, "website_info is required"
|
||||
if extract_setting.website_info.provider == "firecrawl":
|
||||
extractor = FirecrawlWebExtractor(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
|
@ -104,18 +104,18 @@ class FirecrawlApp:
|
|||
def _prepare_headers(self) -> dict[str, Any]:
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> requests.Response:
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
|
||||
for attempt in range(retries):
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response = httpx.post(url, headers=headers, json=data)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2**attempt))
|
||||
else:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> requests.Response:
|
||||
def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
|
||||
for attempt in range(retries):
|
||||
response = requests.get(url, headers=headers)
|
||||
response = httpx.get(url, headers=headers)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2**attempt))
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import operator
|
||||
from typing import Any, cast
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
|
|
@ -92,7 +92,7 @@ class NotionExtractor(BaseExtractor):
|
|||
if next_cursor:
|
||||
current_query["start_cursor"] = next_cursor
|
||||
|
||||
res = requests.post(
|
||||
res = httpx.post(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
headers={
|
||||
"Authorization": "Bearer " + self._notion_access_token,
|
||||
|
|
@ -160,7 +160,7 @@ class NotionExtractor(BaseExtractor):
|
|||
while True:
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
|
||||
try:
|
||||
res = requests.request(
|
||||
res = httpx.request(
|
||||
"GET",
|
||||
block_url,
|
||||
headers={
|
||||
|
|
@ -173,7 +173,7 @@ class NotionExtractor(BaseExtractor):
|
|||
if res.status_code != 200:
|
||||
raise ValueError(f"Error fetching Notion block data: {res.text}")
|
||||
data = res.json()
|
||||
except requests.RequestException as e:
|
||||
except httpx.HTTPError as e:
|
||||
raise ValueError("Error fetching Notion block data") from e
|
||||
if "results" not in data or not isinstance(data["results"], list):
|
||||
raise ValueError("Error fetching Notion block data")
|
||||
|
|
@ -222,7 +222,7 @@ class NotionExtractor(BaseExtractor):
|
|||
while True:
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
|
||||
|
||||
res = requests.request(
|
||||
res = httpx.request(
|
||||
"GET",
|
||||
block_url,
|
||||
headers={
|
||||
|
|
@ -282,7 +282,7 @@ class NotionExtractor(BaseExtractor):
|
|||
while not done:
|
||||
query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor}
|
||||
|
||||
res = requests.request(
|
||||
res = httpx.request(
|
||||
"GET",
|
||||
block_url,
|
||||
headers={
|
||||
|
|
@ -354,7 +354,7 @@ class NotionExtractor(BaseExtractor):
|
|||
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
res = httpx.request(
|
||||
"GET",
|
||||
retrieve_page_url,
|
||||
headers={
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -49,7 +50,8 @@ class UnstructuredWordExtractor(BaseExtractor):
|
|||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import logging
|
|||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -46,7 +47,8 @@ class UnstructuredEmailExtractor(BaseExtractor):
|
|||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import logging
|
|||
|
||||
import pypandoc # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -40,7 +41,8 @@ class UnstructuredEpubExtractor(BaseExtractor):
|
|||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -32,7 +33,8 @@ class UnstructuredMarkdownExtractor(BaseExtractor):
|
|||
elements = partition_md(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -31,7 +32,8 @@ class UnstructuredMsgExtractor(BaseExtractor):
|
|||
elements = partition_msg(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -32,7 +33,8 @@ class UnstructuredXmlExtractor(BaseExtractor):
|
|||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ from collections.abc import Generator
|
|||
from typing import Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from requests import Response
|
||||
import httpx
|
||||
from httpx import Response
|
||||
|
||||
from core.rag.extractor.watercrawl.exceptions import (
|
||||
WaterCrawlAuthenticationError,
|
||||
|
|
@ -20,28 +20,45 @@ class BaseAPIClient:
|
|||
self.session = self.init_session()
|
||||
|
||||
def init_session(self):
|
||||
session = requests.Session()
|
||||
session.headers.update({"X-API-Key": self.api_key})
|
||||
session.headers.update({"Content-Type": "application/json"})
|
||||
session.headers.update({"Accept": "application/json"})
|
||||
session.headers.update({"User-Agent": "WaterCrawl-Plugin"})
|
||||
session.headers.update({"Accept-Language": "en-US"})
|
||||
return session
|
||||
headers = {
|
||||
"X-API-Key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "WaterCrawl-Plugin",
|
||||
"Accept-Language": "en-US",
|
||||
}
|
||||
return httpx.Client(headers=headers, timeout=None)
|
||||
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
query_params: dict | None = None,
|
||||
data: dict | None = None,
|
||||
**kwargs,
|
||||
) -> Response:
|
||||
stream = kwargs.pop("stream", False)
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
if stream:
|
||||
request = self.session.build_request(method, url, params=query_params, json=data)
|
||||
return self.session.send(request, stream=True, **kwargs)
|
||||
|
||||
return self.session.request(method, url, params=query_params, json=data, **kwargs)
|
||||
|
||||
def _get(self, endpoint: str, query_params: dict | None = None, **kwargs):
|
||||
return self.session.get(urljoin(self.base_url, endpoint), params=query_params, **kwargs)
|
||||
return self._request("GET", endpoint, query_params=query_params, **kwargs)
|
||||
|
||||
def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
return self.session.post(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
|
||||
return self._request("POST", endpoint, query_params=query_params, data=data, **kwargs)
|
||||
|
||||
def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
return self.session.put(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
|
||||
return self._request("PUT", endpoint, query_params=query_params, data=data, **kwargs)
|
||||
|
||||
def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs):
|
||||
return self.session.delete(urljoin(self.base_url, endpoint), params=query_params, **kwargs)
|
||||
return self._request("DELETE", endpoint, query_params=query_params, **kwargs)
|
||||
|
||||
def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
return self.session.patch(urljoin(self.base_url, endpoint), params=query_params, json=data, **kwargs)
|
||||
return self._request("PATCH", endpoint, query_params=query_params, data=data, **kwargs)
|
||||
|
||||
|
||||
class WaterCrawlAPIClient(BaseAPIClient):
|
||||
|
|
@ -49,14 +66,17 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
|||
super().__init__(api_key, base_url)
|
||||
|
||||
def process_eventstream(self, response: Response, download: bool = False) -> Generator:
|
||||
for line in response.iter_lines():
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
data = json.loads(line)
|
||||
if data["type"] == "result" and download:
|
||||
data["data"] = self.download_result(data["data"])
|
||||
yield data
|
||||
try:
|
||||
for raw_line in response.iter_lines():
|
||||
line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
data = json.loads(line)
|
||||
if data["type"] == "result" and download:
|
||||
data["data"] = self.download_result(data["data"])
|
||||
yield data
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def process_response(self, response: Response) -> dict | bytes | list | None | Generator:
|
||||
if response.status_code == 401:
|
||||
|
|
@ -170,7 +190,10 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
|||
return event_data["data"]
|
||||
|
||||
def download_result(self, result_object: dict):
|
||||
response = requests.get(result_object["result"])
|
||||
response.raise_for_status()
|
||||
result_object["result"] = response.json()
|
||||
response = httpx.get(result_object["result"], timeout=None)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
result_object["result"] = response.json()
|
||||
finally:
|
||||
response.close()
|
||||
return result_object
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import uuid
|
|||
from urllib.parse import urlparse
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from docx import Document as DocxDocument
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -43,15 +43,19 @@ class WordExtractor(BaseExtractor):
|
|||
|
||||
# If the file is a web path, download it to a temporary file, and use that
|
||||
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
|
||||
r = requests.get(self.file_path)
|
||||
response = httpx.get(self.file_path, timeout=None)
|
||||
|
||||
if r.status_code != 200:
|
||||
raise ValueError(f"Check the url of your file; returned status code {r.status_code}")
|
||||
if response.status_code != 200:
|
||||
response.close()
|
||||
raise ValueError(f"Check the url of your file; returned status code {response.status_code}")
|
||||
|
||||
self.web_path = self.file_path
|
||||
# TODO: use a better way to handle the file
|
||||
self.temp_file = tempfile.NamedTemporaryFile() # noqa SIM115
|
||||
self.temp_file.write(r.content)
|
||||
try:
|
||||
self.temp_file.write(response.content)
|
||||
finally:
|
||||
response.close()
|
||||
self.file_path = self.temp_file.name
|
||||
elif not os.path.isfile(self.file_path):
|
||||
raise ValueError(f"File path {self.file_path} is not a valid file or url")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||
from configs import dify_config
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.splitter.fixed_text_splitter import (
|
||||
EnhanceRecursiveCharacterTextSplitter,
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
|
|
@ -49,7 +50,7 @@ class BaseIndexProcessor(ABC):
|
|||
@abstractmethod
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
|
|
@ -106,7 +107,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
|
|
@ -161,7 +162,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
|
@ -141,7 +142,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ class RerankRunnerFactory:
|
|||
@staticmethod
|
||||
def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner:
|
||||
match runner_type:
|
||||
case RerankMode.RERANKING_MODEL.value:
|
||||
case RerankMode.RERANKING_MODEL:
|
||||
return RerankModelRunner(*args, **kwargs)
|
||||
case RerankMode.WEIGHTED_SCORE.value:
|
||||
case RerankMode.WEIGHTED_SCORE:
|
||||
return WeightRerankRunner(*args, **kwargs)
|
||||
case _:
|
||||
raise ValueError(f"Unknown runner type: {runner_type}")
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ from models.dataset import Document as DatasetDocument
|
|||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
|
|
@ -364,7 +364,7 @@ class DatasetRetrieval:
|
|||
top_k = retrieval_model_config["top_k"]
|
||||
# get retrieval method
|
||||
if dataset.indexing_technique == "economy":
|
||||
retrieval_method = "keyword_search"
|
||||
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
else:
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
# get reranking model
|
||||
|
|
@ -623,7 +623,7 @@ class DatasetRetrieval:
|
|||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search",
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
|
|
@ -692,7 +692,7 @@ class DatasetRetrieval:
|
|||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
# get retrieval model config
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RetrievalMethod(Enum):
|
||||
class RetrievalMethod(StrEnum):
|
||||
SEMANTIC_SEARCH = "semantic_search"
|
||||
FULL_TEXT_SEARCH = "full_text_search"
|
||||
HYBRID_SEARCH = "hybrid_search"
|
||||
|
|
@ -9,8 +9,8 @@ class RetrievalMethod(Enum):
|
|||
|
||||
@staticmethod
|
||||
def is_support_semantic_search(retrieval_method: str) -> bool:
|
||||
return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value}
|
||||
return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH, RetrievalMethod.HYBRID_SEARCH}
|
||||
|
||||
@staticmethod
|
||||
def is_support_fulltext_search(retrieval_method: str) -> bool:
|
||||
return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value}
|
||||
return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH, RetrievalMethod.HYBRID_SEARCH}
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
|
||||
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
|
||||
|
||||
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||
"""
|
||||
|
|
@ -122,7 +122,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
"""
|
||||
if credential_type == CredentialType.OAUTH2.value:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY.value:
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
|
|
@ -134,15 +134,15 @@ class BuiltinToolProviderController(ToolProviderController):
|
|||
"""
|
||||
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||
|
||||
def get_supported_credential_types(self) -> list[str]:
|
||||
def get_supported_credential_types(self) -> list[CredentialType]:
|
||||
"""
|
||||
returns the credential support type of the provider
|
||||
"""
|
||||
types = []
|
||||
if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
|
||||
types.append(CredentialType.API_KEY.value)
|
||||
types.append(CredentialType.API_KEY)
|
||||
if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
|
||||
types.append(CredentialType.OAUTH2.value)
|
||||
types.append(CredentialType.OAUTH2)
|
||||
return types
|
||||
|
||||
def get_tools(self) -> list[BuiltinTool]:
|
||||
|
|
|
|||
|
|
@ -290,6 +290,7 @@ class ApiTool(Tool):
|
|||
method_lc
|
||||
]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926
|
||||
url,
|
||||
max_retries=0,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class ToolProviderApiEntity(BaseModel):
|
|||
for tool in tools:
|
||||
if tool.get("parameters"):
|
||||
for parameter in tool.get("parameters"):
|
||||
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
|
||||
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES:
|
||||
parameter["type"] = "files"
|
||||
if parameter.get("input_schema") is None:
|
||||
parameter.pop("input_schema", None)
|
||||
|
|
@ -111,7 +111,9 @@ class ToolProviderCredentialApiEntity(BaseModel):
|
|||
|
||||
|
||||
class ToolProviderCredentialInfoApiEntity(BaseModel):
|
||||
supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
|
||||
supported_credential_types: list[CredentialType] = Field(
|
||||
description="The supported credential types of the provider"
|
||||
)
|
||||
is_oauth_custom_client_enabled: bool = Field(
|
||||
default=False, description="Whether the OAuth custom client is enabled for the provider"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ class ApiProviderAuthType(StrEnum):
|
|||
# normalize & tiny alias for backward compatibility
|
||||
v = (value or "").strip().lower()
|
||||
if v == "api_key":
|
||||
v = cls.API_KEY_HEADER.value
|
||||
v = cls.API_KEY_HEADER
|
||||
|
||||
for mode in cls:
|
||||
if mode.value == v:
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue