Merge branch 'main' into feat/memory-orchestration-fed

This commit is contained in:
zxhlyh 2025-08-08 13:45:47 +08:00
commit 985becbc41
559 changed files with 27719 additions and 3488 deletions

1197
.env.example Normal file

File diff suppressed because it is too large Load Diff

44
.github/ISSUE_TEMPLATE/chore.yaml vendored Normal file
View File

@ -0,0 +1,44 @@
name: "✨ Refactor"
description: Refactor existing code for improved readability and maintainability.
title: "[Chore/Refactor] "
labels:
- refactor
body:
- type: checkboxes
attributes:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to submit this report, otherwise it will be closed.
required: true
- label: 【中文用户 & Non English User】请使用英语提交否则会被关闭
required: true
- label: "Please do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
id: description
attributes:
label: Description
placeholder: "Describe the refactor you are proposing."
validations:
required: true
- type: textarea
id: motivation
attributes:
label: Motivation
placeholder: "Explain why this refactor is necessary."
validations:
required: false
- type: textarea
id: additional-context
attributes:
label: Additional Context
placeholder: "Add any other context or screenshots about the request here."
validations:
required: false

View File

@ -99,3 +99,6 @@ jobs:
- name: Run Tool
run: uv run --project api bash dev/pytest/pytest_tools.sh
- name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh

View File

@ -9,6 +9,7 @@ permissions:
jobs:
autofix:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

View File

@ -7,6 +7,7 @@ on:
- "deploy/dev"
- "deploy/enterprise"
- "build/**"
- "release/e-*"
tags:
- "*"

View File

@ -5,6 +5,10 @@ on:
types: [closed]
branches: [main]
permissions:
contents: write
pull-requests: write
jobs:
check-and-update:
if: github.event.pull_request.merged == true
@ -16,7 +20,7 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 2 # last 2 commits
persist-credentials: false
token: ${{ secrets.GITHUB_TOKEN }}
- name: Check for file changes in i18n/en-US
id: check_files
@ -49,7 +53,7 @@ jobs:
if: env.FILES_CHANGED == 'true'
run: pnpm install --frozen-lockfile
- name: Run npm script
- name: Generate i18n translations
if: env.FILES_CHANGED == 'true'
run: pnpm run auto-gen-i18n
@ -57,6 +61,7 @@ jobs:
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Update i18n files based on en-US changes
title: 'chore: translate i18n files'
body: This PR was automatically created to update i18n files based on changes in en-US locale.

1
.gitignore vendored
View File

@ -215,3 +215,4 @@ mise.toml
# AI Assistant
.roo/
api/.env.backup
/clickzetta

View File

@ -235,6 +235,10 @@ Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https:/
One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Deploy to AKS with Azure Devops Pipeline
One-Click deploy Dify to AKS with [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## Contributing

View File

@ -217,6 +217,10 @@ docker compose up -d
انشر Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### استخدام Azure Devops Pipeline للنشر على AKS
انشر Dify على AKS بنقرة واحدة باستخدام [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## المساهمة

View File

@ -235,6 +235,10 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### AKS-এ ডিপ্লয় করার জন্য Azure Devops Pipeline ব্যবহার
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ব্যবহার করে Dify কে AKS-এ এক ক্লিকে ডিপ্লয় করুন
## Contributing

View File

@ -233,6 +233,9 @@ docker compose up -d
使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云
#### 使用 Azure Devops Pipeline 部署到AKS
使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 将 Dify 一键部署到 AKS
## Star History

View File

@ -230,6 +230,10 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/)
Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Verwendung von Azure Devops Pipeline für AKS-Bereitstellung
Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) verwenden
## Contributing

View File

@ -230,6 +230,10 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/)
Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Uso de Azure Devops Pipeline para implementar en AKS
Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## Contribuir

View File

@ -228,6 +228,10 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/)
Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Utilisation d'Azure Devops Pipeline pour déployer sur AKS
Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## Contribuer

View File

@ -227,6 +227,10 @@ docker compose up -d
#### Alibaba Cloud Data Management
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます
#### AKSへのデプロイにAzure Devops Pipelineを使用
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)を使用してDifyをAKSにワンクリックでデプロイ
## 貢献

View File

@ -228,6 +228,10 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### AKS 'e' Deploy je Azure Devops Pipeline lo'laH
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) lo'laH Dify AKS 'e' wa'DIch click 'e' Deploy
## Contributing

View File

@ -222,6 +222,10 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다
#### AKS에 배포하기 위해 Azure Devops Pipeline 사용
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)을 사용하여 Dify를 AKS에 원클릭으로 배포
## 기여

View File

@ -227,6 +227,10 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/)
Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Usando Azure Devops Pipeline para Implantar no AKS
Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## Contribuindo

View File

@ -228,6 +228,10 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Uporaba Azure Devops Pipeline za uvajanje v AKS
Z enim klikom namestite Dify v AKS z uporabo [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## Prispevam

View File

@ -221,6 +221,10 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın
#### AKS'ye Dağıtım için Azure Devops Pipeline Kullanımı
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) kullanarak Dify'ı tek tıkla AKS'ye dağıtın
## Katkıda Bulunma

View File

@ -233,6 +233,10 @@ Dify 的所有功能都提供相應的 API因此您可以輕鬆地將 Dify
透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲
#### 使用 Azure Devops Pipeline 部署到AKS
使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 將 Dify 一鍵部署到 AKS
## 貢獻

View File

@ -224,6 +224,10 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/)
Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Sử dụng Azure Devops Pipeline để Triển khai lên AKS
Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure Devops Pipeline Helm Chart bởi @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## Đóng góp

View File

@ -232,6 +232,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com
TABLESTORE_INSTANCE_NAME=instance-name
TABLESTORE_ACCESS_KEY_ID=xxx
TABLESTORE_ACCESS_KEY_SECRET=xxx
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false
# Tidb Vector configuration
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com

View File

@ -19,7 +19,7 @@ RUN apt-get update \
# Install Python dependencies
COPY pyproject.toml uv.lock ./
RUN uv sync --locked
RUN uv sync --locked --no-dev
# production stage
FROM base AS production

View File

@ -5,10 +5,11 @@ import secrets
from typing import Any, Optional
import click
import sqlalchemy as sa
from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config
from constants.languages import languages
@ -180,8 +181,8 @@ def migrate_annotation_vector_database():
)
if not apps:
break
except NotFound:
break
except SQLAlchemyError:
raise
page += 1
for app in apps:
@ -307,8 +308,8 @@ def migrate_knowledge_vector_database():
)
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
except NotFound:
break
except SQLAlchemyError:
raise
page += 1
for dataset in datasets:
@ -457,7 +458,7 @@ def convert_to_agent_apps():
"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query))
rs = conn.execute(sa.text(sql_query))
apps = []
for i in rs:
@ -560,8 +561,8 @@ def old_metadata_migration():
.order_by(DatasetDocument.created_at.desc())
)
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
except NotFound:
break
except SQLAlchemyError:
raise
if not documents:
break
for document in documents:
@ -702,7 +703,7 @@ def fix_app_site_missing():
sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id
where sites.id is null limit 1000"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql))
rs = conn.execute(sa.text(sql))
processed_count = 0
for i in rs:
@ -916,7 +917,7 @@ def clear_orphaned_file_records(force: bool):
)
orphaned_message_files = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])})
@ -937,7 +938,7 @@ def clear_orphaned_file_records(force: bool):
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
query = "DELETE FROM message_files WHERE id IN :ids"
with db.engine.begin() as conn:
conn.execute(db.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
click.echo(
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
)
@ -954,7 +955,7 @@ def clear_orphaned_file_records(force: bool):
click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white"))
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]})
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
@ -974,7 +975,7 @@ def clear_orphaned_file_records(force: bool):
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
elif ids_table["type"] == "text":
@ -989,7 +990,7 @@ def clear_orphaned_file_records(force: bool):
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
@ -1008,7 +1009,7 @@ def clear_orphaned_file_records(force: bool):
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
@ -1037,7 +1038,7 @@ def clear_orphaned_file_records(force: bool):
click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white"))
query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids"
with db.engine.begin() as conn:
conn.execute(db.text(query), {"ids": tuple(orphaned_files)})
conn.execute(sa.text(query), {"ids": tuple(orphaned_files)})
except Exception as e:
click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red"))
return
@ -1107,7 +1108,7 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white"))
query = f"SELECT {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
all_files_in_tables.append(str(i[0]))
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))

View File

@ -10,6 +10,7 @@ from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from .storage.amazon_s3_storage_config import S3StorageConfig
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
from .storage.clickzetta_volume_storage_config import ClickZettaVolumeStorageConfig
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from .storage.oci_storage_config import OCIStorageConfig
@ -20,6 +21,7 @@ from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from .vdb.analyticdb_config import AnalyticdbConfig
from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig
from .vdb.clickzetta_config import ClickzettaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
@ -52,6 +54,7 @@ class StorageConfig(BaseSettings):
"aliyun-oss",
"azure-blob",
"baidu-obs",
"clickzetta-volume",
"google-storage",
"huawei-obs",
"oci-storage",
@ -61,8 +64,9 @@ class StorageConfig(BaseSettings):
"local",
] = Field(
description="Type of storage to use."
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', "
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', "
"'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
"'volcengine-tos', 'supabase'. Default is 'opendal'.",
default="opendal",
)
@ -215,7 +219,7 @@ class DatabaseConfig(BaseSettings):
class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field(
description="Backend for Celery task results. Options: 'database', 'redis'.",
description="Backend for Celery task results. Options: 'database', 'redis', 'rabbitmq'.",
default="redis",
)
@ -245,7 +249,12 @@ class CeleryConfig(DatabaseConfig):
@computed_field
def CELERY_RESULT_BACKEND(self) -> str | None:
return f"db+{self.SQLALCHEMY_DATABASE_URI}" if self.CELERY_BACKEND == "database" else self.CELERY_BROKER_URL
if self.CELERY_BACKEND in ("database", "rabbitmq"):
return f"db+{self.SQLALCHEMY_DATABASE_URI}"
elif self.CELERY_BACKEND == "redis":
return self.CELERY_BROKER_URL
else:
return None
@property
def BROKER_USE_SSL(self) -> bool:
@ -298,6 +307,7 @@ class MiddlewareConfig(
AliyunOSSStorageConfig,
AzureBlobStorageConfig,
BaiduOBSStorageConfig,
ClickZettaVolumeStorageConfig,
GoogleCloudStorageConfig,
HuaweiCloudOBSStorageConfig,
OCIStorageConfig,
@ -310,6 +320,7 @@ class MiddlewareConfig(
VectorStoreConfig,
AnalyticdbConfig,
ChromaConfig,
ClickzettaConfig,
HuaweiCloudConfig,
MilvusConfig,
MyScaleConfig,

View File

@ -0,0 +1,65 @@
"""ClickZetta Volume Storage Configuration"""
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class ClickZettaVolumeStorageConfig(BaseSettings):
"""Configuration for ClickZetta Volume storage."""
CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field(
description="Username for ClickZetta Volume authentication",
default=None,
)
CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field(
description="Password for ClickZetta Volume authentication",
default=None,
)
CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field(
description="ClickZetta instance identifier",
default=None,
)
CLICKZETTA_VOLUME_SERVICE: str = Field(
description="ClickZetta service endpoint",
default="api.clickzetta.com",
)
CLICKZETTA_VOLUME_WORKSPACE: str = Field(
description="ClickZetta workspace name",
default="quick_start",
)
CLICKZETTA_VOLUME_VCLUSTER: str = Field(
description="ClickZetta virtual cluster name",
default="default_ap",
)
CLICKZETTA_VOLUME_SCHEMA: str = Field(
description="ClickZetta schema name",
default="dify",
)
CLICKZETTA_VOLUME_TYPE: str = Field(
description="ClickZetta volume type (table|user|external)",
default="user",
)
CLICKZETTA_VOLUME_NAME: Optional[str] = Field(
description="ClickZetta volume name for external volumes",
default=None,
)
CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field(
description="Prefix for ClickZetta volume table names",
default="dataset_",
)
CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field(
description="Directory prefix for User Volume to organize Dify files",
default="dify_km",
)

View File

@ -0,0 +1,69 @@
from typing import Optional
from pydantic import BaseModel, Field
class ClickzettaConfig(BaseModel):
"""
Clickzetta Lakehouse vector database configuration
"""
CLICKZETTA_USERNAME: Optional[str] = Field(
description="Username for authenticating with Clickzetta Lakehouse",
default=None,
)
CLICKZETTA_PASSWORD: Optional[str] = Field(
description="Password for authenticating with Clickzetta Lakehouse",
default=None,
)
CLICKZETTA_INSTANCE: Optional[str] = Field(
description="Clickzetta Lakehouse instance ID",
default=None,
)
CLICKZETTA_SERVICE: Optional[str] = Field(
description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')",
default="api.clickzetta.com",
)
CLICKZETTA_WORKSPACE: Optional[str] = Field(
description="Clickzetta workspace name",
default="default",
)
CLICKZETTA_VCLUSTER: Optional[str] = Field(
description="Clickzetta virtual cluster name",
default="default_ap",
)
CLICKZETTA_SCHEMA: Optional[str] = Field(
description="Database schema name in Clickzetta",
default="public",
)
CLICKZETTA_BATCH_SIZE: Optional[int] = Field(
description="Batch size for bulk insert operations",
default=100,
)
CLICKZETTA_ENABLE_INVERTED_INDEX: Optional[bool] = Field(
description="Enable inverted index for full-text search capabilities",
default=True,
)
CLICKZETTA_ANALYZER_TYPE: Optional[str] = Field(
description="Analyzer type for full-text search: keyword, english, chinese, unicode",
default="chinese",
)
CLICKZETTA_ANALYZER_MODE: Optional[str] = Field(
description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)",
default="smart",
)
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field(
description="Distance function for vector similarity: l2_distance or cosine_distance",
default="cosine_distance",
)

View File

@ -28,3 +28,8 @@ class TableStoreConfig(BaseSettings):
description="AccessKey secret for the instance name",
default=None,
)
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: bool = Field(
description="Whether to normalize full-text search scores to [0, 1]",
default=False,
)

View File

@ -9,10 +9,10 @@ DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])

View File

@ -84,6 +84,7 @@ from .datasets import (
external,
hit_testing,
metadata,
upload_file,
website,
)

View File

@ -100,7 +100,7 @@ class AnnotationReplyActionStatusApi(Resource):
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
class AnnotationListApi(Resource):
class AnnotationApi(Resource):
@setup_required
@login_required
@account_initialization_required
@ -123,6 +123,23 @@ class AnnotationListApi(Resource):
}
return response, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation
@setup_required
@login_required
@account_initialization_required
@ -131,8 +148,25 @@ class AnnotationListApi(Resource):
raise Forbidden()
app_id = str(app_id)
AppAnnotationService.clear_all_annotations(app_id)
return {"result": "success"}, 204
# Use request.args.getlist to get annotation_ids array directly
annotation_ids = request.args.getlist("annotation_id")
# If annotation_ids are provided, handle batch deletion
if annotation_ids:
# Check if any annotation_ids contain empty strings or invalid values
if not all(annotation_id.strip() for annotation_id in annotation_ids if annotation_id):
return {
"code": "bad_request",
"message": "annotation_ids are required if the parameter is provided.",
}, 400
result = AppAnnotationService.delete_app_annotations_in_batch(app_id, annotation_ids)
return result, 204
# If no annotation_ids are provided, handle clearing all annotations
else:
AppAnnotationService.clear_all_annotations(app_id)
return {"result": "success"}, 204
class AnnotationExportApi(Resource):
@ -149,25 +183,6 @@ class AnnotationExportApi(Resource):
return response, 200
class AnnotationCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
def post(self, app_id):
if not current_user.is_editor:
raise Forbidden()
app_id = str(app_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation
class AnnotationUpdateDeleteApi(Resource):
@setup_required
@login_required
@ -210,14 +225,15 @@ class AnnotationBatchImportApi(Resource):
raise Forbidden()
app_id = str(app_id)
# get file from request
file = request.files["file"]
# check file
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# get file from request
file = request.files["file"]
# check file type
if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
@ -276,7 +292,7 @@ api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply
api.add_resource(
AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>"
)
api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations")
api.add_resource(AnnotationApi, "/apps/<uuid:app_id>/annotations")
api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export")
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import")

View File

@ -28,6 +28,12 @@ from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class AppListApi(Resource):
@setup_required
@login_required
@ -94,7 +100,7 @@ class AppListApi(Resource):
"""Create app"""
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
@ -146,7 +152,7 @@ class AppApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
@ -189,7 +195,7 @@ class AppCopyApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("description", type=_validate_description_length, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")

View File

@ -67,7 +67,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "message_count": i.message_count})
@ -176,7 +176,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
@ -234,7 +234,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"}
@ -310,7 +310,7 @@ ORDER BY
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
@ -373,7 +373,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{
@ -435,7 +435,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)})
@ -495,7 +495,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})

View File

@ -2,6 +2,7 @@ from datetime import datetime
from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restful import Resource, reqparse
@ -71,7 +72,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "runs": i.runs})
@ -133,7 +134,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
@ -195,7 +196,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{
@ -277,7 +278,7 @@ GROUP BY
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}

View File

@ -41,7 +41,7 @@ def _validate_name(name):
def _validate_description_length(description):
if len(description) > 400:
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@ -113,7 +113,7 @@ class DatasetListApi(Resource):
)
parser.add_argument(
"description",
type=str,
type=_validate_description_length,
nullable=True,
required=False,
default="",
@ -683,6 +683,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.HUAWEI_CLOUD
| VectorType.TENCENT
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
):
return {
"retrieval_method": [
@ -731,6 +732,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.TENCENT
| VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
):
return {
"retrieval_method": [

View File

@ -642,7 +642,7 @@ class DocumentIndexingStatusApi(DocumentResource):
return marshal(document_dict, document_status_fields)
class DocumentDetailApi(DocumentResource):
class DocumentApi(DocumentResource):
METADATA_CHOICES = {"all", "only", "without"}
@setup_required
@ -730,6 +730,28 @@ class DocumentDetailApi(DocumentResource):
return response, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id)
try:
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
class DocumentProcessingApi(DocumentResource):
@setup_required
@ -768,30 +790,6 @@ class DocumentProcessingApi(DocumentResource):
return {"result": "success"}, 200
class DocumentDeleteApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id)
try:
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
class DocumentMetadataApi(DocumentResource):
@setup_required
@login_required
@ -1037,11 +1035,10 @@ api.add_resource(
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
)
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")

View File

@ -0,0 +1,62 @@
from flask_login import current_user
from flask_restful import Resource
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import UploadFile
from services.dataset_service import DocumentService
class UploadFileApi(Resource):
@setup_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""Get upload file."""
# check dataset
dataset_id = str(dataset_id)
dataset = (
db.session.query(Dataset)
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id)
.first()
)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check upload file
if document.data_source_type != "upload_file":
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:
raise ValueError("Upload file id not found in document data source info.")
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"url": url,
"download_url": f"{url}&as_attachment=true",
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at.timestamp(),
}, 200
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")

View File

@ -127,7 +127,7 @@ class EducationActivateLimitError(BaseHTTPException):
code = 429
class CompilanceRateLimitError(BaseHTTPException):
error_code = "compilance_rate_limit"
class ComplianceRateLimitError(BaseHTTPException):
error_code = "compliance_rate_limit"
description = "Rate limit exceeded for downloading compliance report."
code = 429

View File

@ -58,21 +58,38 @@ class InstalledAppsListApi(Resource):
# filter out apps that user doesn't have access to
if FeatureService.get_system_features().webapp_auth.enabled:
user_id = current_user.id
res = []
app_ids = [installed_app["app"].id for installed_app in installed_app_list]
webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids)
# Pre-filter out apps without setting or with sso_verified
filtered_installed_apps = []
app_id_to_app_code = {}
for installed_app in installed_app_list:
webapp_setting = webapp_settings.get(installed_app["app"].id)
if not webapp_setting:
app_id = installed_app["app"].id
webapp_setting = webapp_settings.get(app_id)
if not webapp_setting or webapp_setting.access_mode == "sso_verified":
continue
if webapp_setting.access_mode == "sso_verified":
continue
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id,
app_code=app_code,
):
app_code = AppService.get_app_code_by_id(str(app_id))
app_id_to_app_code[app_id] = app_code
filtered_installed_apps.append(installed_app)
app_codes = list(app_id_to_app_code.values())
# Batch permission check
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
user_id=user_id,
app_codes=app_codes,
)
# Keep only allowed apps
res = []
for installed_app in filtered_installed_apps:
app_id = installed_app["app"].id
app_code = app_id_to_app_code[app_id]
if permissions.get(app_code):
res.append(installed_app)
installed_app_list = res
logger.debug("installed_app_list: %s, user_id: %s", installed_app_list, user_id)

View File

@ -49,7 +49,6 @@ class FileApi(Resource):
@marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents")
def post(self):
file = request.files["file"]
source_str = request.form.get("source")
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
@ -58,6 +57,7 @@ class FileApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.filename:
raise FilenameNotExistsError

View File

@ -191,9 +191,6 @@ class WebappLogoWorkspaceApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
# get file from request
file = request.files["file"]
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -201,6 +198,8 @@ class WebappLogoWorkspaceApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
# get file from request
file = request.files["file"]
if not file.filename:
raise FilenameNotExistsError

View File

@ -6,6 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp)
from . import index
from .app import annotation, app, audio, completion, conversation, file, message, site, workflow
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
from .workspace import models

View File

@ -2,7 +2,7 @@ import logging
from flask import request
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.service_api import api
@ -30,6 +30,7 @@ from libs import helper
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
@ -113,7 +114,7 @@ class ChatApi(Resource):
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
parser.add_argument("workflow_id", type=str, required=False, location="json")
args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
@ -128,6 +129,12 @@ class ChatApi(Resource):
)
return helper.compact_generate_response(response)
except WorkflowNotFoundError as ex:
raise NotFound(str(ex))
except IsDraftWorkflowError as ex:
raise BadRequest(str(ex))
except WorkflowIdFormatError as ex:
raise BadRequest(str(ex))
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:

View File

@ -1,7 +1,9 @@
import json
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import BadRequest, NotFound
import services
from controllers.service_api import api
@ -15,6 +17,7 @@ from fields.conversation_fields import (
simple_conversation_fields,
)
from fields.conversation_variable_fields import (
conversation_variable_fields,
conversation_variable_infinite_scroll_pagination_fields,
)
from libs.helper import uuid_value
@ -120,7 +123,41 @@ class ConversationVariablesApi(Resource):
raise NotFound("Conversation Not Exists.")
class ConversationVariableDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(conversation_variable_fields)
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
"""Update a conversation variable's value"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
variable_id = str(variable_id)
parser = reqparse.RequestParser()
parser.add_argument("value", required=True, location="json")
args = parser.parse_args()
try:
return ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id, end_user, json.loads(args["value"])
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationVariableNotExistsError:
raise NotFound("Conversation Variable Not Exists.")
except services.errors.conversation.ConversationVariableTypeMismatchError as e:
raise BadRequest(str(e))
api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="conversation_name")
api.add_resource(ConversationApi, "/conversations")
api.add_resource(ConversationDetailApi, "/conversations/<uuid:c_id>", endpoint="conversation_detail")
api.add_resource(ConversationVariablesApi, "/conversations/<uuid:c_id>/variables", endpoint="conversation_variables")
api.add_resource(
ConversationVariableDetailApi,
"/conversations/<uuid:c_id>/variables/<uuid:variable_id>",
endpoint="conversation_variable_detail",
methods=["PUT"],
)

View File

@ -107,3 +107,15 @@ class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class FileNotFoundError(BaseHTTPException):
error_code = "file_not_found"
description = "The requested file was not found."
code = 404
class FileAccessDeniedError(BaseHTTPException):
error_code = "file_access_denied"
description = "Access to the requested file is denied."
code = 403

View File

@ -20,18 +20,17 @@ class FileApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@marshal_with(file_fields)
def post(self, app_model: App, end_user: EndUser):
file = request.files["file"]
# check file
if "file" not in request.files:
raise NoFileUploadedError()
if not file.mimetype:
raise UnsupportedFileTypeError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.mimetype:
raise UnsupportedFileTypeError()
if not file.filename:
raise FilenameNotExistsError

View File

@ -0,0 +1,186 @@
import logging
from urllib.parse import quote
from flask import Response
from flask_restful import Resource, reqparse
from controllers.service_api import api
from controllers.service_api.app.error import (
FileAccessDeniedError,
FileNotFoundError,
)
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, EndUser, Message, MessageFile, UploadFile
logger = logging.getLogger(__name__)
class FilePreviewApi(Resource):
"""
Service API File Preview endpoint
Provides secure file preview/download functionality for external API users.
Files can only be accessed if they belong to messages within the requesting app's context.
"""
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, file_id: str):
"""
Preview/Download a file that was uploaded via Service API
Args:
app_model: The authenticated app model
end_user: The authenticated end user (optional)
file_id: UUID of the file to preview
Query Parameters:
user: Optional user identifier
as_attachment: Boolean, whether to download as attachment (default: false)
Returns:
Stream response with file content
Raises:
FileNotFoundError: File does not exist
FileAccessDeniedError: File access denied (not owned by app)
"""
file_id = str(file_id)
# Parse query parameters
parser = reqparse.RequestParser()
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
args = parser.parse_args()
# Validate file ownership and get file objects
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
# Get file content generator
try:
generator = storage.load(upload_file.key, stream=True)
except Exception as e:
raise FileNotFoundError(f"Failed to load file content: {str(e)}")
# Build response with appropriate headers
response = self._build_file_response(generator, upload_file, args["as_attachment"])
return response
def _validate_file_ownership(self, file_id: str, app_id: str) -> tuple[MessageFile, UploadFile]:
"""
Validate that the file belongs to a message within the requesting app's context
Security validations performed:
1. File exists in MessageFile table (was used in a conversation)
2. Message belongs to the requesting app
3. UploadFile record exists and is accessible
4. File tenant matches app tenant (additional security layer)
Args:
file_id: UUID of the file to validate
app_id: UUID of the requesting app
Returns:
Tuple of (MessageFile, UploadFile) if validation passes
Raises:
FileNotFoundError: File or related records not found
FileAccessDeniedError: File does not belong to the app's context
"""
try:
# Input validation
if not file_id or not app_id:
raise FileAccessDeniedError("Invalid file or app identifier")
# First, find the MessageFile that references this upload file
message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first()
if not message_file:
raise FileNotFoundError("File not found in message context")
# Get the message and verify it belongs to the requesting app
message = (
db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first()
)
if not message:
raise FileAccessDeniedError("File access denied: not owned by requesting app")
# Get the actual upload file record
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise FileNotFoundError("Upload file record not found")
# Additional security: verify tenant isolation
app = db.session.query(App).where(App.id == app_id).first()
if app and upload_file.tenant_id != app.tenant_id:
raise FileAccessDeniedError("File access denied: tenant mismatch")
return message_file, upload_file
except (FileNotFoundError, FileAccessDeniedError):
# Re-raise our custom exceptions
raise
except Exception as e:
# Log unexpected errors for debugging
logger.exception(
"Unexpected error during file ownership validation",
extra={"file_id": file_id, "app_id": app_id, "error": str(e)},
)
raise FileAccessDeniedError("File access validation failed")
def _build_file_response(self, generator, upload_file: UploadFile, as_attachment: bool = False) -> Response:
"""
Build Flask Response object with appropriate headers for file streaming
Args:
generator: File content generator from storage
upload_file: UploadFile database record
as_attachment: Whether to set Content-Disposition as attachment
Returns:
Flask Response object with streaming file content
"""
response = Response(
generator,
mimetype=upload_file.mime_type,
direct_passthrough=True,
headers={},
)
# Add Content-Length if known
if upload_file.size and upload_file.size > 0:
response.headers["Content-Length"] = str(upload_file.size)
# Add Accept-Ranges header for audio/video files to support seeking
if upload_file.mime_type in [
"audio/mpeg",
"audio/wav",
"audio/mp4",
"audio/ogg",
"audio/flac",
"audio/aac",
"video/mp4",
"video/webm",
"video/quicktime",
"audio/x-m4a",
]:
response.headers["Accept-Ranges"] = "bytes"
# Set Content-Disposition for downloads
if as_attachment and upload_file.name:
encoded_filename = quote(upload_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
# Override content-type for downloads to force download
response.headers["Content-Type"] = "application/octet-stream"
# Add caching headers for performance
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour
return response
# Register the API endpoint
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")

View File

@ -5,7 +5,7 @@ from flask import request
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import InternalServerError
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.service_api import api
from controllers.service_api.app.error import (
@ -34,6 +34,7 @@ from libs.helper import TimestampField
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_app_service import WorkflowAppService
@ -120,6 +121,59 @@ class WorkflowRunApi(Resource):
raise InternalServerError()
class WorkflowRunByIdApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, workflow_id: str):
"""
Run specific workflow by ID
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
args = parser.parse_args()
# Add workflow_id to args for AppGenerateService
args["workflow_id"] = workflow_id
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming"
try:
response = AppGenerateService.generate(
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
)
return helper.compact_generate_response(response)
except WorkflowNotFoundError as ex:
raise NotFound(str(ex))
except IsDraftWorkflowError as ex:
raise BadRequest(str(ex))
except WorkflowIdFormatError as ex:
raise BadRequest(str(ex))
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowTaskStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id: str):
@ -193,5 +247,6 @@ class WorkflowAppLogApi(Resource):
api.add_resource(WorkflowRunApi, "/workflows/run")
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_run_id>")
api.add_resource(WorkflowRunByIdApi, "/workflows/<string:workflow_id>/run")
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
api.add_resource(WorkflowAppLogApi, "/workflows/logs")

View File

@ -29,7 +29,7 @@ def _validate_name(name):
def _validate_description_length(description):
if len(description) > 400:
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@ -87,7 +87,7 @@ class DatasetListApi(DatasetApiResource):
)
parser.add_argument(
"description",
type=str,
type=_validate_description_length,
nullable=True,
required=False,
default="",

View File

@ -234,8 +234,6 @@ class DocumentAddByFileApi(DatasetApiResource):
args["retrieval_model"].get("reranking_model").get("reranking_model_name"),
)
# save file info
file = request.files["file"]
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -243,6 +241,8 @@ class DocumentAddByFileApi(DatasetApiResource):
if len(request.files) > 1:
raise TooManyFilesError()
# save file info
file = request.files["file"]
if not file.filename:
raise FilenameNotExistsError
@ -358,39 +358,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
return documents_and_batch_fields, 200
class DocumentDeleteApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document."""
document_id = str(document_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
raise NotFound("Document Not Exists.")
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
try:
# delete document
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return 204
class DocumentListApi(DatasetApiResource):
def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
@ -473,7 +440,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
return data
class DocumentDetailApi(DatasetApiResource):
class DocumentApi(DatasetApiResource):
METADATA_CHOICES = {"all", "only", "without"}
def get(self, tenant_id, dataset_id, document_id):
@ -567,6 +534,37 @@ class DocumentDetailApi(DatasetApiResource):
return response
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document."""
document_id = str(document_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
raise NotFound("Document Not Exists.")
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
try:
# delete document
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return 204
api.add_resource(
DocumentAddByTextApi,
@ -588,7 +586,6 @@ api.add_resource(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
)
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")

View File

@ -12,18 +12,17 @@ from services.file_service import FileService
class FileApi(WebApiResource):
@marshal_with(file_fields)
def post(self, app_model, end_user):
file = request.files["file"]
source = request.form.get("source")
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.filename:
raise FilenameNotExistsError
source = request.form.get("source")
if source not in ("datasets", None):
source = None

View File

@ -148,6 +148,8 @@ SupportedComparisonOperator = Literal[
"is not",
"empty",
"not empty",
"in",
"not in",
# for number
"=",
"",

View File

@ -23,6 +23,7 @@ from core.app.entities.task_entities import (
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
StreamEvent,
WorkflowTaskState,
)
from core.llm_generator.llm_generator import LLMGenerator
@ -180,11 +181,15 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first()
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
from_variable_selector=from_variable_selector,
event=event_type,
)
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:

View File

@ -176,7 +176,7 @@ class ProviderConfig(BasicProviderConfig):
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
required: bool = False
default: Optional[Union[int, str]] = None
default: Optional[Union[int, str, float, bool]] = None
options: Optional[list[Option]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None

View File

@ -32,7 +32,7 @@ def get_attr(*, file: File, attr: FileAttribute):
case FileAttribute.TRANSFER_METHOD:
return file.transfer_method.value
case FileAttribute.URL:
return file.remote_url
return _to_url(file)
case FileAttribute.EXTENSION:
return file.extension
case FileAttribute.RELATED_ID:

View File

@ -121,9 +121,8 @@ class TokenBufferMemory:
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
if curr_message_tokens > max_token_limit:
pruned_memory = []
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
pruned_memory.append(prompt_messages.pop(0))
prompt_messages.pop(0)
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
return prompt_messages

View File

@ -4,6 +4,7 @@ import logging
import os
from datetime import datetime, timedelta
from typing import Any, Optional, Union, cast
from urllib.parse import urlparse
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry import trace
@ -40,8 +41,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
try:
# Choose the appropriate exporter based on config type
exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
# Inspect the provided endpoint to determine its structure
parsed = urlparse(arize_phoenix_config.endpoint)
base_endpoint = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/")
if isinstance(arize_phoenix_config, ArizeConfig):
arize_endpoint = f"{arize_phoenix_config.endpoint}/v1"
arize_endpoint = f"{base_endpoint}/v1"
arize_headers = {
"api_key": arize_phoenix_config.api_key or "",
"space_id": arize_phoenix_config.space_id or "",
@ -53,7 +60,7 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
timeout=30,
)
else:
phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces"
phoenix_endpoint = f"{base_endpoint}{path}/v1/traces"
phoenix_headers = {
"api_key": arize_phoenix_config.api_key or "",
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",

View File

@ -87,7 +87,7 @@ class PhoenixConfig(BaseTracingConfig):
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com")
return validate_url_with_path(v, "https://app.phoenix.arize.com")
class LangfuseConfig(BaseTracingConfig):

View File

@ -322,7 +322,7 @@ class OpsTraceManager:
:return:
"""
# auth check
if enabled == True:
if enabled:
try:
provider_config_map[tracing_provider]
except KeyError:
@ -407,7 +407,6 @@ class TraceTask:
def __init__(
self,
trace_type: Any,
trace_id: Optional[str] = None,
message_id: Optional[str] = None,
workflow_execution: Optional[WorkflowExecution] = None,
conversation_id: Optional[str] = None,
@ -423,7 +422,7 @@ class TraceTask:
self.timer = timer
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.app_id = None
self.trace_id = None
self.kwargs = kwargs
external_trace_id = kwargs.get("external_trace_id")
if external_trace_id:

View File

@ -208,6 +208,7 @@ class BasePluginClient:
except Exception:
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
logger.error("Error in stream reponse for plugin %s", rep.__dict__)
self._handle_plugin_daemon_error(error.error_type, error.message)
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
if rep.data is None:

View File

@ -2,6 +2,8 @@ from collections.abc import Mapping
from pydantic import TypeAdapter
from extensions.ext_logging import get_request_id
class PluginDaemonError(Exception):
"""Base class for all plugin daemon errors."""
@ -11,7 +13,7 @@ class PluginDaemonError(Exception):
def __str__(self) -> str:
# returns the class name and description
return f"{self.__class__.__name__}: {self.description}"
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"
class PluginDaemonInternalError(PluginDaemonError):

View File

@ -0,0 +1,190 @@
# Clickzetta Vector Database Integration
This module provides integration with Clickzetta Lakehouse as a vector database for Dify.
## Features
- **Vector Storage**: Store and retrieve high-dimensional vectors using Clickzetta's native VECTOR type
- **Vector Search**: Efficient similarity search using HNSW algorithm
- **Full-Text Search**: Leverage Clickzetta's inverted index for powerful text search capabilities
- **Hybrid Search**: Combine vector similarity and full-text search for better results
- **Multi-language Support**: Built-in support for Chinese, English, and Unicode text processing
- **Scalable**: Leverage Clickzetta's distributed architecture for large-scale deployments
## Configuration
### Required Environment Variables
All seven configuration parameters are required:
```bash
# Authentication
CLICKZETTA_USERNAME=your_username
CLICKZETTA_PASSWORD=your_password
# Instance configuration
CLICKZETTA_INSTANCE=your_instance_id
CLICKZETTA_SERVICE=api.clickzetta.com
CLICKZETTA_WORKSPACE=your_workspace
CLICKZETTA_VCLUSTER=your_vcluster
CLICKZETTA_SCHEMA=your_schema
```
### Optional Configuration
```bash
# Batch processing
CLICKZETTA_BATCH_SIZE=100
# Full-text search configuration
CLICKZETTA_ENABLE_INVERTED_INDEX=true
CLICKZETTA_ANALYZER_TYPE=chinese # Options: keyword, english, chinese, unicode
CLICKZETTA_ANALYZER_MODE=smart # Options: max_word, smart
# Vector search configuration
CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # Options: l2_distance, cosine_distance
```
## Usage
### 1. Set Clickzetta as the Vector Store
In your Dify configuration, set:
```bash
VECTOR_STORE=clickzetta
```
### 2. Table Structure
Clickzetta will automatically create tables with the following structure:
```sql
CREATE TABLE <collection_name> (
id STRING NOT NULL,
content STRING NOT NULL,
metadata JSON,
vector VECTOR(FLOAT, <dimension>) NOT NULL,
PRIMARY KEY (id)
);
-- Vector index for similarity search
CREATE VECTOR INDEX idx_<collection_name>_vec
ON TABLE <schema>.<collection_name>(vector)
PROPERTIES (
"distance.function" = "cosine_distance",
"scalar.type" = "f32"
);
-- Inverted index for full-text search (if enabled)
CREATE INVERTED INDEX idx_<collection_name>_text
ON <schema>.<collection_name>(content)
PROPERTIES (
"analyzer" = "chinese",
"mode" = "smart"
);
```
## Full-Text Search Capabilities
Clickzetta supports advanced full-text search with multiple analyzers:
### Analyzer Types
1. **keyword**: No tokenization, treats the entire string as a single token
- Best for: Exact matching, IDs, codes
2. **english**: Designed for English text
- Features: Recognizes ASCII letters and numbers, converts to lowercase
- Best for: English content
3. **chinese**: Chinese text tokenizer
- Features: Recognizes Chinese and English characters, removes punctuation
- Best for: Chinese or mixed Chinese-English content
4. **unicode**: Multi-language tokenizer based on Unicode
- Features: Recognizes text boundaries in multiple languages
- Best for: Multi-language content
### Analyzer Modes
- **max_word**: Fine-grained tokenization (more tokens)
- **smart**: Intelligent tokenization (balanced)
### Full-Text Search Functions
- `MATCH_ALL(column, query)`: All terms must be present
- `MATCH_ANY(column, query)`: At least one term must be present
- `MATCH_PHRASE(column, query)`: Exact phrase matching
- `MATCH_PHRASE_PREFIX(column, query)`: Phrase prefix matching
- `MATCH_REGEXP(column, pattern)`: Regular expression matching
## Performance Optimization
### Vector Search
1. **Adjust exploration factor** for accuracy vs speed trade-off:
```sql
SET cz.vector.index.search.ef=64;
```
2. **Use appropriate distance functions**:
- `cosine_distance`: Best for normalized embeddings (e.g., from language models)
- `l2_distance`: Best for raw feature vectors
### Full-Text Search
1. **Choose the right analyzer**:
- Use `keyword` for exact matching
- Use language-specific analyzers for better tokenization
2. **Combine with vector search**:
- Pre-filter with full-text search for better performance
- Use hybrid search for improved relevance
## Troubleshooting
### Connection Issues
1. Verify all 7 required configuration parameters are set
2. Check network connectivity to Clickzetta service
3. Ensure the user has proper permissions on the schema
### Search Performance
1. Verify vector index exists:
```sql
SHOW INDEX FROM <schema>.<table_name>;
```
2. Check if vector index is being used:
```sql
EXPLAIN SELECT ... WHERE l2_distance(...) < threshold;
```
Look for `vector_index_search_type` in the execution plan.
### Full-Text Search Not Working
1. Verify inverted index is created
2. Check analyzer configuration matches your content language
3. Use `TOKENIZE()` function to test tokenization:
```sql
SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart'));
```
## Limitations
1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns
2. Full-text search relevance scores are not provided by Clickzetta
3. Inverted index creation may fail for very large existing tables (continue without error)
4. Index naming constraints:
- Index names must be unique within a schema
- Only one vector index can be created per column
- The implementation uses timestamps to ensure unique index names
5. A column can only have one vector index at a time
## References
- [Clickzetta Vector Search Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/vector-search.md)
- [Clickzetta Inverted Index Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/inverted-index.md)
- [Clickzetta SQL Functions](../../../../../../../yunqidoc/cn_markdown_20250526/sql_functions/)

View File

@ -0,0 +1 @@
# Clickzetta Vector Database Integration for Dify

View File

@ -0,0 +1,843 @@
import json
import logging
import queue
import threading
import uuid
from typing import TYPE_CHECKING, Any, Optional
import clickzetta # type: ignore
from pydantic import BaseModel, model_validator
if TYPE_CHECKING:
from clickzetta import Connection
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from models.dataset import Dataset
logger = logging.getLogger(__name__)
# ClickZetta Lakehouse Vector Database Configuration
class ClickzettaConfig(BaseModel):
"""
Configuration class for Clickzetta connection.
"""
username: str
password: str
instance: str
service: str = "api.clickzetta.com"
workspace: str = "quick_start"
vcluster: str = "default_ap"
schema_name: str = "dify" # Renamed to avoid shadowing BaseModel.schema
# Advanced settings
batch_size: int = 20 # Reduced batch size to avoid large SQL statements
enable_inverted_index: bool = True # Enable inverted index for full-text search
analyzer_type: str = "chinese" # Analyzer type for full-text search: keyword, english, chinese, unicode
analyzer_mode: str = "smart" # Analyzer mode: max_word, smart
vector_distance_function: str = "cosine_distance" # l2_distance or cosine_distance
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""
Validate the configuration values.
"""
if not values.get("username"):
raise ValueError("config CLICKZETTA_USERNAME is required")
if not values.get("password"):
raise ValueError("config CLICKZETTA_PASSWORD is required")
if not values.get("instance"):
raise ValueError("config CLICKZETTA_INSTANCE is required")
if not values.get("service"):
raise ValueError("config CLICKZETTA_SERVICE is required")
if not values.get("workspace"):
raise ValueError("config CLICKZETTA_WORKSPACE is required")
if not values.get("vcluster"):
raise ValueError("config CLICKZETTA_VCLUSTER is required")
if not values.get("schema_name"):
raise ValueError("config CLICKZETTA_SCHEMA is required")
return values
class ClickzettaVector(BaseVector):
"""
Clickzetta vector storage implementation.
"""
# Class-level write queue and lock for serializing writes
_write_queue: Optional[queue.Queue] = None
_write_thread: Optional[threading.Thread] = None
_write_lock = threading.Lock()
_shutdown = False
def __init__(self, collection_name: str, config: ClickzettaConfig):
super().__init__(collection_name)
self._config = config
self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name
self._connection: Optional[Connection] = None
self._init_connection()
self._init_write_queue()
def _init_connection(self):
"""Initialize Clickzetta connection."""
self._connection = clickzetta.connect(
username=self._config.username,
password=self._config.password,
instance=self._config.instance,
service=self._config.service,
workspace=self._config.workspace,
vcluster=self._config.vcluster,
schema=self._config.schema_name,
)
# Set session parameters for better string handling and performance optimization
if self._connection is not None:
with self._connection.cursor() as cursor:
# Use quote mode for string literal escaping to handle quotes better
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
logger.info("Set string literal escape mode to 'quote' for better quote handling")
# Performance optimization hints for vector operations
self._set_performance_hints(cursor)
def _set_performance_hints(self, cursor):
"""Set ClickZetta performance optimization hints for vector operations."""
try:
# Performance optimization hints for vector operations and query processing
performance_hints = [
# Vector index optimization
"SET cz.storage.parquet.vector.index.read.memory.cache = true",
"SET cz.storage.parquet.vector.index.read.local.cache = false",
# Query optimization
"SET cz.sql.table.scan.push.down.filter = true",
"SET cz.sql.table.scan.enable.ensure.filter = true",
"SET cz.storage.always.prefetch.internal = true",
"SET cz.optimizer.generate.columns.always.valid = true",
"SET cz.sql.index.prewhere.enabled = true",
# Storage optimization
"SET cz.storage.parquet.enable.io.prefetch = false",
"SET cz.optimizer.enable.mv.rewrite = false",
"SET cz.sql.dump.as.lz4 = true",
"SET cz.optimizer.limited.optimization.naive.query = true",
"SET cz.sql.table.scan.enable.push.down.log = false",
"SET cz.storage.use.file.format.local.stats = false",
"SET cz.storage.local.file.object.cache.level = all",
# Job execution optimization
"SET cz.sql.job.fast.mode = true",
"SET cz.storage.parquet.non.contiguous.read = true",
"SET cz.sql.compaction.after.commit = true",
]
for hint in performance_hints:
cursor.execute(hint)
logger.info(
"Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints)
)
except Exception:
# Catch any errors setting performance hints but continue with defaults
logger.exception("Failed to set some performance hints, continuing with default settings")
@classmethod
def _init_write_queue(cls):
"""Initialize the write queue and worker thread."""
with cls._write_lock:
if cls._write_queue is None:
cls._write_queue = queue.Queue()
cls._write_thread = threading.Thread(target=cls._write_worker, daemon=True)
cls._write_thread.start()
logger.info("Started Clickzetta write worker thread")
@classmethod
def _write_worker(cls):
"""Worker thread that processes write tasks sequentially."""
while not cls._shutdown:
try:
# Get task from queue with timeout
if cls._write_queue is not None:
task = cls._write_queue.get(timeout=1)
if task is None: # Shutdown signal
break
# Execute the write task
func, args, kwargs, result_queue = task
try:
result = func(*args, **kwargs)
result_queue.put((True, result))
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
logger.exception("Write task failed")
result_queue.put((False, e))
finally:
cls._write_queue.task_done()
else:
break
except queue.Empty:
continue
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
logger.exception("Write worker error")
def _execute_write(self, func, *args, **kwargs):
"""Execute a write operation through the queue."""
if ClickzettaVector._write_queue is None:
raise RuntimeError("Write queue not initialized")
result_queue: queue.Queue[tuple[bool, Any]] = queue.Queue()
ClickzettaVector._write_queue.put((func, args, kwargs, result_queue))
# Wait for result
success, result = result_queue.get()
if not success:
raise result
return result
def get_type(self) -> str:
"""Return the vector database type."""
return "clickzetta"
def _ensure_connection(self) -> "Connection":
"""Ensure connection is available and return it."""
if self._connection is None:
raise RuntimeError("Database connection not initialized")
return self._connection
def _table_exists(self) -> bool:
"""Check if the table exists."""
try:
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}")
return True
except (RuntimeError, ValueError) as e:
if "table or view not found" in str(e).lower():
return False
else:
# Re-raise if it's a different error
raise
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""Create the collection and add initial documents."""
# Execute table creation through write queue to avoid concurrent conflicts
self._execute_write(self._create_table_and_indexes, embeddings)
# Add initial texts
if texts:
self.add_texts(texts, embeddings, **kwargs)
def _create_table_and_indexes(self, embeddings: list[list[float]]):
"""Create table and indexes (executed in write worker thread)."""
# Check if table already exists to avoid unnecessary index creation
if self._table_exists():
logger.info("Table %s.%s already exists, skipping creation", self._config.schema_name, self._table_name)
return
# Create table with vector and metadata columns
dimension = len(embeddings[0]) if embeddings else 768
create_table_sql = f"""
CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} (
id STRING NOT NULL COMMENT 'Unique document identifier',
{Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
{Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes',
{Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
'High-dimensional embedding vector for semantic similarity search',
PRIMARY KEY (id)
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
"""
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(create_table_sql)
logger.info("Created table %s.%s", self._config.schema_name, self._table_name)
# Create vector index
self._create_vector_index(cursor)
# Create inverted index for full-text search if enabled
if self._config.enable_inverted_index:
self._create_inverted_index(cursor)
def _create_vector_index(self, cursor):
"""Create HNSW vector index for similarity search."""
# Use a fixed index name based on table and column name
index_name = f"idx_{self._table_name}_vector"
# First check if an index already exists on this column
try:
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
existing_indexes = cursor.fetchall()
for idx in existing_indexes:
# Check if vector index already exists on the embedding column
if Field.VECTOR.value in str(idx).lower():
logger.info("Vector index already exists on column %s", Field.VECTOR.value)
return
except (RuntimeError, ValueError) as e:
logger.warning("Failed to check existing indexes: %s", e)
index_sql = f"""
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value})
PROPERTIES (
"distance.function" = "{self._config.vector_distance_function}",
"scalar.type" = "f32",
"m" = "16",
"ef.construction" = "128"
)
"""
try:
cursor.execute(index_sql)
logger.info("Created vector index: %s", index_name)
except (RuntimeError, ValueError) as e:
error_msg = str(e).lower()
if "already exists" in error_msg or "already has index" in error_msg or "with the same type" in error_msg:
logger.info("Vector index already exists: %s", e)
else:
logger.exception("Failed to create vector index")
raise
def _create_inverted_index(self, cursor):
"""Create inverted index for full-text search."""
# Use a fixed index name based on table name to avoid duplicates
index_name = f"idx_{self._table_name}_text"
# Check if an inverted index already exists on this column
try:
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
existing_indexes = cursor.fetchall()
for idx in existing_indexes:
idx_str = str(idx).lower()
# More precise check: look for inverted index specifically on the content column
if (
"inverted" in idx_str
and Field.CONTENT_KEY.value.lower() in idx_str
and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)
):
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx)
return
except (RuntimeError, ValueError) as e:
logger.warning("Failed to check existing indexes: %s", e)
index_sql = f"""
CREATE INVERTED INDEX IF NOT EXISTS {index_name}
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value})
PROPERTIES (
"analyzer" = "{self._config.analyzer_type}",
"mode" = "{self._config.analyzer_mode}"
)
"""
try:
cursor.execute(index_sql)
logger.info("Created inverted index: %s", index_name)
except (RuntimeError, ValueError) as e:
error_msg = str(e).lower()
# Handle ClickZetta specific error messages
if (
"already exists" in error_msg
or "already has index" in error_msg
or "with the same type" in error_msg
or "cannot create inverted index" in error_msg
) and "already has index" in error_msg:
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value)
# Try to get the existing index name for logging
try:
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
existing_indexes = cursor.fetchall()
for idx in existing_indexes:
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower():
logger.info("Found existing inverted index: %s", idx)
break
except (RuntimeError, ValueError):
pass
else:
logger.warning("Failed to create inverted index: %s", e)
# Continue without inverted index - full-text search will fall back to LIKE
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""Add documents with embeddings to the collection."""
if not documents:
return
batch_size = self._config.batch_size
total_batches = (len(documents) + batch_size - 1) // batch_size
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
# Execute batch insert through write queue
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
def _insert_batch(
self,
batch_docs: list[Document],
batch_embeddings: list[list[float]],
batch_index: int,
batch_size: int,
total_batches: int,
):
"""Insert a batch of documents using parameterized queries (executed in write worker thread)."""
if not batch_docs or not batch_embeddings:
logger.warning("Empty batch provided, skipping insertion")
return
if len(batch_docs) != len(batch_embeddings):
logger.error("Mismatch between docs (%d) and embeddings (%d)", len(batch_docs), len(batch_embeddings))
return
# Prepare data for parameterized insertion
data_rows = []
vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768
for doc, embedding in zip(batch_docs, batch_embeddings):
# Optimized: minimal checks for common case, fallback for edge cases
metadata = doc.metadata if doc.metadata else {}
if not isinstance(metadata, dict):
metadata = {}
doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))
# Fast path for JSON serialization
try:
metadata_json = json.dumps(metadata, ensure_ascii=True)
except (TypeError, ValueError):
logger.warning("JSON serialization failed, using empty dict")
metadata_json = "{}"
content = doc.page_content or ""
# According to ClickZetta docs, vector should be formatted as array string
# for external systems: '[1.0, 2.0, 3.0]'
vector_str = "[" + ",".join(map(str, embedding)) + "]"
data_rows.append([doc_id, content, metadata_json, vector_str])
# Check if we have any valid data to insert
if not data_rows:
logger.warning("No valid documents to insert in batch %d/%d", batch_index // batch_size + 1, total_batches)
return
# Use parameterized INSERT with executemany for better performance and security
# Cast JSON and VECTOR in SQL, pass raw data as parameters
columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}"
insert_sql = (
f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) "
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
)
connection = self._ensure_connection()
with connection.cursor() as cursor:
try:
# Set session-level hints for batch insert operations
# Note: executemany doesn't support hints parameter, so we set them as session variables
cursor.execute("SET cz.sql.job.fast.mode = true")
cursor.execute("SET cz.sql.compaction.after.commit = true")
cursor.execute("SET cz.storage.always.prefetch.internal = true")
cursor.executemany(insert_sql, data_rows)
logger.info(
"Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)",
batch_index // batch_size + 1,
total_batches,
len(data_rows),
vector_dimension,
)
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows))
logger.exception("SQL template: %s", insert_sql)
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
raise
def text_exists(self, id: str) -> bool:
"""Check if a document exists by ID."""
safe_id = self._safe_doc_id(id)
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", [safe_id]
)
result = cursor.fetchone()
return result[0] > 0 if result else False
def delete_by_ids(self, ids: list[str]) -> None:
"""Delete documents by IDs."""
if not ids:
return
# Check if table exists before attempting delete
if not self._table_exists():
logger.warning("Table %s.%s does not exist, skipping delete", self._config.schema_name, self._table_name)
return
# Execute delete through write queue
self._execute_write(self._delete_by_ids_impl, ids)
def _delete_by_ids_impl(self, ids: list[str]) -> None:
"""Implementation of delete by IDs (executed in write worker thread)."""
safe_ids = [self._safe_doc_id(id) for id in ids]
# Create properly escaped string literals for SQL
id_list = ",".join(f"'{id}'" for id in safe_ids)
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})"
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(sql)
def delete_by_metadata_field(self, key: str, value: str) -> None:
"""Delete documents by metadata field."""
# Check if table exists before attempting delete
if not self._table_exists():
logger.warning("Table %s.%s does not exist, skipping delete", self._config.schema_name, self._table_name)
return
# Execute delete through write queue
self._execute_write(self._delete_by_metadata_field_impl, key, value)
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None:
"""Implementation of delete by metadata field (executed in write worker thread)."""
connection = self._ensure_connection()
with connection.cursor() as cursor:
# Using JSON path to filter with parameterized query
# Note: JSON path requires literal key name, cannot be parameterized
# Use json_extract_string function for ClickZetta compatibility
sql = (
f"DELETE FROM {self._config.schema_name}.{self._table_name} "
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?"
)
cursor.execute(sql, [value])
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Search for documents by vector similarity."""
top_k = kwargs.get("top_k", 10)
score_threshold = kwargs.get("score_threshold", 0.0)
document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {})
# Build filter clause
filter_clauses = []
if document_ids_filter:
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
# No need for dataset_id filter since each dataset has its own table
# Add distance threshold based on distance function
vector_dimension = len(query_vector)
if self._config.vector_distance_function == "cosine_distance":
# For cosine distance, smaller is better (0 = identical, 2 = opposite)
distance_func = "COSINE_DISTANCE"
if score_threshold > 0:
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
filter_clauses.append(
f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}"
)
else:
# For L2 distance, smaller is better
distance_func = "L2_DISTANCE"
if score_threshold > 0:
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}")
where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1"
# Execute vector search query
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
search_sql = f"""
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value},
{distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance
FROM {self._config.schema_name}.{self._table_name}
WHERE {where_clause}
ORDER BY distance
LIMIT {top_k}
"""
documents = []
connection = self._ensure_connection()
with connection.cursor() as cursor:
# Use hints parameter for vector search optimization
search_hints = {
"hints": {
"sdk.job.timeout": 60, # Increase timeout for vector search
"cz.sql.job.fast.mode": True,
"cz.storage.parquet.vector.index.read.memory.cache": True,
}
}
cursor.execute(search_sql, parameters=search_hints)
results = cursor.fetchall()
for row in results:
# Parse metadata from JSON string (may be double-encoded)
try:
if row[2]:
metadata = json.loads(row[2])
# If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str):
metadata = json.loads(metadata)
if not isinstance(metadata, dict):
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, TypeError) as e:
logger.exception("JSON parsing failed")
# Fallback: extract document_id with regex
import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ""))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
# Ensure required fields are set
metadata["doc_id"] = row[0] # segment id
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id
# Add score based on distance
if self._config.vector_distance_function == "cosine_distance":
metadata["score"] = 1 - (row[3] / 2)
else:
metadata["score"] = 1 / (1 + row[3])
doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Search for documents using full-text search with inverted index."""
if not self._config.enable_inverted_index:
logger.warning("Full-text search is not enabled. Enable inverted index in config.")
return []
top_k = kwargs.get("top_k", 10)
document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {})
# Build filter clause
filter_clauses = []
if document_ids_filter:
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
# No need for dataset_id filter since each dataset has its own table
# Use match_all function for full-text search
# match_all requires all terms to be present
# Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')")
where_clause = " AND ".join(filter_clauses)
# Execute full-text search query
search_sql = f"""
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
FROM {self._config.schema_name}.{self._table_name}
WHERE {where_clause}
LIMIT {top_k}
"""
documents = []
connection = self._ensure_connection()
with connection.cursor() as cursor:
try:
# Use hints parameter for full-text search optimization
fulltext_hints = {
"hints": {
"sdk.job.timeout": 30, # Timeout for full-text search
"cz.sql.job.fast.mode": True,
"cz.sql.index.prewhere.enabled": True,
}
}
cursor.execute(search_sql, parameters=fulltext_hints)
results = cursor.fetchall()
for row in results:
# Parse metadata from JSON string (may be double-encoded)
try:
if row[2]:
metadata = json.loads(row[2])
# If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str):
metadata = json.loads(metadata)
if not isinstance(metadata, dict):
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, TypeError) as e:
logger.exception("JSON parsing failed")
# Fallback: extract document_id with regex
import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ""))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
# Ensure required fields are set
metadata["doc_id"] = row[0] # segment id
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id
# Add a relevance score for full-text search
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc)
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
logger.exception("Full-text search failed")
# Fallback to LIKE search if full-text search fails
return self._search_by_like(query, **kwargs)
return documents
def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]:
"""Fallback search using LIKE operator."""
top_k = kwargs.get("top_k", 10)
document_ids_filter = kwargs.get("document_ids_filter")
# Handle filter parameter from canvas (workflow)
filter_param = kwargs.get("filter", {})
# Build filter clause
filter_clauses = []
if document_ids_filter:
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
# Use json_extract_string function for ClickZetta compatibility
filter_clauses.append(
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
)
# No need for dataset_id filter since each dataset has its own table
# Use simple quote escaping for LIKE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'")
where_clause = " AND ".join(filter_clauses)
search_sql = f"""
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
FROM {self._config.schema_name}.{self._table_name}
WHERE {where_clause}
LIMIT {top_k}
"""
documents = []
connection = self._ensure_connection()
with connection.cursor() as cursor:
# Use hints parameter for LIKE search optimization
like_hints = {
"hints": {
"sdk.job.timeout": 20, # Timeout for LIKE search
"cz.sql.job.fast.mode": True,
}
}
cursor.execute(search_sql, parameters=like_hints)
results = cursor.fetchall()
for row in results:
# Parse metadata from JSON string (may be double-encoded)
try:
if row[2]:
metadata = json.loads(row[2])
# If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str):
metadata = json.loads(metadata)
if not isinstance(metadata, dict):
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, TypeError) as e:
logger.exception("JSON parsing failed")
# Fallback: extract document_id with regex
import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ""))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
# Ensure required fields are set
metadata["doc_id"] = row[0] # segment id
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
if "document_id" not in metadata:
metadata["document_id"] = row[0] # fallback to segment id
metadata["score"] = 0.5 # Lower score for LIKE search
doc = Document(page_content=row[1], metadata=metadata)
documents.append(doc)
return documents
def delete(self) -> None:
"""Delete the entire collection."""
connection = self._ensure_connection()
with connection.cursor() as cursor:
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
def _format_vector_simple(self, vector: list[float]) -> str:
"""Simple vector formatting for SQL queries."""
return ",".join(map(str, vector))
def _safe_doc_id(self, doc_id: str) -> str:
"""Ensure doc_id is safe for SQL and doesn't contain special characters."""
if not doc_id:
return str(uuid.uuid4())
# Remove or replace potentially problematic characters
safe_id = str(doc_id)
# Only allow alphanumeric, hyphens, underscores
safe_id = "".join(c for c in safe_id if c.isalnum() or c in "-_")
if not safe_id: # If all characters were removed
return str(uuid.uuid4())
return safe_id[:255] # Limit length
class ClickzettaVectorFactory(AbstractVectorFactory):
"""Factory for creating Clickzetta vector instances."""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
"""Initialize a Clickzetta vector instance."""
# Get configuration from environment variables or dataset config
config = ClickzettaConfig(
username=dify_config.CLICKZETTA_USERNAME or "",
password=dify_config.CLICKZETTA_PASSWORD or "",
instance=dify_config.CLICKZETTA_INSTANCE or "",
service=dify_config.CLICKZETTA_SERVICE or "api.clickzetta.com",
workspace=dify_config.CLICKZETTA_WORKSPACE or "quick_start",
vcluster=dify_config.CLICKZETTA_VCLUSTER or "default_ap",
schema_name=dify_config.CLICKZETTA_SCHEMA or "dify",
batch_size=dify_config.CLICKZETTA_BATCH_SIZE or 100,
enable_inverted_index=dify_config.CLICKZETTA_ENABLE_INVERTED_INDEX or True,
analyzer_type=dify_config.CLICKZETTA_ANALYZER_TYPE or "chinese",
analyzer_mode=dify_config.CLICKZETTA_ANALYZER_MODE or "smart",
vector_distance_function=dify_config.CLICKZETTA_VECTOR_DISTANCE_FUNCTION or "cosine_distance",
)
# Use dataset collection name as table name
collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower()
return ClickzettaVector(collection_name=collection_name, config=config)

View File

@ -7,6 +7,7 @@ from urllib.parse import urlparse
import requests
from elasticsearch import Elasticsearch
from flask import current_app
from packaging.version import parse as parse_version
from pydantic import BaseModel, model_validator
from core.rag.datasource.vdb.field import Field
@ -149,7 +150,7 @@ class ElasticSearchVector(BaseVector):
return cast(str, info["version"]["number"])
def _check_version(self):
if self._version < "8.0.0":
if parse_version(self._version) < parse_version("8.0.0"):
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
def get_type(self) -> str:

View File

@ -1,5 +1,6 @@
import json
import logging
import math
from typing import Any, Optional
import tablestore # type: ignore
@ -22,6 +23,7 @@ class TableStoreConfig(BaseModel):
access_key_secret: Optional[str] = None
instance_name: Optional[str] = None
endpoint: Optional[str] = None
normalize_full_text_bm25_score: Optional[bool] = False
@model_validator(mode="before")
@classmethod
@ -47,6 +49,7 @@ class TableStoreVector(BaseVector):
config.access_key_secret,
config.instance_name,
)
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
self._table_name = f"{collection_name}"
self._index_name = f"{collection_name}_idx"
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
@ -131,8 +134,8 @@ class TableStoreVector(BaseVector):
filtered_list = None
if document_ids_filter:
filtered_list = ["document_id=" + item for item in document_ids_filter]
return self._search_by_full_text(query, filtered_list, top_k)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._search_by_full_text(query, filtered_list, top_k, score_threshold)
def delete(self) -> None:
self._delete_table_if_exist()
@ -318,7 +321,19 @@ class TableStoreVector(BaseVector):
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
@staticmethod
def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float:
"""
Args:
score: BM25 search score.
k: decay factor, the larger the k, the steeper the low score end
"""
normalized_score = 1 - math.exp(-k * score)
return max(0.0, min(1.0, normalized_score))
def _search_by_full_text(
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
@ -339,15 +354,27 @@ class TableStoreVector(BaseVector):
documents = []
for search_hit in search_response.search_hits:
score = None
if self._normalize_full_text_bm25_score:
score = self._normalize_score_exp_decay(search_hit.score)
# skip when score is below threshold and use normalize score
if score and score <= score_threshold:
continue
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
vector_str = ots_column_map.get(Field.VECTOR.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}
vector_str = ots_column_map.get(Field.VECTOR.value)
vector = json.loads(vector_str) if vector_str else None
if score:
metadata["score"] = score
documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
@ -355,6 +382,8 @@ class TableStoreVector(BaseVector):
metadata=metadata,
)
)
if self._normalize_full_text_bm25_score:
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
@ -375,5 +404,6 @@ class TableStoreVectorFactory(AbstractVectorFactory):
instance_name=dify_config.TABLESTORE_INSTANCE_NAME,
access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID,
access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET,
normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE,
),
)

View File

@ -246,6 +246,10 @@ class TencentVector(BaseVector):
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
if not self._enable_hybrid_search:
return []
res = self._client.hybrid_search(
@ -269,6 +273,7 @@ class TencentVector(BaseVector):
),
retrieve_vector=False,
limit=kwargs.get("top_k", 4),
filter=filter,
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)

View File

@ -172,6 +172,10 @@ class Vector:
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
return MatrixoneVectorFactory
case VectorType.CLICKZETTA:
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
return ClickzettaVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@ -30,3 +30,4 @@ class VectorType(StrEnum):
TABLESTORE = "tablestore"
HUAWEI_CLOUD = "huawei_cloud"
MATRIXONE = "matrixone"
CLICKZETTA = "clickzetta"

View File

@ -13,6 +13,8 @@ SupportedComparisonOperator = Literal[
"is not",
"empty",
"not empty",
"in",
"not in",
# for number
"=",
"",

View File

@ -1,5 +1,6 @@
import json
import logging
import operator
from typing import Any, Optional, cast
import requests
@ -130,13 +131,15 @@ class NotionExtractor(BaseExtractor):
data[property_name] = value
row_dict = {k: v for k, v in data.items() if v}
row_content = ""
for key, value in row_dict.items():
for key, value in sorted(row_dict.items(), key=operator.itemgetter(0)):
if isinstance(value, dict):
value_dict = {k: v for k, v in value.items() if v}
value_content = "".join(f"{k}:{v} " for k, v in value_dict.items())
row_content = row_content + f"{key}:{value_content}\n"
else:
row_content = row_content + f"{key}:{value}\n"
if "url" in result:
row_content = row_content + f"Row Page URL:{result.get('url', '')}\n"
database_content.append(row_content)
has_more = response_data.get("has_more", False)

View File

@ -62,7 +62,7 @@ class WordExtractor(BaseExtractor):
def extract(self) -> list[Document]:
"""Load given path as single page."""
content = self.parse_docx(self.file_path, "storage")
content = self.parse_docx(self.file_path)
return [
Document(
page_content=content,
@ -189,23 +189,8 @@ class WordExtractor(BaseExtractor):
paragraph_content.append(run.text)
return "".join(paragraph_content).strip()
def _parse_paragraph(self, paragraph, image_map):
paragraph_content = []
for run in paragraph.runs:
if run.element.xpath(".//a:blip"):
for blip in run.element.xpath(".//a:blip"):
embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
if embed_id:
rel_target = run.part.rels[embed_id].target_ref
if rel_target in image_map:
paragraph_content.append(image_map[rel_target])
if run.text.strip():
paragraph_content.append(run.text.strip())
return " ".join(paragraph_content) if paragraph_content else ""
def parse_docx(self, docx_path, image_folder):
def parse_docx(self, docx_path):
doc = DocxDocument(docx_path)
os.makedirs(image_folder, exist_ok=True)
content = []

View File

@ -5,14 +5,13 @@ from __future__ import annotations
from typing import Any, Optional
from core.model_manager import ModelInstance
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
from core.rag.splitter.text_splitter import (
TS,
Collection,
Literal,
RecursiveCharacterTextSplitter,
Set,
TokenTextSplitter,
Union,
)
@ -45,14 +44,6 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
return [len(text) for text in texts]
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_character_encoder, **kwargs)

View File

@ -20,9 +20,6 @@ class Tool(ABC):
The base class of a tool
"""
entity: ToolEntity
runtime: ToolRuntime
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
self.entity = entity
self.runtime = runtime

View File

@ -37,12 +37,12 @@ class LocaltimeToTimestampTool(BuiltinTool):
@staticmethod
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
try:
if local_tz is None:
local_tz = datetime.now().astimezone().tzinfo
if isinstance(local_tz, str):
local_tz = pytz.timezone(local_tz)
local_time = datetime.strptime(localtime, time_format)
localtime = local_tz.localize(local_time) # type: ignore
if local_tz is None:
localtime = local_time.astimezone() # type: ignore
elif isinstance(local_tz, str):
local_tz = pytz.timezone(local_tz)
localtime = local_tz.localize(local_time) # type: ignore
timestamp = int(localtime.timestamp()) # type: ignore
return timestamp
except Exception as e:

View File

@ -27,7 +27,7 @@ class TimezoneConversionTool(BuiltinTool):
target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore
if not target_time:
yield self.create_text_message(
f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}"
f"Invalid datetime and timezone: {current_time},{current_timezone},{target_timezone}"
)
return

View File

@ -20,8 +20,6 @@ class BuiltinTool(Tool):
:param meta: the meta data of a tool call processing
"""
provider: str
def __init__(self, provider: str, **kwargs):
super().__init__(**kwargs)
self.provider = provider

View File

@ -1,7 +1,8 @@
import json
from collections.abc import Generator
from dataclasses import dataclass
from os import getenv
from typing import Any, Optional
from typing import Any, Optional, Union
from urllib.parse import urlencode
import httpx
@ -20,10 +21,21 @@ API_TOOL_DEFAULT_TIMEOUT = (
)
class ApiTool(Tool):
api_bundle: ApiToolBundle
provider_id: str
@dataclass
class ParsedResponse:
"""Represents a parsed HTTP response with type information"""
content: Union[str, dict]
is_json: bool
def to_string(self) -> str:
"""Convert response to string format for credential validation"""
if isinstance(self.content, dict):
return json.dumps(self.content, ensure_ascii=False)
return str(self.content)
class ApiTool(Tool):
"""
Api tool
"""
@ -61,7 +73,9 @@ class ApiTool(Tool):
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
# validate response
return self.validate_and_parse_response(response)
parsed_response = self.validate_and_parse_response(response)
# For credential validation, always return as string
return parsed_response.to_string()
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.API
@ -115,23 +129,36 @@ class ApiTool(Tool):
return headers
def validate_and_parse_response(self, response: httpx.Response) -> str:
def validate_and_parse_response(self, response: httpx.Response) -> ParsedResponse:
"""
validate the response
validate the response and return parsed content with type information
:return: ParsedResponse with content and is_json flag
"""
if isinstance(response, httpx.Response):
if response.status_code >= 400:
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
if not response.content:
return "Empty response from the tool, please check your parameters and try again."
return ParsedResponse(
"Empty response from the tool, please check your parameters and try again.", False
)
# Check content type
content_type = response.headers.get("content-type", "").lower()
is_json_content_type = "application/json" in content_type
# Try to parse as JSON
try:
response = response.json()
try:
return json.dumps(response, ensure_ascii=False)
except Exception:
return json.dumps(response)
json_response = response.json()
# If content-type indicates JSON, return as JSON object
if is_json_content_type:
return ParsedResponse(json_response, True)
else:
# If content-type doesn't indicate JSON, treat as text regardless of content
return ParsedResponse(response.text, False)
except Exception:
return response.text
# Not valid JSON, return as text
return ParsedResponse(response.text, False)
else:
raise ValueError(f"Invalid response type {type(response)}")
@ -372,7 +399,14 @@ class ApiTool(Tool):
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
# validate response
response = self.validate_and_parse_response(response)
parsed_response = self.validate_and_parse_response(response)
# assemble invoke message
yield self.create_text_message(response)
# assemble invoke message based on response type
if parsed_response.is_json and isinstance(parsed_response.content, dict):
yield self.create_json_message(parsed_response.content)
else:
# Convert to string if needed and create text message
text_response = (
parsed_response.content if isinstance(parsed_response.content, str) else str(parsed_response.content)
)
yield self.create_text_message(text_response)

View File

@ -8,23 +8,16 @@ from core.mcp.mcp_client import MCPClient
from core.mcp.types import ImageContent, TextContent
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
class MCPTool(Tool):
tenant_id: str
icon: str
runtime_parameters: Optional[list[ToolParameter]]
server_url: str
provider_id: str
def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.runtime_parameters = None
self.server_url = server_url
self.provider_id = provider_id

View File

@ -9,11 +9,6 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
class PluginTool(Tool):
tenant_id: str
icon: str
plugin_unique_identifier: str
runtime_parameters: Optional[list[ToolParameter]]
def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
) -> None:
@ -21,7 +16,7 @@ class PluginTool(Tool):
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
self.runtime_parameters = None
self.runtime_parameters: Optional[list[ToolParameter]] = None
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.PLUGIN

View File

@ -29,7 +29,7 @@ from core.tools.errors import (
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.enums import CreatorUserRole
@ -247,7 +247,8 @@ class ToolEngine:
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
result += json.dumps(
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
ensure_ascii=False,
)
else:
result += str(response.message)

View File

@ -7,6 +7,7 @@ from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
import sqlalchemy as sa
from pydantic import TypeAdapter
from yarl import URL
@ -616,7 +617,7 @@ class ToolManager:
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
@classmethod

View File

@ -20,8 +20,6 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas
class DatasetRetrieverTool(Tool):
retrieval_tool: DatasetRetrieverBaseTool
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
super().__init__(entity, runtime)
self.retrieval_tool = retrieval_tool

View File

@ -1,7 +1,14 @@
import logging
from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from typing import Optional
from typing import Optional, cast
from uuid import UUID
import numpy as np
import pytz
from flask_login import current_user
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
@ -10,6 +17,41 @@ from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__)
def safe_json_value(v):
if isinstance(v, datetime):
tz_name = getattr(current_user, "timezone", None) if current_user is not None else None
if not tz_name:
tz_name = "UTC"
return v.astimezone(pytz.timezone(tz_name)).isoformat()
elif isinstance(v, date):
return v.isoformat()
elif isinstance(v, UUID):
return str(v)
elif isinstance(v, Decimal):
return float(v)
elif isinstance(v, bytes):
try:
return v.decode("utf-8")
except UnicodeDecodeError:
return v.hex()
elif isinstance(v, memoryview):
return v.tobytes().hex()
elif isinstance(v, np.ndarray):
return v.tolist()
elif isinstance(v, dict):
return safe_json_dict(v)
elif isinstance(v, list | tuple | set):
return [safe_json_value(i) for i in v]
else:
return v
def safe_json_dict(d):
if not isinstance(d, dict):
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
return {k: safe_json_value(v) for k, v in d.items()}
class ToolFileMessageTransformer:
@classmethod
def transform_tool_invoke_messages(
@ -113,6 +155,12 @@ class ToolFileMessageTransformer:
)
else:
yield message
elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
json_msg.json_object = safe_json_value(json_msg.json_object)
yield message
else:
yield message

View File

@ -25,15 +25,6 @@ logger = logging.getLogger(__name__)
class WorkflowTool(Tool):
workflow_app_id: str
version: str
workflow_entities: dict[str, Any]
workflow_call_depth: int
thread_pool_id: Optional[str] = None
workflow_as_tool_id: str
label: str
"""
Workflow tool.
"""

View File

@ -119,6 +119,13 @@ class ObjectSegment(Segment):
class ArraySegment(Segment):
@property
def text(self) -> str:
# Return empty string for empty arrays instead of "[]"
if not self.value:
return ""
return super().text
@property
def markdown(self) -> str:
items = []
@ -155,6 +162,9 @@ class ArrayStringSegment(ArraySegment):
@property
def text(self) -> str:
# Return empty string for empty arrays instead of "[]"
if not self.value:
return ""
return json.dumps(self.value, ensure_ascii=False)

View File

@ -109,7 +109,7 @@ class SegmentType(StrEnum):
elif array_validation == ArrayValidation.FIRST:
return element_type.is_valid(value[0])
else:
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value)
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
"""
@ -152,7 +152,7 @@ class SegmentType(StrEnum):
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
# ARRAY_ANY does not have correpond element type.
# ARRAY_ANY does not have corresponding element type.
SegmentType.ARRAY_STRING: SegmentType.STRING,
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,

View File

@ -597,7 +597,7 @@ def _extract_text_from_vtt(vtt_bytes: bytes) -> str:
for i in range(1, len(raw_results)):
spk, txt = raw_results[i]
if spk == None:
if spk is None:
merged_results.append((None, current_text))
continue

View File

@ -91,7 +91,7 @@ class Executor:
self.auth = node_data.authorization
self.timeout = timeout
self.ssl_verify = node_data.ssl_verify
self.params = []
self.params = None
self.headers = {}
self.content = None
self.files = None
@ -139,7 +139,8 @@ class Executor:
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
)
self.params = result
if result:
self.params = result
def _init_headers(self):
"""
@ -277,6 +278,22 @@ class Executor:
elif self.auth.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key or ""
# Handle Content-Type for multipart/form-data requests
# Fix for issue #22880: Missing boundary when using multipart/form-data
body = self.node_data.body
if body and body.type == "form-data":
# For multipart/form-data with files, let httpx handle the boundary automatically
# by not setting Content-Type header when files are present
if not self.files or all(f[0] == "__multipart_placeholder__" for f in self.files):
# Only set Content-Type when there are no actual files
# This ensures httpx generates the correct boundary
if "content-type" not in (k.lower() for k in headers):
headers["Content-Type"] = "multipart/form-data"
elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE:
# Set Content-Type for other body types
if "content-type" not in (k.lower() for k in headers):
headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
return headers
def _validate_and_parse_response(self, response: httpx.Response) -> Response:
@ -384,15 +401,24 @@ class Executor:
# '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file.
# This prevents logging meaningless placeholder entries.
if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files):
for key, (filename, content, mime_type) in self.files:
for file_entry in self.files:
# file_entry should be (key, (filename, content, mime_type)), but handle edge cases
if len(file_entry) != 2 or not isinstance(file_entry[1], tuple) or len(file_entry[1]) < 2:
continue # skip malformed entries
key = file_entry[0]
content = file_entry[1][1]
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
# decode content
try:
body_string += content.decode("utf-8")
except UnicodeDecodeError:
# fix: decode binary content
pass
# decode content safely
if isinstance(content, bytes):
try:
body_string += content.decode("utf-8")
except UnicodeDecodeError:
body_string += content.decode("utf-8", errors="replace")
elif isinstance(content, str):
body_string += content
else:
body_string += f"[Unsupported content type: {type(content).__name__}]"
body_string += "\r\n"
body_string += f"--{boundary}--\r\n"
elif self.node_data.body:

View File

@ -74,6 +74,8 @@ SupportedComparisonOperator = Literal[
"is not",
"empty",
"not empty",
"in",
"not in",
# for number
"=",
"",

View File

@ -602,6 +602,28 @@ class KnowledgeRetrievalNode(BaseNode):
**{key: metadata_name, key_value: f"%{value}"}
)
)
case "in":
if isinstance(value, str):
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
escaped_value_str = ",".join(escaped_values)
else:
escaped_value_str = str(value)
filters.append(
(text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params(
**{key: metadata_name, key_value: escaped_value_str}
)
)
case "not in":
if isinstance(value, str):
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
escaped_value_str = ",".join(escaped_values)
else:
escaped_value_str = str(value)
filters.append(
(text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params(
**{key: metadata_name, key_value: escaped_value_str}
)
)
case "=" | "is":
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')

View File

@ -3,7 +3,7 @@ import io
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
@ -33,12 +33,10 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelFeature,
ModelPropertyKey,
ModelType,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -1006,21 +1004,6 @@ class LLMNode(BaseNode):
)
return saved_file
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
"""
Fetch model schema
"""
model_name = self._node_data.model.name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
)
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_schema
@staticmethod
def fetch_structured_output_schema(
*,

View File

@ -318,6 +318,33 @@ class ToolNode(BaseNode):
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": message.message.text,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])

View File

@ -136,6 +136,8 @@ def init_app(app: DifyApp):
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
from opentelemetry.instrumentation.celery import CeleryInstrumentor
from opentelemetry.instrumentation.flask import FlaskInstrumentor
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.instrumentation.requests import RequestsInstrumentor
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider
from opentelemetry.propagate import set_global_textmap
@ -234,6 +236,8 @@ def init_app(app: DifyApp):
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
instrument_exception_logging()
init_sqlalchemy_instrumentor(app)
RedisInstrumentor().instrument()
RequestsInstrumentor().instrument()
atexit.register(shutdown_tracer)

View File

@ -69,6 +69,19 @@ class Storage:
from extensions.storage.supabase_storage import SupabaseStorage
return SupabaseStorage
case StorageType.CLICKZETTA_VOLUME:
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
ClickZettaVolumeConfig,
ClickZettaVolumeStorage,
)
def create_clickzetta_volume_storage():
# ClickZettaVolumeConfig will automatically read from environment variables
# and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set
volume_config = ClickZettaVolumeConfig()
return ClickZettaVolumeStorage(volume_config)
return create_clickzetta_volume_storage
case _:
raise ValueError(f"unsupported storage type {storage_type}")

View File

@ -0,0 +1,5 @@
"""ClickZetta Volume storage implementation."""
from .clickzetta_volume_storage import ClickZettaVolumeStorage
__all__ = ["ClickZettaVolumeStorage"]

View File

@ -0,0 +1,530 @@
"""ClickZetta Volume Storage Implementation
This module provides storage backend using ClickZetta Volume functionality.
Supports Table Volume, User Volume, and External Volume types.
"""
import logging
import os
import tempfile
from collections.abc import Generator
from io import BytesIO
from pathlib import Path
from typing import Optional
import clickzetta # type: ignore[import]
from pydantic import BaseModel, model_validator
from extensions.storage.base_storage import BaseStorage
from .volume_permissions import VolumePermissionManager, check_volume_permission
logger = logging.getLogger(__name__)
class ClickZettaVolumeConfig(BaseModel):
"""Configuration for ClickZetta Volume storage."""
username: str = ""
password: str = ""
instance: str = ""
service: str = "api.clickzetta.com"
workspace: str = "quick_start"
vcluster: str = "default_ap"
schema_name: str = "dify"
volume_type: str = "table" # table|user|external
volume_name: Optional[str] = None # For external volumes
table_prefix: str = "dataset_" # Prefix for table volume names
dify_prefix: str = "dify_km" # Directory prefix for User Volume
permission_check: bool = True # Enable/disable permission checking
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""Validate the configuration values.
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
then fall back to CLICKZETTA_* environment variables (for vector DB config).
"""
import os
# Helper function to get environment variable with fallback
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
# First try CLICKZETTA_VOLUME_* specific config
volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", ""))
if volume_value:
return str(volume_value)
# Then try environment variables
volume_env = os.getenv(volume_key)
if volume_env:
return volume_env
# Fall back to existing CLICKZETTA_* config
fallback_env = os.getenv(fallback_key)
if fallback_env:
return fallback_env
return default or ""
# Apply environment variables with fallback to existing CLICKZETTA_* config
values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
values.setdefault(
"service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")
)
values.setdefault(
"workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")
)
values.setdefault(
"vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")
)
values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
# Volume-specific configurations (no fallback to vector DB config)
values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table"))
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
# 暂时禁用权限检查功能直接设置为false
values.setdefault("permission_check", False)
# Validate required fields
if not values.get("username"):
raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required")
if not values.get("password"):
raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required")
if not values.get("instance"):
raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required")
# Validate volume type
volume_type = values["volume_type"]
if volume_type not in ["table", "user", "external"]:
raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external")
if volume_type == "external" and not values.get("volume_name"):
raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type")
return values
class ClickZettaVolumeStorage(BaseStorage):
"""ClickZetta Volume storage implementation."""
def __init__(self, config: ClickZettaVolumeConfig):
"""Initialize ClickZetta Volume storage.
Args:
config: ClickZetta Volume configuration
"""
self._config = config
self._connection = None
self._permission_manager: VolumePermissionManager | None = None
self._init_connection()
self._init_permission_manager()
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
def _init_connection(self):
"""Initialize ClickZetta connection."""
try:
self._connection = clickzetta.connect(
username=self._config.username,
password=self._config.password,
instance=self._config.instance,
service=self._config.service,
workspace=self._config.workspace,
vcluster=self._config.vcluster,
schema=self._config.schema_name,
)
logger.debug("ClickZetta connection established")
except Exception as e:
logger.exception("Failed to connect to ClickZetta")
raise
def _init_permission_manager(self):
"""Initialize permission manager."""
try:
self._permission_manager = VolumePermissionManager(
self._connection, self._config.volume_type, self._config.volume_name
)
logger.debug("Permission manager initialized")
except Exception as e:
logger.exception("Failed to initialize permission manager")
raise
def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str:
"""Get the appropriate volume path based on volume type."""
if self._config.volume_type == "user":
# Add dify prefix for User Volume to organize files
return f"{self._config.dify_prefix}/{filename}"
elif self._config.volume_type == "table":
# Check if this should use User Volume (special directories)
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
# Use User Volume with dify prefix for special directories
return f"{self._config.dify_prefix}/{filename}"
if dataset_id:
return f"{self._config.table_prefix}{dataset_id}/{filename}"
else:
# Extract dataset_id from filename if not provided
# Format: dataset_id/filename
if "/" in filename:
return filename
else:
raise ValueError("dataset_id is required for table volume or filename must include dataset_id/")
elif self._config.volume_type == "external":
return filename
else:
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str:
"""Get SQL prefix for volume operations."""
if self._config.volume_type == "user":
return "USER VOLUME"
elif self._config.volume_type == "table":
# For Dify's current file storage pattern, most files are stored in
# paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext"
# These should use USER VOLUME for better compatibility
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
return "USER VOLUME"
# Only use TABLE VOLUME for actual dataset-specific paths
# like "dataset_12345/file.pdf" or paths with dataset_ prefix
if dataset_id:
table_name = f"{self._config.table_prefix}{dataset_id}"
else:
# Default table name for generic operations
table_name = "default_dataset"
return f"TABLE VOLUME {table_name}"
elif self._config.volume_type == "external":
return f"VOLUME {self._config.volume_name}"
else:
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
def _execute_sql(self, sql: str, fetch: bool = False):
"""Execute SQL command."""
try:
if self._connection is None:
raise RuntimeError("Connection not initialized")
with self._connection.cursor() as cursor:
cursor.execute(sql)
if fetch:
return cursor.fetchall()
return None
except Exception as e:
logger.exception("SQL execution failed: %s", sql)
raise
def _ensure_table_volume_exists(self, dataset_id: str) -> None:
"""Ensure table volume exists for the given dataset_id."""
if self._config.volume_type != "table" or not dataset_id:
return
# Skip for upload_files and other special directories that use USER VOLUME
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
return
table_name = f"{self._config.table_prefix}{dataset_id}"
try:
# Check if table exists
check_sql = f"SHOW TABLES LIKE '{table_name}'"
result = self._execute_sql(check_sql, fetch=True)
if not result:
# Create table with volume
create_sql = f"""
CREATE TABLE {table_name} (
id INT PRIMARY KEY AUTO_INCREMENT,
filename VARCHAR(255) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_filename (filename)
) WITH VOLUME
"""
self._execute_sql(create_sql)
logger.info("Created table volume: %s", table_name)
except Exception as e:
logger.warning("Failed to create table volume %s: %s", table_name, e)
# Don't raise exception, let the operation continue
# The table might exist but not be visible due to permissions
def save(self, filename: str, data: bytes) -> None:
"""Save data to ClickZetta Volume.
Args:
filename: File path in volume
data: File content as bytes
"""
# Extract dataset_id from filename if present
dataset_id = None
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
filename = parts[1]
# Ensure table volume exists (for table volumes)
if dataset_id:
self._ensure_table_volume_exists(dataset_id)
# Check permissions (if enabled)
if self._config.permission_check:
# Skip permission check for special directories that use USER VOLUME
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
if self._permission_manager is not None:
check_volume_permission(self._permission_manager, "save", dataset_id)
# Write data to temporary file
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(data)
temp_file_path = temp_file.name
try:
# Upload to volume
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
actual_filename = volume_path.split("/")[-1] if "/" in volume_path else volume_path
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'"
else:
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'"
self._execute_sql(sql)
logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path)
finally:
# Clean up temporary file
Path(temp_file_path).unlink(missing_ok=True)
def load_once(self, filename: str) -> bytes:
"""Load file content from ClickZetta Volume.
Args:
filename: File path in volume
Returns:
File content as bytes
"""
# Extract dataset_id from filename if present
dataset_id = None
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
filename = parts[1]
# Check permissions (if enabled)
if self._config.permission_check:
# Skip permission check for special directories that use USER VOLUME
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
if self._permission_manager is not None:
check_volume_permission(self._permission_manager, "load_once", dataset_id)
# Download to temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'"
else:
sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'"
self._execute_sql(sql)
# Find the downloaded file (may be in subdirectories)
downloaded_file = None
for root, dirs, files in os.walk(temp_dir):
for file in files:
if file == filename or file == os.path.basename(filename):
downloaded_file = Path(root) / file
break
if downloaded_file:
break
if not downloaded_file or not downloaded_file.exists():
raise FileNotFoundError(f"Downloaded file not found: {filename}")
content = downloaded_file.read_bytes()
logger.debug("File %s loaded from ClickZetta Volume", filename)
return content
def load_stream(self, filename: str) -> Generator:
"""Load file as stream from ClickZetta Volume.
Args:
filename: File path in volume
Yields:
File content chunks
"""
content = self.load_once(filename)
batch_size = 4096
stream = BytesIO(content)
while chunk := stream.read(batch_size):
yield chunk
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
def download(self, filename: str, target_filepath: str):
"""Download file from ClickZetta Volume to local path.
Args:
filename: File path in volume
target_filepath: Local target file path
"""
content = self.load_once(filename)
with Path(target_filepath).open("wb") as f:
f.write(content)
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
def exists(self, filename: str) -> bool:
"""Check if file exists in ClickZetta Volume.
Args:
filename: File path in volume
Returns:
True if file exists, False otherwise
"""
try:
# Extract dataset_id from filename if present
dataset_id = None
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
filename = parts[1]
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'"
else:
sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'"
rows = self._execute_sql(sql, fetch=True)
exists = len(rows) > 0
logger.debug("File %s exists check: %s", filename, exists)
return exists
except Exception as e:
logger.warning("Error checking file existence for %s: %s", filename, e)
return False
def delete(self, filename: str):
"""Delete file from ClickZetta Volume.
Args:
filename: File path in volume
"""
if not self.exists(filename):
logger.debug("File %s not found, skip delete", filename)
return
# Extract dataset_id from filename if present
dataset_id = None
if "/" in filename and self._config.volume_type == "table":
parts = filename.split("/", 1)
if parts[0].startswith(self._config.table_prefix):
dataset_id = parts[0][len(self._config.table_prefix) :]
filename = parts[1]
else:
dataset_id = parts[0]
filename = parts[1]
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# Get the actual volume path (may include dify_km prefix)
volume_path = self._get_volume_path(filename, dataset_id)
# For User Volume, use the full path with dify_km prefix
if volume_prefix == "USER VOLUME":
sql = f"REMOVE {volume_prefix} FILE '{volume_path}'"
else:
sql = f"REMOVE {volume_prefix} FILE '{filename}'"
self._execute_sql(sql)
logger.debug("File %s deleted from ClickZetta Volume", filename)
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
"""Scan files and directories in ClickZetta Volume.
Args:
path: Path to scan (dataset_id for table volumes)
files: Include files in results
directories: Include directories in results
Returns:
List of file/directory paths
"""
try:
# For table volumes, path is treated as dataset_id
dataset_id = None
if self._config.volume_type == "table":
dataset_id = path
path = "" # Root of the table volume
volume_prefix = self._get_volume_sql_prefix(dataset_id)
# For User Volume, add dify prefix to path
if volume_prefix == "USER VOLUME":
if path:
scan_path = f"{self._config.dify_prefix}/{path}"
sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'"
else:
sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'"
else:
if path:
sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'"
else:
sql = f"LIST {volume_prefix}"
rows = self._execute_sql(sql, fetch=True)
result = []
for row in rows:
file_path = row[0] # relative_path column
# For User Volume, remove dify prefix from results
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
result.append(file_path)
logger.debug("Scanned %d items in path %s", len(result), path)
return result
except Exception as e:
logger.exception("Error scanning path %s", path)
return []

View File

@ -0,0 +1,516 @@
"""ClickZetta Volume文件生命周期管理
该模块提供文件版本控制自动清理备份和恢复等生命周期管理功能
支持知识库文件的完整生命周期管理
"""
import json
import logging
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Optional
logger = logging.getLogger(__name__)
class FileStatus(Enum):
"""文件状态枚举"""
ACTIVE = "active" # 活跃状态
ARCHIVED = "archived" # 已归档
DELETED = "deleted" # 已删除(软删除)
BACKUP = "backup" # 备份文件
@dataclass
class FileMetadata:
"""文件元数据"""
filename: str
size: int | None
created_at: datetime
modified_at: datetime
version: int | None
status: FileStatus
checksum: Optional[str] = None
tags: Optional[dict[str, str]] = None
parent_version: Optional[int] = None
def to_dict(self) -> dict:
"""转换为字典格式"""
data = asdict(self)
data["created_at"] = self.created_at.isoformat()
data["modified_at"] = self.modified_at.isoformat()
data["status"] = self.status.value
return data
@classmethod
def from_dict(cls, data: dict) -> "FileMetadata":
"""从字典创建实例"""
data = data.copy()
data["created_at"] = datetime.fromisoformat(data["created_at"])
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
data["status"] = FileStatus(data["status"])
return cls(**data)
class FileLifecycleManager:
"""文件生命周期管理器"""
def __init__(self, storage, dataset_id: Optional[str] = None):
"""初始化生命周期管理器
Args:
storage: ClickZetta Volume存储实例
dataset_id: 数据集ID用于Table Volume
"""
self._storage = storage
self._dataset_id = dataset_id
self._metadata_file = ".dify_file_metadata.json"
self._version_prefix = ".versions/"
self._backup_prefix = ".backups/"
self._deleted_prefix = ".deleted/"
# 获取权限管理器(如果存在)
self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None)
def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata:
"""保存文件并管理生命周期
Args:
filename: 文件名
data: 文件内容
tags: 文件标签
Returns:
文件元数据
"""
# 权限检查
if not self._check_permission(filename, "save"):
from .volume_permissions import VolumePermissionError
raise VolumePermissionError(
f"Permission denied for lifecycle save operation on file: {filename}",
operation="save",
volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"),
dataset_id=self._dataset_id,
)
try:
# 1. 检查是否存在旧版本
metadata_dict = self._load_metadata()
current_metadata = metadata_dict.get(filename)
# 2. 如果存在旧版本,创建版本备份
if current_metadata:
self._create_version_backup(filename, current_metadata)
# 3. 计算文件信息
now = datetime.now()
checksum = self._calculate_checksum(data)
new_version = (current_metadata["version"] + 1) if current_metadata else 1
# 4. 保存新文件
self._storage.save(filename, data)
# 5. 创建元数据
created_at = now
parent_version = None
if current_metadata:
# 如果created_at是字符串转换为datetime
if isinstance(current_metadata["created_at"], str):
created_at = datetime.fromisoformat(current_metadata["created_at"])
else:
created_at = current_metadata["created_at"]
parent_version = current_metadata["version"]
file_metadata = FileMetadata(
filename=filename,
size=len(data),
created_at=created_at,
modified_at=now,
version=new_version,
status=FileStatus.ACTIVE,
checksum=checksum,
tags=tags or {},
parent_version=parent_version,
)
# 6. 更新元数据
metadata_dict[filename] = file_metadata.to_dict()
self._save_metadata(metadata_dict)
logger.info("File %s saved with lifecycle management, version %s", filename, new_version)
return file_metadata
except Exception as e:
logger.exception("Failed to save file with lifecycle")
raise
def get_file_metadata(self, filename: str) -> Optional[FileMetadata]:
"""获取文件元数据
Args:
filename: 文件名
Returns:
文件元数据如果不存在返回None
"""
try:
metadata_dict = self._load_metadata()
if filename in metadata_dict:
return FileMetadata.from_dict(metadata_dict[filename])
return None
except Exception as e:
logger.exception("Failed to get file metadata for %s", filename)
return None
def list_file_versions(self, filename: str) -> list[FileMetadata]:
"""列出文件的所有版本
Args:
filename: 文件名
Returns:
文件版本列表按版本号排序
"""
try:
versions = []
# 获取当前版本
current_metadata = self.get_file_metadata(filename)
if current_metadata:
versions.append(current_metadata)
# 获取历史版本
version_pattern = f"{self._version_prefix}{filename}.v*"
try:
version_files = self._storage.scan(self._dataset_id or "", files=True)
for file_path in version_files:
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
# 解析版本号
version_str = file_path.split(".v")[-1].split(".")[0]
try:
version_num = int(version_str)
# 这里简化处理,实际应该从版本文件中读取元数据
# 暂时创建基本的元数据信息
except ValueError:
continue
except:
# 如果无法扫描版本文件,只返回当前版本
pass
return sorted(versions, key=lambda x: x.version or 0, reverse=True)
except Exception as e:
logger.exception("Failed to list file versions for %s", filename)
return []
def restore_version(self, filename: str, version: int) -> bool:
"""恢复文件到指定版本
Args:
filename: 文件名
version: 要恢复的版本号
Returns:
恢复是否成功
"""
try:
version_filename = f"{self._version_prefix}{filename}.v{version}"
# 检查版本文件是否存在
if not self._storage.exists(version_filename):
logger.warning("Version %s of %s not found", version, filename)
return False
# 读取版本文件内容
version_data = self._storage.load_once(version_filename)
# 保存当前版本为备份
current_metadata = self.get_file_metadata(filename)
if current_metadata:
self._create_version_backup(filename, current_metadata.to_dict())
# 恢复文件
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
return True
except Exception as e:
logger.exception("Failed to restore %s to version %s", filename, version)
return False
def archive_file(self, filename: str) -> bool:
"""归档文件
Args:
filename: 文件名
Returns:
归档是否成功
"""
# 权限检查
if not self._check_permission(filename, "archive"):
logger.warning("Permission denied for archive operation on file: %s", filename)
return False
try:
# 更新文件状态为归档
metadata_dict = self._load_metadata()
if filename not in metadata_dict:
logger.warning("File %s not found in metadata", filename)
return False
metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict)
logger.info("File %s archived successfully", filename)
return True
except Exception as e:
logger.exception("Failed to archive file %s", filename)
return False
def soft_delete_file(self, filename: str) -> bool:
"""软删除文件(移动到删除目录)
Args:
filename: 文件名
Returns:
删除是否成功
"""
# 权限检查
if not self._check_permission(filename, "delete"):
logger.warning("Permission denied for soft delete operation on file: %s", filename)
return False
try:
# 检查文件是否存在
if not self._storage.exists(filename):
logger.warning("File %s not found", filename)
return False
# 读取文件内容
file_data = self._storage.load_once(filename)
# 移动到删除目录
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self._storage.save(deleted_filename, file_data)
# 删除原文件
self._storage.delete(filename)
# 更新元数据
metadata_dict = self._load_metadata()
if filename in metadata_dict:
metadata_dict[filename]["status"] = FileStatus.DELETED.value
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
self._save_metadata(metadata_dict)
logger.info("File %s soft deleted successfully", filename)
return True
except Exception as e:
logger.exception("Failed to soft delete file %s", filename)
return False
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
"""清理旧版本文件
Args:
max_versions: 保留的最大版本数
max_age_days: 版本文件的最大保留天数
Returns:
清理的文件数量
"""
try:
cleaned_count = 0
cutoff_date = datetime.now() - timedelta(days=max_age_days)
# 获取所有版本文件
try:
all_files = self._storage.scan(self._dataset_id or "", files=True)
version_files = [f for f in all_files if f.startswith(self._version_prefix)]
# 按文件分组
file_versions: dict[str, list[tuple[int, str]]] = {}
for version_file in version_files:
# 解析文件名和版本
parts = version_file[len(self._version_prefix) :].split(".v")
if len(parts) >= 2:
base_filename = parts[0]
version_part = parts[1].split(".")[0]
try:
version_num = int(version_part)
if base_filename not in file_versions:
file_versions[base_filename] = []
file_versions[base_filename].append((version_num, version_file))
except ValueError:
continue
# 清理每个文件的旧版本
for base_filename, versions in file_versions.items():
# 按版本号排序
versions.sort(key=lambda x: x[0], reverse=True)
# 保留最新的max_versions个版本删除其余的
if len(versions) > max_versions:
to_delete = versions[max_versions:]
for version_num, version_file in to_delete:
self._storage.delete(version_file)
cleaned_count += 1
logger.debug("Cleaned old version: %s", version_file)
logger.info("Cleaned %d old version files", cleaned_count)
except Exception as e:
logger.warning("Could not scan for version files: %s", e)
return cleaned_count
except Exception as e:
logger.exception("Failed to cleanup old versions")
return 0
def get_storage_statistics(self) -> dict[str, Any]:
"""获取存储统计信息
Returns:
存储统计字典
"""
try:
metadata_dict = self._load_metadata()
stats: dict[str, Any] = {
"total_files": len(metadata_dict),
"active_files": 0,
"archived_files": 0,
"deleted_files": 0,
"total_size": 0,
"versions_count": 0,
"oldest_file": None,
"newest_file": None,
}
oldest_date = None
newest_date = None
for filename, metadata in metadata_dict.items():
file_meta = FileMetadata.from_dict(metadata)
# 统计文件状态
if file_meta.status == FileStatus.ACTIVE:
stats["active_files"] = (stats["active_files"] or 0) + 1
elif file_meta.status == FileStatus.ARCHIVED:
stats["archived_files"] = (stats["archived_files"] or 0) + 1
elif file_meta.status == FileStatus.DELETED:
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1
# 统计大小
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)
# 统计版本
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)
# 找出最新和最旧的文件
if oldest_date is None or file_meta.created_at < oldest_date:
oldest_date = file_meta.created_at
stats["oldest_file"] = filename
if newest_date is None or file_meta.modified_at > newest_date:
newest_date = file_meta.modified_at
stats["newest_file"] = filename
return stats
except Exception as e:
logger.exception("Failed to get storage statistics")
return {}
def _create_version_backup(self, filename: str, metadata: dict):
"""创建版本备份"""
try:
# 读取当前文件内容
current_data = self._storage.load_once(filename)
# 保存为版本文件
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
self._storage.save(version_filename, current_data)
logger.debug("Created version backup: %s", version_filename)
except Exception as e:
logger.warning("Failed to create version backup for %s: %s", filename, e)
def _load_metadata(self) -> dict[str, Any]:
"""加载元数据文件"""
try:
if self._storage.exists(self._metadata_file):
metadata_content = self._storage.load_once(self._metadata_file)
result = json.loads(metadata_content.decode("utf-8"))
return dict(result) if result else {}
else:
return {}
except Exception as e:
logger.warning("Failed to load metadata: %s", e)
return {}
def _save_metadata(self, metadata_dict: dict):
"""保存元数据文件"""
try:
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
logger.debug("Metadata saved successfully")
except Exception as e:
logger.exception("Failed to save metadata")
raise
def _calculate_checksum(self, data: bytes) -> str:
"""计算文件校验和"""
import hashlib
return hashlib.md5(data).hexdigest()
def _check_permission(self, filename: str, operation: str) -> bool:
"""检查文件操作权限
Args:
filename: 文件名
operation: 操作类型
Returns:
True if permission granted, False otherwise
"""
# 如果没有权限管理器,默认允许
if not self._permission_manager:
return True
try:
# 根据操作类型映射到权限
operation_mapping = {
"save": "save",
"load": "load_once",
"delete": "delete",
"archive": "delete", # 归档需要删除权限
"restore": "save", # 恢复需要写权限
"cleanup": "delete", # 清理需要删除权限
"read": "load_once",
"write": "save",
}
mapped_operation = operation_mapping.get(operation, operation)
# 检查权限
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
return bool(result)
except Exception as e:
logger.exception("Permission check failed for %s operation %s", filename, operation)
# 安全默认:权限检查失败时拒绝访问
return False

Some files were not shown because too many files have changed in this diff Show More