mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/memory-orchestration-fed
This commit is contained in:
commit
985becbc41
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,44 @@
|
|||
name: "✨ Refactor"
|
||||
description: Refactor existing code for improved readability and maintainability.
|
||||
title: "[Chore/Refactor] "
|
||||
labels:
|
||||
- refactor
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
|
||||
required: true
|
||||
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
required: true
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report, otherwise it will be closed.
|
||||
required: true
|
||||
- label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :)
|
||||
required: true
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Description
|
||||
placeholder: "Describe the refactor you are proposing."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation
|
||||
placeholder: "Explain why this refactor is necessary."
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
placeholder: "Add any other context or screenshots about the request here."
|
||||
validations:
|
||||
required: false
|
||||
|
|
@ -99,3 +99,6 @@ jobs:
|
|||
|
||||
- name: Run Tool
|
||||
run: uv run --project api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run TestContainers
|
||||
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ permissions:
|
|||
|
||||
jobs:
|
||||
autofix:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ on:
|
|||
- "deploy/dev"
|
||||
- "deploy/enterprise"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,10 @@ on:
|
|||
types: [closed]
|
||||
branches: [main]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
check-and-update:
|
||||
if: github.event.pull_request.merged == true
|
||||
|
|
@ -16,7 +20,7 @@ jobs:
|
|||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # last 2 commits
|
||||
persist-credentials: false
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
|
|
@ -49,7 +53,7 @@ jobs:
|
|||
if: env.FILES_CHANGED == 'true'
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
- name: Generate i18n translations
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: pnpm run auto-gen-i18n
|
||||
|
||||
|
|
@ -57,6 +61,7 @@ jobs:
|
|||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
|
|
|
|||
|
|
@ -215,3 +215,4 @@ mise.toml
|
|||
# AI Assistant
|
||||
.roo/
|
||||
api/.env.backup
|
||||
/clickzetta
|
||||
|
|
|
|||
|
|
@ -235,6 +235,10 @@ Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https:/
|
|||
|
||||
One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Deploy to AKS with Azure Devops Pipeline
|
||||
|
||||
One-Click deploy Dify to AKS with [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -217,6 +217,10 @@ docker compose up -d
|
|||
|
||||
انشر Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### استخدام Azure Devops Pipeline للنشر على AKS
|
||||
|
||||
انشر Dify على AKS بنقرة واحدة باستخدام [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## المساهمة
|
||||
|
||||
|
|
|
|||
|
|
@ -235,6 +235,10 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
|
|||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### AKS-এ ডিপ্লয় করার জন্য Azure Devops Pipeline ব্যবহার
|
||||
|
||||
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ব্যবহার করে Dify কে AKS-এ এক ক্লিকে ডিপ্লয় করুন
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -233,6 +233,9 @@ docker compose up -d
|
|||
|
||||
使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云
|
||||
|
||||
#### 使用 Azure Devops Pipeline 部署到AKS
|
||||
|
||||
使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 将 Dify 一键部署到 AKS
|
||||
|
||||
## Star History
|
||||
|
||||
|
|
|
|||
|
|
@ -230,6 +230,10 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Verwendung von Azure Devops Pipeline für AKS-Bereitstellung
|
||||
|
||||
Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) verwenden
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -230,6 +230,10 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Uso de Azure Devops Pipeline para implementar en AKS
|
||||
|
||||
Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## Contribuir
|
||||
|
||||
|
|
|
|||
|
|
@ -228,6 +228,10 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Utilisation d'Azure Devops Pipeline pour déployer sur AKS
|
||||
|
||||
Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## Contribuer
|
||||
|
||||
|
|
|
|||
|
|
@ -227,6 +227,10 @@ docker compose up -d
|
|||
#### Alibaba Cloud Data Management
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます
|
||||
|
||||
#### AKSへのデプロイにAzure Devops Pipelineを使用
|
||||
|
||||
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)を使用してDifyをAKSにワンクリックでデプロイ
|
||||
|
||||
|
||||
## 貢献
|
||||
|
||||
|
|
|
|||
|
|
@ -228,6 +228,10 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo
|
|||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### AKS 'e' Deploy je Azure Devops Pipeline lo'laH
|
||||
|
||||
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) lo'laH Dify AKS 'e' wa'DIch click 'e' Deploy
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -222,6 +222,10 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
|
|||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다
|
||||
|
||||
#### AKS에 배포하기 위해 Azure Devops Pipeline 사용
|
||||
|
||||
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)을 사용하여 Dify를 AKS에 원클릭으로 배포
|
||||
|
||||
|
||||
## 기여
|
||||
|
||||
|
|
|
|||
|
|
@ -227,6 +227,10 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Usando Azure Devops Pipeline para Implantar no AKS
|
||||
|
||||
Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## Contribuindo
|
||||
|
||||
|
|
|
|||
|
|
@ -228,6 +228,10 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Uporaba Azure Devops Pipeline za uvajanje v AKS
|
||||
|
||||
Z enim klikom namestite Dify v AKS z uporabo [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## Prispevam
|
||||
|
||||
|
|
|
|||
|
|
@ -221,6 +221,10 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
|
|||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın
|
||||
|
||||
#### AKS'ye Dağıtım için Azure Devops Pipeline Kullanımı
|
||||
|
||||
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) kullanarak Dify'ı tek tıkla AKS'ye dağıtın
|
||||
|
||||
|
||||
## Katkıda Bulunma
|
||||
|
||||
|
|
|
|||
|
|
@ -233,6 +233,10 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
|
|||
|
||||
透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲
|
||||
|
||||
#### 使用 Azure Devops Pipeline 部署到AKS
|
||||
|
||||
使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 將 Dify 一鍵部署到 AKS
|
||||
|
||||
|
||||
## 貢獻
|
||||
|
||||
|
|
|
|||
|
|
@ -224,6 +224,10 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Sử dụng Azure Devops Pipeline để Triển khai lên AKS
|
||||
|
||||
Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure Devops Pipeline Helm Chart bởi @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## Đóng góp
|
||||
|
||||
|
|
|
|||
|
|
@ -232,6 +232,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com
|
|||
TABLESTORE_INSTANCE_NAME=instance-name
|
||||
TABLESTORE_ACCESS_KEY_ID=xxx
|
||||
TABLESTORE_ACCESS_KEY_SECRET=xxx
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false
|
||||
|
||||
# Tidb Vector configuration
|
||||
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ RUN apt-get update \
|
|||
|
||||
# Install Python dependencies
|
||||
COPY pyproject.toml uv.lock ./
|
||||
RUN uv sync --locked
|
||||
RUN uv sync --locked --no-dev
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
|
|
|
|||
|
|
@ -5,10 +5,11 @@ import secrets
|
|||
from typing import Any, Optional
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from flask import current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
|
|
@ -180,8 +181,8 @@ def migrate_annotation_vector_database():
|
|||
)
|
||||
if not apps:
|
||||
break
|
||||
except NotFound:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
page += 1
|
||||
for app in apps:
|
||||
|
|
@ -307,8 +308,8 @@ def migrate_knowledge_vector_database():
|
|||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
except NotFound:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
|
|
@ -457,7 +458,7 @@ def convert_to_agent_apps():
|
|||
"""
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query))
|
||||
rs = conn.execute(sa.text(sql_query))
|
||||
|
||||
apps = []
|
||||
for i in rs:
|
||||
|
|
@ -560,8 +561,8 @@ def old_metadata_migration():
|
|||
.order_by(DatasetDocument.created_at.desc())
|
||||
)
|
||||
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
except NotFound:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if not documents:
|
||||
break
|
||||
for document in documents:
|
||||
|
|
@ -702,7 +703,7 @@ def fix_app_site_missing():
|
|||
sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id
|
||||
where sites.id is null limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql))
|
||||
rs = conn.execute(sa.text(sql))
|
||||
|
||||
processed_count = 0
|
||||
for i in rs:
|
||||
|
|
@ -916,7 +917,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
)
|
||||
orphaned_message_files = []
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])})
|
||||
|
||||
|
|
@ -937,7 +938,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
|
||||
query = "DELETE FROM message_files WHERE id IN :ids"
|
||||
with db.engine.begin() as conn:
|
||||
conn.execute(db.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
|
||||
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
|
||||
click.echo(
|
||||
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
|
||||
)
|
||||
|
|
@ -954,7 +955,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white"))
|
||||
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]})
|
||||
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
|
||||
|
|
@ -974,7 +975,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
|
||||
elif ids_table["type"] == "text":
|
||||
|
|
@ -989,7 +990,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
|
|
@ -1008,7 +1009,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
|
|
@ -1037,7 +1038,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white"))
|
||||
query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids"
|
||||
with db.engine.begin() as conn:
|
||||
conn.execute(db.text(query), {"ids": tuple(orphaned_files)})
|
||||
conn.execute(sa.text(query), {"ids": tuple(orphaned_files)})
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red"))
|
||||
return
|
||||
|
|
@ -1107,7 +1108,7 @@ def remove_orphaned_files_on_storage(force: bool):
|
|||
click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white"))
|
||||
query = f"SELECT {files_table['key_column']} FROM {files_table['table']}"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_files_in_tables.append(str(i[0]))
|
||||
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
|||
from .storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
|
||||
from .storage.clickzetta_volume_storage_config import ClickZettaVolumeStorageConfig
|
||||
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from .storage.oci_storage_config import OCIStorageConfig
|
||||
|
|
@ -20,6 +21,7 @@ from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
|||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
from .vdb.clickzetta_config import ClickzettaConfig
|
||||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
||||
|
|
@ -52,6 +54,7 @@ class StorageConfig(BaseSettings):
|
|||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"clickzetta-volume",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
|
|
@ -61,8 +64,9 @@ class StorageConfig(BaseSettings):
|
|||
"local",
|
||||
] = Field(
|
||||
description="Type of storage to use."
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', "
|
||||
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', "
|
||||
"'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
|
||||
"'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
default="opendal",
|
||||
)
|
||||
|
||||
|
|
@ -215,7 +219,7 @@ class DatabaseConfig(BaseSettings):
|
|||
|
||||
class CeleryConfig(DatabaseConfig):
|
||||
CELERY_BACKEND: str = Field(
|
||||
description="Backend for Celery task results. Options: 'database', 'redis'.",
|
||||
description="Backend for Celery task results. Options: 'database', 'redis', 'rabbitmq'.",
|
||||
default="redis",
|
||||
)
|
||||
|
||||
|
|
@ -245,7 +249,12 @@ class CeleryConfig(DatabaseConfig):
|
|||
|
||||
@computed_field
|
||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||
return f"db+{self.SQLALCHEMY_DATABASE_URI}" if self.CELERY_BACKEND == "database" else self.CELERY_BROKER_URL
|
||||
if self.CELERY_BACKEND in ("database", "rabbitmq"):
|
||||
return f"db+{self.SQLALCHEMY_DATABASE_URI}"
|
||||
elif self.CELERY_BACKEND == "redis":
|
||||
return self.CELERY_BROKER_URL
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def BROKER_USE_SSL(self) -> bool:
|
||||
|
|
@ -298,6 +307,7 @@ class MiddlewareConfig(
|
|||
AliyunOSSStorageConfig,
|
||||
AzureBlobStorageConfig,
|
||||
BaiduOBSStorageConfig,
|
||||
ClickZettaVolumeStorageConfig,
|
||||
GoogleCloudStorageConfig,
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
OCIStorageConfig,
|
||||
|
|
@ -310,6 +320,7 @@ class MiddlewareConfig(
|
|||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
ChromaConfig,
|
||||
ClickzettaConfig,
|
||||
HuaweiCloudConfig,
|
||||
MilvusConfig,
|
||||
MyScaleConfig,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
"""ClickZetta Volume Storage Configuration"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ClickZettaVolumeStorageConfig(BaseSettings):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field(
|
||||
description="Username for ClickZetta Volume authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field(
|
||||
description="Password for ClickZetta Volume authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field(
|
||||
description="ClickZetta instance identifier",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_SERVICE: str = Field(
|
||||
description="ClickZetta service endpoint",
|
||||
default="api.clickzetta.com",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_WORKSPACE: str = Field(
|
||||
description="ClickZetta workspace name",
|
||||
default="quick_start",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_VCLUSTER: str = Field(
|
||||
description="ClickZetta virtual cluster name",
|
||||
default="default_ap",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_SCHEMA: str = Field(
|
||||
description="ClickZetta schema name",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_TYPE: str = Field(
|
||||
description="ClickZetta volume type (table|user|external)",
|
||||
default="user",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_NAME: Optional[str] = Field(
|
||||
description="ClickZetta volume name for external volumes",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field(
|
||||
description="Prefix for ClickZetta volume table names",
|
||||
default="dataset_",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field(
|
||||
description="Directory prefix for User Volume to organize Dify files",
|
||||
default="dify_km",
|
||||
)
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ClickzettaConfig(BaseModel):
|
||||
"""
|
||||
Clickzetta Lakehouse vector database configuration
|
||||
"""
|
||||
|
||||
CLICKZETTA_USERNAME: Optional[str] = Field(
|
||||
description="Username for authenticating with Clickzetta Lakehouse",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_PASSWORD: Optional[str] = Field(
|
||||
description="Password for authenticating with Clickzetta Lakehouse",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_INSTANCE: Optional[str] = Field(
|
||||
description="Clickzetta Lakehouse instance ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_SERVICE: Optional[str] = Field(
|
||||
description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')",
|
||||
default="api.clickzetta.com",
|
||||
)
|
||||
|
||||
CLICKZETTA_WORKSPACE: Optional[str] = Field(
|
||||
description="Clickzetta workspace name",
|
||||
default="default",
|
||||
)
|
||||
|
||||
CLICKZETTA_VCLUSTER: Optional[str] = Field(
|
||||
description="Clickzetta virtual cluster name",
|
||||
default="default_ap",
|
||||
)
|
||||
|
||||
CLICKZETTA_SCHEMA: Optional[str] = Field(
|
||||
description="Database schema name in Clickzetta",
|
||||
default="public",
|
||||
)
|
||||
|
||||
CLICKZETTA_BATCH_SIZE: Optional[int] = Field(
|
||||
description="Batch size for bulk insert operations",
|
||||
default=100,
|
||||
)
|
||||
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX: Optional[bool] = Field(
|
||||
description="Enable inverted index for full-text search capabilities",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CLICKZETTA_ANALYZER_TYPE: Optional[str] = Field(
|
||||
description="Analyzer type for full-text search: keyword, english, chinese, unicode",
|
||||
default="chinese",
|
||||
)
|
||||
|
||||
CLICKZETTA_ANALYZER_MODE: Optional[str] = Field(
|
||||
description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)",
|
||||
default="smart",
|
||||
)
|
||||
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field(
|
||||
description="Distance function for vector similarity: l2_distance or cosine_distance",
|
||||
default="cosine_distance",
|
||||
)
|
||||
|
|
@ -28,3 +28,8 @@ class TableStoreConfig(BaseSettings):
|
|||
description="AccessKey secret for the instance name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: bool = Field(
|
||||
description="Whether to normalize full-text search scores to [0, 1]",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ DEFAULT_FILE_NUMBER_LIMITS = 3
|
|||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
|
||||
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
|
||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ from .datasets import (
|
|||
external,
|
||||
hit_testing,
|
||||
metadata,
|
||||
upload_file,
|
||||
website,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
class AnnotationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -123,6 +123,23 @@ class AnnotationListApi(Resource):
|
|||
}
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -131,8 +148,25 @@ class AnnotationListApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
AppAnnotationService.clear_all_annotations(app_id)
|
||||
return {"result": "success"}, 204
|
||||
|
||||
# Use request.args.getlist to get annotation_ids array directly
|
||||
annotation_ids = request.args.getlist("annotation_id")
|
||||
|
||||
# If annotation_ids are provided, handle batch deletion
|
||||
if annotation_ids:
|
||||
# Check if any annotation_ids contain empty strings or invalid values
|
||||
if not all(annotation_id.strip() for annotation_id in annotation_ids if annotation_id):
|
||||
return {
|
||||
"code": "bad_request",
|
||||
"message": "annotation_ids are required if the parameter is provided.",
|
||||
}, 400
|
||||
|
||||
result = AppAnnotationService.delete_app_annotations_in_batch(app_id, annotation_ids)
|
||||
return result, 204
|
||||
# If no annotation_ids are provided, handle clearing all annotations
|
||||
else:
|
||||
AppAnnotationService.clear_all_annotations(app_id)
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class AnnotationExportApi(Resource):
|
||||
|
|
@ -149,25 +183,6 @@ class AnnotationExportApi(Resource):
|
|||
return response, 200
|
||||
|
||||
|
||||
class AnnotationCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
return annotation
|
||||
|
||||
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -210,14 +225,15 @@ class AnnotationBatchImportApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
# check file type
|
||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
|
@ -276,7 +292,7 @@ api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply
|
|||
api.add_resource(
|
||||
AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>"
|
||||
)
|
||||
api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations")
|
||||
api.add_resource(AnnotationApi, "/apps/<uuid:app_id>/annotations")
|
||||
api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export")
|
||||
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import")
|
||||
|
|
|
|||
|
|
@ -28,6 +28,12 @@ from services.feature_service import FeatureService
|
|||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class AppListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -94,7 +100,7 @@ class AppListApi(Resource):
|
|||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
|
|
@ -146,7 +152,7 @@ class AppApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
|
|
@ -189,7 +195,7 @@ class AppCopyApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "message_count": i.message_count})
|
||||
|
||||
|
|
@ -176,7 +176,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"}
|
||||
|
|
@ -310,7 +310,7 @@ ORDER BY
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
|
|
@ -373,7 +373,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
|
|
@ -435,7 +435,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)})
|
||||
|
||||
|
|
@ -495,7 +495,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from datetime import datetime
|
|||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
|
@ -71,7 +72,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "runs": i.runs})
|
||||
|
||||
|
|
@ -133,7 +134,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
|
||||
|
|
@ -195,7 +196,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
|
|
@ -277,7 +278,7 @@ GROUP BY
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def _validate_name(name):
|
|||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
|
@ -113,7 +113,7 @@ class DatasetListApi(Resource):
|
|||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
|
|
@ -683,6 +683,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
|
@ -731,6 +732,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
| VectorType.TENCENT
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
|
|
|||
|
|
@ -642,7 +642,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
|||
return marshal(document_dict, document_status_fields)
|
||||
|
||||
|
||||
class DocumentDetailApi(DocumentResource):
|
||||
class DocumentApi(DocumentResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
@setup_required
|
||||
|
|
@ -730,6 +730,28 @@ class DocumentDetailApi(DocumentResource):
|
|||
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class DocumentProcessingApi(DocumentResource):
|
||||
@setup_required
|
||||
|
|
@ -768,30 +790,6 @@ class DocumentProcessingApi(DocumentResource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DocumentDeleteApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class DocumentMetadataApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -1037,11 +1035,10 @@ api.add_resource(
|
|||
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
|
||||
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(
|
||||
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
|
||||
)
|
||||
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
|
||||
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
|
||||
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
class UploadFileApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""Get upload file."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check upload file
|
||||
if document.data_source_type != "upload_file":
|
||||
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
else:
|
||||
raise ValueError("Upload file id not found in document data source info.")
|
||||
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": url,
|
||||
"download_url": f"{url}&as_attachment=true",
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.timestamp(),
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")
|
||||
|
|
@ -127,7 +127,7 @@ class EducationActivateLimitError(BaseHTTPException):
|
|||
code = 429
|
||||
|
||||
|
||||
class CompilanceRateLimitError(BaseHTTPException):
|
||||
error_code = "compilance_rate_limit"
|
||||
class ComplianceRateLimitError(BaseHTTPException):
|
||||
error_code = "compliance_rate_limit"
|
||||
description = "Rate limit exceeded for downloading compliance report."
|
||||
code = 429
|
||||
|
|
|
|||
|
|
@ -58,21 +58,38 @@ class InstalledAppsListApi(Resource):
|
|||
# filter out apps that user doesn't have access to
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
user_id = current_user.id
|
||||
res = []
|
||||
app_ids = [installed_app["app"].id for installed_app in installed_app_list]
|
||||
webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids)
|
||||
|
||||
# Pre-filter out apps without setting or with sso_verified
|
||||
filtered_installed_apps = []
|
||||
app_id_to_app_code = {}
|
||||
|
||||
for installed_app in installed_app_list:
|
||||
webapp_setting = webapp_settings.get(installed_app["app"].id)
|
||||
if not webapp_setting:
|
||||
app_id = installed_app["app"].id
|
||||
webapp_setting = webapp_settings.get(app_id)
|
||||
if not webapp_setting or webapp_setting.access_mode == "sso_verified":
|
||||
continue
|
||||
if webapp_setting.access_mode == "sso_verified":
|
||||
continue
|
||||
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
|
||||
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
app_code=app_code,
|
||||
):
|
||||
app_code = AppService.get_app_code_by_id(str(app_id))
|
||||
app_id_to_app_code[app_id] = app_code
|
||||
filtered_installed_apps.append(installed_app)
|
||||
|
||||
app_codes = list(app_id_to_app_code.values())
|
||||
|
||||
# Batch permission check
|
||||
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
|
||||
user_id=user_id,
|
||||
app_codes=app_codes,
|
||||
)
|
||||
|
||||
# Keep only allowed apps
|
||||
res = []
|
||||
for installed_app in filtered_installed_apps:
|
||||
app_id = installed_app["app"].id
|
||||
app_code = app_id_to_app_code[app_id]
|
||||
if permissions.get(app_code):
|
||||
res.append(installed_app)
|
||||
|
||||
installed_app_list = res
|
||||
logger.debug("installed_app_list: %s, user_id: %s", installed_app_list, user_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ class FileApi(Resource):
|
|||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
def post(self):
|
||||
file = request.files["file"]
|
||||
source_str = request.form.get("source")
|
||||
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
|
||||
|
||||
|
|
@ -58,6 +57,7 @@ class FileApi(Resource):
|
|||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
file = request.files["file"]
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
|
|
|||
|
|
@ -191,9 +191,6 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
|
@ -201,6 +198,8 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
|||
api = ExternalApi(bp)
|
||||
|
||||
from . import index
|
||||
from .app import annotation, app, audio, completion, conversation, file, message, site, workflow
|
||||
from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow
|
||||
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
|
||||
from .workspace import models
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
|
|
@ -30,6 +30,7 @@ from libs import helper
|
|||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
|
|
@ -113,7 +114,7 @@ class ChatApi(Resource):
|
|||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
|
||||
|
||||
parser.add_argument("workflow_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
|
|
@ -128,6 +129,12 @@ class ChatApi(Resource):
|
|||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except IsDraftWorkflowError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except WorkflowIdFormatError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import json
|
||||
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
|
|
@ -15,6 +17,7 @@ from fields.conversation_fields import (
|
|||
simple_conversation_fields,
|
||||
)
|
||||
from fields.conversation_variable_fields import (
|
||||
conversation_variable_fields,
|
||||
conversation_variable_infinite_scroll_pagination_fields,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
|
|
@ -120,7 +123,41 @@ class ConversationVariablesApi(Resource):
|
|||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
class ConversationVariableDetailApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(conversation_variable_fields)
|
||||
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
|
||||
"""Update a conversation variable's value"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
variable_id = str(variable_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("value", required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id, end_user, json.loads(args["value"])
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationVariableNotExistsError:
|
||||
raise NotFound("Conversation Variable Not Exists.")
|
||||
except services.errors.conversation.ConversationVariableTypeMismatchError as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
|
||||
api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="conversation_name")
|
||||
api.add_resource(ConversationApi, "/conversations")
|
||||
api.add_resource(ConversationDetailApi, "/conversations/<uuid:c_id>", endpoint="conversation_detail")
|
||||
api.add_resource(ConversationVariablesApi, "/conversations/<uuid:c_id>/variables", endpoint="conversation_variables")
|
||||
api.add_resource(
|
||||
ConversationVariableDetailApi,
|
||||
"/conversations/<uuid:c_id>/variables/<uuid:variable_id>",
|
||||
endpoint="conversation_variable_detail",
|
||||
methods=["PUT"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -107,3 +107,15 @@ class UnsupportedFileTypeError(BaseHTTPException):
|
|||
error_code = "unsupported_file_type"
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class FileNotFoundError(BaseHTTPException):
|
||||
error_code = "file_not_found"
|
||||
description = "The requested file was not found."
|
||||
code = 404
|
||||
|
||||
|
||||
class FileAccessDeniedError(BaseHTTPException):
|
||||
error_code = "file_access_denied"
|
||||
description = "Access to the requested file is denied."
|
||||
code = 403
|
||||
|
|
|
|||
|
|
@ -20,18 +20,17 @@ class FileApi(Resource):
|
|||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
file = request.files["file"]
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,186 @@
|
|||
import logging
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
FileAccessDeniedError,
|
||||
FileNotFoundError,
|
||||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import App, EndUser, Message, MessageFile, UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FilePreviewApi(Resource):
|
||||
"""
|
||||
Service API File Preview endpoint
|
||||
|
||||
Provides secure file preview/download functionality for external API users.
|
||||
Files can only be accessed if they belong to messages within the requesting app's context.
|
||||
"""
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, file_id: str):
|
||||
"""
|
||||
Preview/Download a file that was uploaded via Service API
|
||||
|
||||
Args:
|
||||
app_model: The authenticated app model
|
||||
end_user: The authenticated end user (optional)
|
||||
file_id: UUID of the file to preview
|
||||
|
||||
Query Parameters:
|
||||
user: Optional user identifier
|
||||
as_attachment: Boolean, whether to download as attachment (default: false)
|
||||
|
||||
Returns:
|
||||
Stream response with file content
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: File does not exist
|
||||
FileAccessDeniedError: File access denied (not owned by app)
|
||||
"""
|
||||
file_id = str(file_id)
|
||||
|
||||
# Parse query parameters
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
message_file, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
|
||||
# Get file content generator
|
||||
try:
|
||||
generator = storage.load(upload_file.key, stream=True)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"Failed to load file content: {str(e)}")
|
||||
|
||||
# Build response with appropriate headers
|
||||
response = self._build_file_response(generator, upload_file, args["as_attachment"])
|
||||
|
||||
return response
|
||||
|
||||
def _validate_file_ownership(self, file_id: str, app_id: str) -> tuple[MessageFile, UploadFile]:
|
||||
"""
|
||||
Validate that the file belongs to a message within the requesting app's context
|
||||
|
||||
Security validations performed:
|
||||
1. File exists in MessageFile table (was used in a conversation)
|
||||
2. Message belongs to the requesting app
|
||||
3. UploadFile record exists and is accessible
|
||||
4. File tenant matches app tenant (additional security layer)
|
||||
|
||||
Args:
|
||||
file_id: UUID of the file to validate
|
||||
app_id: UUID of the requesting app
|
||||
|
||||
Returns:
|
||||
Tuple of (MessageFile, UploadFile) if validation passes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: File or related records not found
|
||||
FileAccessDeniedError: File does not belong to the app's context
|
||||
"""
|
||||
try:
|
||||
# Input validation
|
||||
if not file_id or not app_id:
|
||||
raise FileAccessDeniedError("Invalid file or app identifier")
|
||||
|
||||
# First, find the MessageFile that references this upload file
|
||||
message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first()
|
||||
|
||||
if not message_file:
|
||||
raise FileNotFoundError("File not found in message context")
|
||||
|
||||
# Get the message and verify it belongs to the requesting app
|
||||
message = (
|
||||
db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise FileAccessDeniedError("File access denied: not owned by requesting app")
|
||||
|
||||
# Get the actual upload file record
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise FileNotFoundError("Upload file record not found")
|
||||
|
||||
# Additional security: verify tenant isolation
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if app and upload_file.tenant_id != app.tenant_id:
|
||||
raise FileAccessDeniedError("File access denied: tenant mismatch")
|
||||
|
||||
return message_file, upload_file
|
||||
|
||||
except (FileNotFoundError, FileAccessDeniedError):
|
||||
# Re-raise our custom exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log unexpected errors for debugging
|
||||
logger.exception(
|
||||
"Unexpected error during file ownership validation",
|
||||
extra={"file_id": file_id, "app_id": app_id, "error": str(e)},
|
||||
)
|
||||
raise FileAccessDeniedError("File access validation failed")
|
||||
|
||||
def _build_file_response(self, generator, upload_file: UploadFile, as_attachment: bool = False) -> Response:
|
||||
"""
|
||||
Build Flask Response object with appropriate headers for file streaming
|
||||
|
||||
Args:
|
||||
generator: File content generator from storage
|
||||
upload_file: UploadFile database record
|
||||
as_attachment: Whether to set Content-Disposition as attachment
|
||||
|
||||
Returns:
|
||||
Flask Response object with streaming file content
|
||||
"""
|
||||
response = Response(
|
||||
generator,
|
||||
mimetype=upload_file.mime_type,
|
||||
direct_passthrough=True,
|
||||
headers={},
|
||||
)
|
||||
|
||||
# Add Content-Length if known
|
||||
if upload_file.size and upload_file.size > 0:
|
||||
response.headers["Content-Length"] = str(upload_file.size)
|
||||
|
||||
# Add Accept-Ranges header for audio/video files to support seeking
|
||||
if upload_file.mime_type in [
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/mp4",
|
||||
"audio/ogg",
|
||||
"audio/flac",
|
||||
"audio/aac",
|
||||
"video/mp4",
|
||||
"video/webm",
|
||||
"video/quicktime",
|
||||
"audio/x-m4a",
|
||||
]:
|
||||
response.headers["Accept-Ranges"] = "bytes"
|
||||
|
||||
# Set Content-Disposition for downloads
|
||||
if as_attachment and upload_file.name:
|
||||
encoded_filename = quote(upload_file.name)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
# Override content-type for downloads to force download
|
||||
response.headers["Content-Type"] = "application/octet-stream"
|
||||
|
||||
# Add caching headers for performance
|
||||
response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Register the API endpoint
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
|
|
@ -5,7 +5,7 @@ from flask import request
|
|||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
|
|
@ -34,6 +34,7 @@ from libs.helper import TimestampField
|
|||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
|
@ -120,6 +121,59 @@ class WorkflowRunApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowRunByIdApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, workflow_id: str):
|
||||
"""
|
||||
Run specific workflow by ID
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Add workflow_id to args for AppGenerateService
|
||||
args["workflow_id"] = workflow_id
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except IsDraftWorkflowError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except WorkflowIdFormatError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowTaskStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
|
|
@ -193,5 +247,6 @@ class WorkflowAppLogApi(Resource):
|
|||
|
||||
api.add_resource(WorkflowRunApi, "/workflows/run")
|
||||
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_run_id>")
|
||||
api.add_resource(WorkflowRunByIdApi, "/workflows/<string:workflow_id>/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
|
||||
api.add_resource(WorkflowAppLogApi, "/workflows/logs")
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def _validate_name(name):
|
|||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ class DatasetListApi(DatasetApiResource):
|
|||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
|
|
|
|||
|
|
@ -234,8 +234,6 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
args["retrieval_model"].get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
|
@ -243,6 +241,8 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
|
|
@ -358,39 +358,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
class DocumentDeleteApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id):
|
||||
"""Delete document."""
|
||||
document_id = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# get dataset info
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
try:
|
||||
# delete document
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DocumentListApi(DatasetApiResource):
|
||||
def get(self, tenant_id, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
|
|
@ -473,7 +440,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
|||
return data
|
||||
|
||||
|
||||
class DocumentDetailApi(DatasetApiResource):
|
||||
class DocumentApi(DatasetApiResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
|
|
@ -567,6 +534,37 @@ class DocumentDetailApi(DatasetApiResource):
|
|||
|
||||
return response
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id):
|
||||
"""Delete document."""
|
||||
document_id = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# get dataset info
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
try:
|
||||
# delete document
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DocumentAddByTextApi,
|
||||
|
|
@ -588,7 +586,6 @@ api.add_resource(
|
|||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
|
||||
)
|
||||
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
|
||||
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
|
|
|
|||
|
|
@ -12,18 +12,17 @@ from services.file_service import FileService
|
|||
class FileApi(WebApiResource):
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
file = request.files["file"]
|
||||
source = request.form.get("source")
|
||||
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
source = request.form.get("source")
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
|
||||
|
|
|
|||
|
|
@ -148,6 +148,8 @@ SupportedComparisonOperator = Literal[
|
|||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from core.app.entities.task_entities import (
|
|||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
StreamEvent,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
|
|
@ -180,11 +181,15 @@ class MessageCycleManager:
|
|||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).filter(MessageFile.id == message_id).first()
|
||||
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
|
||||
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
from_variable_selector=from_variable_selector,
|
||||
event=event_type,
|
||||
)
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ class ProviderConfig(BasicProviderConfig):
|
|||
|
||||
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
|
||||
required: bool = False
|
||||
default: Optional[Union[int, str]] = None
|
||||
default: Optional[Union[int, str, float, bool]] = None
|
||||
options: Optional[list[Option]] = None
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def get_attr(*, file: File, attr: FileAttribute):
|
|||
case FileAttribute.TRANSFER_METHOD:
|
||||
return file.transfer_method.value
|
||||
case FileAttribute.URL:
|
||||
return file.remote_url
|
||||
return _to_url(file)
|
||||
case FileAttribute.EXTENSION:
|
||||
return file.extension
|
||||
case FileAttribute.RELATED_ID:
|
||||
|
|
|
|||
|
|
@ -121,9 +121,8 @@ class TokenBufferMemory:
|
|||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if curr_message_tokens > max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
pruned_memory.append(prompt_messages.pop(0))
|
||||
prompt_messages.pop(0)
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import logging
|
|||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry import trace
|
||||
|
|
@ -40,8 +41,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
|
|||
try:
|
||||
# Choose the appropriate exporter based on config type
|
||||
exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
|
||||
|
||||
# Inspect the provided endpoint to determine its structure
|
||||
parsed = urlparse(arize_phoenix_config.endpoint)
|
||||
base_endpoint = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/")
|
||||
|
||||
if isinstance(arize_phoenix_config, ArizeConfig):
|
||||
arize_endpoint = f"{arize_phoenix_config.endpoint}/v1"
|
||||
arize_endpoint = f"{base_endpoint}/v1"
|
||||
arize_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"space_id": arize_phoenix_config.space_id or "",
|
||||
|
|
@ -53,7 +60,7 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra
|
|||
timeout=30,
|
||||
)
|
||||
else:
|
||||
phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces"
|
||||
phoenix_endpoint = f"{base_endpoint}{path}/v1/traces"
|
||||
phoenix_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class PhoenixConfig(BaseTracingConfig):
|
|||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com")
|
||||
return validate_url_with_path(v, "https://app.phoenix.arize.com")
|
||||
|
||||
|
||||
class LangfuseConfig(BaseTracingConfig):
|
||||
|
|
|
|||
|
|
@ -322,7 +322,7 @@ class OpsTraceManager:
|
|||
:return:
|
||||
"""
|
||||
# auth check
|
||||
if enabled == True:
|
||||
if enabled:
|
||||
try:
|
||||
provider_config_map[tracing_provider]
|
||||
except KeyError:
|
||||
|
|
@ -407,7 +407,6 @@ class TraceTask:
|
|||
def __init__(
|
||||
self,
|
||||
trace_type: Any,
|
||||
trace_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
workflow_execution: Optional[WorkflowExecution] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
|
|
@ -423,7 +422,7 @@ class TraceTask:
|
|||
self.timer = timer
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.app_id = None
|
||||
|
||||
self.trace_id = None
|
||||
self.kwargs = kwargs
|
||||
external_trace_id = kwargs.get("external_trace_id")
|
||||
if external_trace_id:
|
||||
|
|
|
|||
|
|
@ -208,6 +208,7 @@ class BasePluginClient:
|
|||
except Exception:
|
||||
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
|
||||
|
||||
logger.error("Error in stream reponse for plugin %s", rep.__dict__)
|
||||
self._handle_plugin_daemon_error(error.error_type, error.message)
|
||||
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
|
||||
if rep.data is None:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from collections.abc import Mapping
|
|||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from extensions.ext_logging import get_request_id
|
||||
|
||||
|
||||
class PluginDaemonError(Exception):
|
||||
"""Base class for all plugin daemon errors."""
|
||||
|
|
@ -11,7 +13,7 @@ class PluginDaemonError(Exception):
|
|||
|
||||
def __str__(self) -> str:
|
||||
# returns the class name and description
|
||||
return f"{self.__class__.__name__}: {self.description}"
|
||||
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"
|
||||
|
||||
|
||||
class PluginDaemonInternalError(PluginDaemonError):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,190 @@
|
|||
# Clickzetta Vector Database Integration
|
||||
|
||||
This module provides integration with Clickzetta Lakehouse as a vector database for Dify.
|
||||
|
||||
## Features
|
||||
|
||||
- **Vector Storage**: Store and retrieve high-dimensional vectors using Clickzetta's native VECTOR type
|
||||
- **Vector Search**: Efficient similarity search using HNSW algorithm
|
||||
- **Full-Text Search**: Leverage Clickzetta's inverted index for powerful text search capabilities
|
||||
- **Hybrid Search**: Combine vector similarity and full-text search for better results
|
||||
- **Multi-language Support**: Built-in support for Chinese, English, and Unicode text processing
|
||||
- **Scalable**: Leverage Clickzetta's distributed architecture for large-scale deployments
|
||||
|
||||
## Configuration
|
||||
|
||||
### Required Environment Variables
|
||||
|
||||
All seven configuration parameters are required:
|
||||
|
||||
```bash
|
||||
# Authentication
|
||||
CLICKZETTA_USERNAME=your_username
|
||||
CLICKZETTA_PASSWORD=your_password
|
||||
|
||||
# Instance configuration
|
||||
CLICKZETTA_INSTANCE=your_instance_id
|
||||
CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
CLICKZETTA_WORKSPACE=your_workspace
|
||||
CLICKZETTA_VCLUSTER=your_vcluster
|
||||
CLICKZETTA_SCHEMA=your_schema
|
||||
```
|
||||
|
||||
### Optional Configuration
|
||||
|
||||
```bash
|
||||
# Batch processing
|
||||
CLICKZETTA_BATCH_SIZE=100
|
||||
|
||||
# Full-text search configuration
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX=true
|
||||
CLICKZETTA_ANALYZER_TYPE=chinese # Options: keyword, english, chinese, unicode
|
||||
CLICKZETTA_ANALYZER_MODE=smart # Options: max_word, smart
|
||||
|
||||
# Vector search configuration
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # Options: l2_distance, cosine_distance
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Set Clickzetta as the Vector Store
|
||||
|
||||
In your Dify configuration, set:
|
||||
|
||||
```bash
|
||||
VECTOR_STORE=clickzetta
|
||||
```
|
||||
|
||||
### 2. Table Structure
|
||||
|
||||
Clickzetta will automatically create tables with the following structure:
|
||||
|
||||
```sql
|
||||
CREATE TABLE <collection_name> (
|
||||
id STRING NOT NULL,
|
||||
content STRING NOT NULL,
|
||||
metadata JSON,
|
||||
vector VECTOR(FLOAT, <dimension>) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
-- Vector index for similarity search
|
||||
CREATE VECTOR INDEX idx_<collection_name>_vec
|
||||
ON TABLE <schema>.<collection_name>(vector)
|
||||
PROPERTIES (
|
||||
"distance.function" = "cosine_distance",
|
||||
"scalar.type" = "f32"
|
||||
);
|
||||
|
||||
-- Inverted index for full-text search (if enabled)
|
||||
CREATE INVERTED INDEX idx_<collection_name>_text
|
||||
ON <schema>.<collection_name>(content)
|
||||
PROPERTIES (
|
||||
"analyzer" = "chinese",
|
||||
"mode" = "smart"
|
||||
);
|
||||
```
|
||||
|
||||
## Full-Text Search Capabilities
|
||||
|
||||
Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
|
||||
### Analyzer Types
|
||||
|
||||
1. **keyword**: No tokenization, treats the entire string as a single token
|
||||
- Best for: Exact matching, IDs, codes
|
||||
|
||||
2. **english**: Designed for English text
|
||||
- Features: Recognizes ASCII letters and numbers, converts to lowercase
|
||||
- Best for: English content
|
||||
|
||||
3. **chinese**: Chinese text tokenizer
|
||||
- Features: Recognizes Chinese and English characters, removes punctuation
|
||||
- Best for: Chinese or mixed Chinese-English content
|
||||
|
||||
4. **unicode**: Multi-language tokenizer based on Unicode
|
||||
- Features: Recognizes text boundaries in multiple languages
|
||||
- Best for: Multi-language content
|
||||
|
||||
### Analyzer Modes
|
||||
|
||||
- **max_word**: Fine-grained tokenization (more tokens)
|
||||
- **smart**: Intelligent tokenization (balanced)
|
||||
|
||||
### Full-Text Search Functions
|
||||
|
||||
- `MATCH_ALL(column, query)`: All terms must be present
|
||||
- `MATCH_ANY(column, query)`: At least one term must be present
|
||||
- `MATCH_PHRASE(column, query)`: Exact phrase matching
|
||||
- `MATCH_PHRASE_PREFIX(column, query)`: Phrase prefix matching
|
||||
- `MATCH_REGEXP(column, pattern)`: Regular expression matching
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Vector Search
|
||||
|
||||
1. **Adjust exploration factor** for accuracy vs speed trade-off:
|
||||
```sql
|
||||
SET cz.vector.index.search.ef=64;
|
||||
```
|
||||
|
||||
2. **Use appropriate distance functions**:
|
||||
- `cosine_distance`: Best for normalized embeddings (e.g., from language models)
|
||||
- `l2_distance`: Best for raw feature vectors
|
||||
|
||||
### Full-Text Search
|
||||
|
||||
1. **Choose the right analyzer**:
|
||||
- Use `keyword` for exact matching
|
||||
- Use language-specific analyzers for better tokenization
|
||||
|
||||
2. **Combine with vector search**:
|
||||
- Pre-filter with full-text search for better performance
|
||||
- Use hybrid search for improved relevance
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Issues
|
||||
|
||||
1. Verify all 7 required configuration parameters are set
|
||||
2. Check network connectivity to Clickzetta service
|
||||
3. Ensure the user has proper permissions on the schema
|
||||
|
||||
### Search Performance
|
||||
|
||||
1. Verify vector index exists:
|
||||
```sql
|
||||
SHOW INDEX FROM <schema>.<table_name>;
|
||||
```
|
||||
|
||||
2. Check if vector index is being used:
|
||||
```sql
|
||||
EXPLAIN SELECT ... WHERE l2_distance(...) < threshold;
|
||||
```
|
||||
Look for `vector_index_search_type` in the execution plan.
|
||||
|
||||
### Full-Text Search Not Working
|
||||
|
||||
1. Verify inverted index is created
|
||||
2. Check analyzer configuration matches your content language
|
||||
3. Use `TOKENIZE()` function to test tokenization:
|
||||
```sql
|
||||
SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart'));
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns
|
||||
2. Full-text search relevance scores are not provided by Clickzetta
|
||||
3. Inverted index creation may fail for very large existing tables (continue without error)
|
||||
4. Index naming constraints:
|
||||
- Index names must be unique within a schema
|
||||
- Only one vector index can be created per column
|
||||
- The implementation uses timestamps to ensure unique index names
|
||||
5. A column can only have one vector index at a time
|
||||
|
||||
## References
|
||||
|
||||
- [Clickzetta Vector Search Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/vector-search.md)
|
||||
- [Clickzetta Inverted Index Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/inverted-index.md)
|
||||
- [Clickzetta SQL Functions](../../../../../../../yunqidoc/cn_markdown_20250526/sql_functions/)
|
||||
|
|
@ -0,0 +1 @@
|
|||
# Clickzetta Vector Database Integration for Dify
|
||||
|
|
@ -0,0 +1,843 @@
|
|||
import json
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import clickzetta # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from clickzetta import Connection
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ClickZetta Lakehouse Vector Database Configuration
|
||||
|
||||
|
||||
class ClickzettaConfig(BaseModel):
|
||||
"""
|
||||
Configuration class for Clickzetta connection.
|
||||
"""
|
||||
|
||||
username: str
|
||||
password: str
|
||||
instance: str
|
||||
service: str = "api.clickzetta.com"
|
||||
workspace: str = "quick_start"
|
||||
vcluster: str = "default_ap"
|
||||
schema_name: str = "dify" # Renamed to avoid shadowing BaseModel.schema
|
||||
# Advanced settings
|
||||
batch_size: int = 20 # Reduced batch size to avoid large SQL statements
|
||||
enable_inverted_index: bool = True # Enable inverted index for full-text search
|
||||
analyzer_type: str = "chinese" # Analyzer type for full-text search: keyword, english, chinese, unicode
|
||||
analyzer_mode: str = "smart" # Analyzer mode: max_word, smart
|
||||
vector_distance_function: str = "cosine_distance" # l2_distance or cosine_distance
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""
|
||||
Validate the configuration values.
|
||||
"""
|
||||
if not values.get("username"):
|
||||
raise ValueError("config CLICKZETTA_USERNAME is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("config CLICKZETTA_PASSWORD is required")
|
||||
if not values.get("instance"):
|
||||
raise ValueError("config CLICKZETTA_INSTANCE is required")
|
||||
if not values.get("service"):
|
||||
raise ValueError("config CLICKZETTA_SERVICE is required")
|
||||
if not values.get("workspace"):
|
||||
raise ValueError("config CLICKZETTA_WORKSPACE is required")
|
||||
if not values.get("vcluster"):
|
||||
raise ValueError("config CLICKZETTA_VCLUSTER is required")
|
||||
if not values.get("schema_name"):
|
||||
raise ValueError("config CLICKZETTA_SCHEMA is required")
|
||||
return values
|
||||
|
||||
|
||||
class ClickzettaVector(BaseVector):
|
||||
"""
|
||||
Clickzetta vector storage implementation.
|
||||
"""
|
||||
|
||||
# Class-level write queue and lock for serializing writes
|
||||
_write_queue: Optional[queue.Queue] = None
|
||||
_write_thread: Optional[threading.Thread] = None
|
||||
_write_lock = threading.Lock()
|
||||
_shutdown = False
|
||||
|
||||
def __init__(self, collection_name: str, config: ClickzettaConfig):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name
|
||||
self._connection: Optional[Connection] = None
|
||||
self._init_connection()
|
||||
self._init_write_queue()
|
||||
|
||||
def _init_connection(self):
|
||||
"""Initialize Clickzetta connection."""
|
||||
self._connection = clickzetta.connect(
|
||||
username=self._config.username,
|
||||
password=self._config.password,
|
||||
instance=self._config.instance,
|
||||
service=self._config.service,
|
||||
workspace=self._config.workspace,
|
||||
vcluster=self._config.vcluster,
|
||||
schema=self._config.schema_name,
|
||||
)
|
||||
|
||||
# Set session parameters for better string handling and performance optimization
|
||||
if self._connection is not None:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use quote mode for string literal escaping to handle quotes better
|
||||
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
|
||||
logger.info("Set string literal escape mode to 'quote' for better quote handling")
|
||||
|
||||
# Performance optimization hints for vector operations
|
||||
self._set_performance_hints(cursor)
|
||||
|
||||
def _set_performance_hints(self, cursor):
|
||||
"""Set ClickZetta performance optimization hints for vector operations."""
|
||||
try:
|
||||
# Performance optimization hints for vector operations and query processing
|
||||
performance_hints = [
|
||||
# Vector index optimization
|
||||
"SET cz.storage.parquet.vector.index.read.memory.cache = true",
|
||||
"SET cz.storage.parquet.vector.index.read.local.cache = false",
|
||||
# Query optimization
|
||||
"SET cz.sql.table.scan.push.down.filter = true",
|
||||
"SET cz.sql.table.scan.enable.ensure.filter = true",
|
||||
"SET cz.storage.always.prefetch.internal = true",
|
||||
"SET cz.optimizer.generate.columns.always.valid = true",
|
||||
"SET cz.sql.index.prewhere.enabled = true",
|
||||
# Storage optimization
|
||||
"SET cz.storage.parquet.enable.io.prefetch = false",
|
||||
"SET cz.optimizer.enable.mv.rewrite = false",
|
||||
"SET cz.sql.dump.as.lz4 = true",
|
||||
"SET cz.optimizer.limited.optimization.naive.query = true",
|
||||
"SET cz.sql.table.scan.enable.push.down.log = false",
|
||||
"SET cz.storage.use.file.format.local.stats = false",
|
||||
"SET cz.storage.local.file.object.cache.level = all",
|
||||
# Job execution optimization
|
||||
"SET cz.sql.job.fast.mode = true",
|
||||
"SET cz.storage.parquet.non.contiguous.read = true",
|
||||
"SET cz.sql.compaction.after.commit = true",
|
||||
]
|
||||
|
||||
for hint in performance_hints:
|
||||
cursor.execute(hint)
|
||||
|
||||
logger.info(
|
||||
"Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Catch any errors setting performance hints but continue with defaults
|
||||
logger.exception("Failed to set some performance hints, continuing with default settings")
|
||||
|
||||
@classmethod
|
||||
def _init_write_queue(cls):
|
||||
"""Initialize the write queue and worker thread."""
|
||||
with cls._write_lock:
|
||||
if cls._write_queue is None:
|
||||
cls._write_queue = queue.Queue()
|
||||
cls._write_thread = threading.Thread(target=cls._write_worker, daemon=True)
|
||||
cls._write_thread.start()
|
||||
logger.info("Started Clickzetta write worker thread")
|
||||
|
||||
@classmethod
|
||||
def _write_worker(cls):
|
||||
"""Worker thread that processes write tasks sequentially."""
|
||||
while not cls._shutdown:
|
||||
try:
|
||||
# Get task from queue with timeout
|
||||
if cls._write_queue is not None:
|
||||
task = cls._write_queue.get(timeout=1)
|
||||
if task is None: # Shutdown signal
|
||||
break
|
||||
|
||||
# Execute the write task
|
||||
func, args, kwargs, result_queue = task
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
result_queue.put((True, result))
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
logger.exception("Write task failed")
|
||||
result_queue.put((False, e))
|
||||
finally:
|
||||
cls._write_queue.task_done()
|
||||
else:
|
||||
break
|
||||
except queue.Empty:
|
||||
continue
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
logger.exception("Write worker error")
|
||||
|
||||
def _execute_write(self, func, *args, **kwargs):
|
||||
"""Execute a write operation through the queue."""
|
||||
if ClickzettaVector._write_queue is None:
|
||||
raise RuntimeError("Write queue not initialized")
|
||||
|
||||
result_queue: queue.Queue[tuple[bool, Any]] = queue.Queue()
|
||||
ClickzettaVector._write_queue.put((func, args, kwargs, result_queue))
|
||||
|
||||
# Wait for result
|
||||
success, result = result_queue.get()
|
||||
if not success:
|
||||
raise result
|
||||
return result
|
||||
|
||||
def get_type(self) -> str:
|
||||
"""Return the vector database type."""
|
||||
return "clickzetta"
|
||||
|
||||
def _ensure_connection(self) -> "Connection":
|
||||
"""Ensure connection is available and return it."""
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Database connection not initialized")
|
||||
return self._connection
|
||||
|
||||
def _table_exists(self) -> bool:
|
||||
"""Check if the table exists."""
|
||||
try:
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}")
|
||||
return True
|
||||
except (RuntimeError, ValueError) as e:
|
||||
if "table or view not found" in str(e).lower():
|
||||
return False
|
||||
else:
|
||||
# Re-raise if it's a different error
|
||||
raise
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Create the collection and add initial documents."""
|
||||
# Execute table creation through write queue to avoid concurrent conflicts
|
||||
self._execute_write(self._create_table_and_indexes, embeddings)
|
||||
|
||||
# Add initial texts
|
||||
if texts:
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def _create_table_and_indexes(self, embeddings: list[list[float]]):
|
||||
"""Create table and indexes (executed in write worker thread)."""
|
||||
# Check if table already exists to avoid unnecessary index creation
|
||||
if self._table_exists():
|
||||
logger.info("Table %s.%s already exists, skipping creation", self._config.schema_name, self._table_name)
|
||||
return
|
||||
|
||||
# Create table with vector and metadata columns
|
||||
dimension = len(embeddings[0]) if embeddings else 768
|
||||
|
||||
create_table_sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} (
|
||||
id STRING NOT NULL COMMENT 'Unique document identifier',
|
||||
{Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
|
||||
{Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes',
|
||||
{Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
|
||||
'High-dimensional embedding vector for semantic similarity search',
|
||||
PRIMARY KEY (id)
|
||||
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
|
||||
"""
|
||||
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(create_table_sql)
|
||||
logger.info("Created table %s.%s", self._config.schema_name, self._table_name)
|
||||
|
||||
# Create vector index
|
||||
self._create_vector_index(cursor)
|
||||
|
||||
# Create inverted index for full-text search if enabled
|
||||
if self._config.enable_inverted_index:
|
||||
self._create_inverted_index(cursor)
|
||||
|
||||
def _create_vector_index(self, cursor):
|
||||
"""Create HNSW vector index for similarity search."""
|
||||
# Use a fixed index name based on table and column name
|
||||
index_name = f"idx_{self._table_name}_vector"
|
||||
|
||||
# First check if an index already exists on this column
|
||||
try:
|
||||
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
|
||||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
# Check if vector index already exists on the embedding column
|
||||
if Field.VECTOR.value in str(idx).lower():
|
||||
logger.info("Vector index already exists on column %s", Field.VECTOR.value)
|
||||
return
|
||||
except (RuntimeError, ValueError) as e:
|
||||
logger.warning("Failed to check existing indexes: %s", e)
|
||||
|
||||
index_sql = f"""
|
||||
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
|
||||
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value})
|
||||
PROPERTIES (
|
||||
"distance.function" = "{self._config.vector_distance_function}",
|
||||
"scalar.type" = "f32",
|
||||
"m" = "16",
|
||||
"ef.construction" = "128"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
cursor.execute(index_sql)
|
||||
logger.info("Created vector index: %s", index_name)
|
||||
except (RuntimeError, ValueError) as e:
|
||||
error_msg = str(e).lower()
|
||||
if "already exists" in error_msg or "already has index" in error_msg or "with the same type" in error_msg:
|
||||
logger.info("Vector index already exists: %s", e)
|
||||
else:
|
||||
logger.exception("Failed to create vector index")
|
||||
raise
|
||||
|
||||
def _create_inverted_index(self, cursor):
|
||||
"""Create inverted index for full-text search."""
|
||||
# Use a fixed index name based on table name to avoid duplicates
|
||||
index_name = f"idx_{self._table_name}_text"
|
||||
|
||||
# Check if an inverted index already exists on this column
|
||||
try:
|
||||
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
|
||||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
idx_str = str(idx).lower()
|
||||
# More precise check: look for inverted index specifically on the content column
|
||||
if (
|
||||
"inverted" in idx_str
|
||||
and Field.CONTENT_KEY.value.lower() in idx_str
|
||||
and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)
|
||||
):
|
||||
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx)
|
||||
return
|
||||
except (RuntimeError, ValueError) as e:
|
||||
logger.warning("Failed to check existing indexes: %s", e)
|
||||
|
||||
index_sql = f"""
|
||||
CREATE INVERTED INDEX IF NOT EXISTS {index_name}
|
||||
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value})
|
||||
PROPERTIES (
|
||||
"analyzer" = "{self._config.analyzer_type}",
|
||||
"mode" = "{self._config.analyzer_mode}"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
cursor.execute(index_sql)
|
||||
logger.info("Created inverted index: %s", index_name)
|
||||
except (RuntimeError, ValueError) as e:
|
||||
error_msg = str(e).lower()
|
||||
# Handle ClickZetta specific error messages
|
||||
if (
|
||||
"already exists" in error_msg
|
||||
or "already has index" in error_msg
|
||||
or "with the same type" in error_msg
|
||||
or "cannot create inverted index" in error_msg
|
||||
) and "already has index" in error_msg:
|
||||
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value)
|
||||
# Try to get the existing index name for logging
|
||||
try:
|
||||
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
|
||||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower():
|
||||
logger.info("Found existing inverted index: %s", idx)
|
||||
break
|
||||
except (RuntimeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
logger.warning("Failed to create inverted index: %s", e)
|
||||
# Continue without inverted index - full-text search will fall back to LIKE
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Add documents with embeddings to the collection."""
|
||||
if not documents:
|
||||
return
|
||||
|
||||
batch_size = self._config.batch_size
|
||||
total_batches = (len(documents) + batch_size - 1) // batch_size
|
||||
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
|
||||
# Execute batch insert through write queue
|
||||
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
|
||||
|
||||
def _insert_batch(
|
||||
self,
|
||||
batch_docs: list[Document],
|
||||
batch_embeddings: list[list[float]],
|
||||
batch_index: int,
|
||||
batch_size: int,
|
||||
total_batches: int,
|
||||
):
|
||||
"""Insert a batch of documents using parameterized queries (executed in write worker thread)."""
|
||||
if not batch_docs or not batch_embeddings:
|
||||
logger.warning("Empty batch provided, skipping insertion")
|
||||
return
|
||||
|
||||
if len(batch_docs) != len(batch_embeddings):
|
||||
logger.error("Mismatch between docs (%d) and embeddings (%d)", len(batch_docs), len(batch_embeddings))
|
||||
return
|
||||
|
||||
# Prepare data for parameterized insertion
|
||||
data_rows = []
|
||||
vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768
|
||||
|
||||
for doc, embedding in zip(batch_docs, batch_embeddings):
|
||||
# Optimized: minimal checks for common case, fallback for edge cases
|
||||
metadata = doc.metadata if doc.metadata else {}
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
|
||||
doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))
|
||||
|
||||
# Fast path for JSON serialization
|
||||
try:
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=True)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("JSON serialization failed, using empty dict")
|
||||
metadata_json = "{}"
|
||||
|
||||
content = doc.page_content or ""
|
||||
|
||||
# According to ClickZetta docs, vector should be formatted as array string
|
||||
# for external systems: '[1.0, 2.0, 3.0]'
|
||||
vector_str = "[" + ",".join(map(str, embedding)) + "]"
|
||||
data_rows.append([doc_id, content, metadata_json, vector_str])
|
||||
|
||||
# Check if we have any valid data to insert
|
||||
if not data_rows:
|
||||
logger.warning("No valid documents to insert in batch %d/%d", batch_index // batch_size + 1, total_batches)
|
||||
return
|
||||
|
||||
# Use parameterized INSERT with executemany for better performance and security
|
||||
# Cast JSON and VECTOR in SQL, pass raw data as parameters
|
||||
columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}"
|
||||
insert_sql = (
|
||||
f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) "
|
||||
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
|
||||
)
|
||||
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
# Set session-level hints for batch insert operations
|
||||
# Note: executemany doesn't support hints parameter, so we set them as session variables
|
||||
cursor.execute("SET cz.sql.job.fast.mode = true")
|
||||
cursor.execute("SET cz.sql.compaction.after.commit = true")
|
||||
cursor.execute("SET cz.storage.always.prefetch.internal = true")
|
||||
|
||||
cursor.executemany(insert_sql, data_rows)
|
||||
logger.info(
|
||||
"Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)",
|
||||
batch_index // batch_size + 1,
|
||||
total_batches,
|
||||
len(data_rows),
|
||||
vector_dimension,
|
||||
)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows))
|
||||
logger.exception("SQL template: %s", insert_sql)
|
||||
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
|
||||
raise
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""Check if a document exists by ID."""
|
||||
safe_id = self._safe_doc_id(id)
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", [safe_id]
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
return result[0] > 0 if result else False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
"""Delete documents by IDs."""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
# Check if table exists before attempting delete
|
||||
if not self._table_exists():
|
||||
logger.warning("Table %s.%s does not exist, skipping delete", self._config.schema_name, self._table_name)
|
||||
return
|
||||
|
||||
# Execute delete through write queue
|
||||
self._execute_write(self._delete_by_ids_impl, ids)
|
||||
|
||||
def _delete_by_ids_impl(self, ids: list[str]) -> None:
|
||||
"""Implementation of delete by IDs (executed in write worker thread)."""
|
||||
safe_ids = [self._safe_doc_id(id) for id in ids]
|
||||
# Create properly escaped string literals for SQL
|
||||
id_list = ",".join(f"'{id}'" for id in safe_ids)
|
||||
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})"
|
||||
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
"""Delete documents by metadata field."""
|
||||
# Check if table exists before attempting delete
|
||||
if not self._table_exists():
|
||||
logger.warning("Table %s.%s does not exist, skipping delete", self._config.schema_name, self._table_name)
|
||||
return
|
||||
|
||||
# Execute delete through write queue
|
||||
self._execute_write(self._delete_by_metadata_field_impl, key, value)
|
||||
|
||||
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None:
|
||||
"""Implementation of delete by metadata field (executed in write worker thread)."""
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
# Using JSON path to filter with parameterized query
|
||||
# Note: JSON path requires literal key name, cannot be parameterized
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
sql = (
|
||||
f"DELETE FROM {self._config.schema_name}.{self._table_name} "
|
||||
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?"
|
||||
)
|
||||
cursor.execute(sql, [value])
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents by vector similarity."""
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
if document_ids_filter:
|
||||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Add distance threshold based on distance function
|
||||
vector_dimension = len(query_vector)
|
||||
if self._config.vector_distance_function == "cosine_distance":
|
||||
# For cosine distance, smaller is better (0 = identical, 2 = opposite)
|
||||
distance_func = "COSINE_DISTANCE"
|
||||
if score_threshold > 0:
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
filter_clauses.append(
|
||||
f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}"
|
||||
)
|
||||
else:
|
||||
# For L2 distance, smaller is better
|
||||
distance_func = "L2_DISTANCE"
|
||||
if score_threshold > 0:
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}")
|
||||
|
||||
where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1"
|
||||
|
||||
# Execute vector search query
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value},
|
||||
{distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
ORDER BY distance
|
||||
LIMIT {top_k}
|
||||
"""
|
||||
|
||||
documents = []
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
# Use hints parameter for vector search optimization
|
||||
search_hints = {
|
||||
"hints": {
|
||||
"sdk.job.timeout": 60, # Increase timeout for vector search
|
||||
"cz.sql.job.fast.mode": True,
|
||||
"cz.storage.parquet.vector.index.read.memory.cache": True,
|
||||
}
|
||||
}
|
||||
cursor.execute(search_sql, parameters=search_hints)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.exception("JSON parsing failed")
|
||||
# Fallback: extract document_id with regex
|
||||
import re
|
||||
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ""))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
|
||||
# Add score based on distance
|
||||
if self._config.vector_distance_function == "cosine_distance":
|
||||
metadata["score"] = 1 - (row[3] / 2)
|
||||
else:
|
||||
metadata["score"] = 1 / (1 + row[3])
|
||||
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents using full-text search with inverted index."""
|
||||
if not self._config.enable_inverted_index:
|
||||
logger.warning("Full-text search is not enabled. Enable inverted index in config.")
|
||||
return []
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
if document_ids_filter:
|
||||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use match_all function for full-text search
|
||||
# match_all requires all terms to be present
|
||||
# Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')")
|
||||
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
# Execute full-text search query
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
LIMIT {top_k}
|
||||
"""
|
||||
|
||||
documents = []
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
# Use hints parameter for full-text search optimization
|
||||
fulltext_hints = {
|
||||
"hints": {
|
||||
"sdk.job.timeout": 30, # Timeout for full-text search
|
||||
"cz.sql.job.fast.mode": True,
|
||||
"cz.sql.index.prewhere.enabled": True,
|
||||
}
|
||||
}
|
||||
cursor.execute(search_sql, parameters=fulltext_hints)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.exception("JSON parsing failed")
|
||||
# Fallback: extract document_id with regex
|
||||
import re
|
||||
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ""))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
|
||||
# Add a relevance score for full-text search
|
||||
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
logger.exception("Full-text search failed")
|
||||
# Fallback to LIKE search if full-text search fails
|
||||
return self._search_by_like(query, **kwargs)
|
||||
|
||||
return documents
|
||||
|
||||
def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Fallback search using LIKE operator."""
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
if document_ids_filter:
|
||||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use simple quote escaping for LIKE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'")
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
LIMIT {top_k}
|
||||
"""
|
||||
|
||||
documents = []
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
# Use hints parameter for LIKE search optimization
|
||||
like_hints = {
|
||||
"hints": {
|
||||
"sdk.job.timeout": 20, # Timeout for LIKE search
|
||||
"cz.sql.job.fast.mode": True,
|
||||
}
|
||||
}
|
||||
cursor.execute(search_sql, parameters=like_hints)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.exception("JSON parsing failed")
|
||||
# Fallback: extract document_id with regex
|
||||
import re
|
||||
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ""))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
|
||||
metadata["score"] = 0.5 # Lower score for LIKE search
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete the entire collection."""
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
|
||||
|
||||
def _format_vector_simple(self, vector: list[float]) -> str:
|
||||
"""Simple vector formatting for SQL queries."""
|
||||
return ",".join(map(str, vector))
|
||||
|
||||
def _safe_doc_id(self, doc_id: str) -> str:
|
||||
"""Ensure doc_id is safe for SQL and doesn't contain special characters."""
|
||||
if not doc_id:
|
||||
return str(uuid.uuid4())
|
||||
# Remove or replace potentially problematic characters
|
||||
safe_id = str(doc_id)
|
||||
# Only allow alphanumeric, hyphens, underscores
|
||||
safe_id = "".join(c for c in safe_id if c.isalnum() or c in "-_")
|
||||
if not safe_id: # If all characters were removed
|
||||
return str(uuid.uuid4())
|
||||
return safe_id[:255] # Limit length
|
||||
|
||||
|
||||
class ClickzettaVectorFactory(AbstractVectorFactory):
|
||||
"""Factory for creating Clickzetta vector instances."""
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||
"""Initialize a Clickzetta vector instance."""
|
||||
# Get configuration from environment variables or dataset config
|
||||
config = ClickzettaConfig(
|
||||
username=dify_config.CLICKZETTA_USERNAME or "",
|
||||
password=dify_config.CLICKZETTA_PASSWORD or "",
|
||||
instance=dify_config.CLICKZETTA_INSTANCE or "",
|
||||
service=dify_config.CLICKZETTA_SERVICE or "api.clickzetta.com",
|
||||
workspace=dify_config.CLICKZETTA_WORKSPACE or "quick_start",
|
||||
vcluster=dify_config.CLICKZETTA_VCLUSTER or "default_ap",
|
||||
schema_name=dify_config.CLICKZETTA_SCHEMA or "dify",
|
||||
batch_size=dify_config.CLICKZETTA_BATCH_SIZE or 100,
|
||||
enable_inverted_index=dify_config.CLICKZETTA_ENABLE_INVERTED_INDEX or True,
|
||||
analyzer_type=dify_config.CLICKZETTA_ANALYZER_TYPE or "chinese",
|
||||
analyzer_mode=dify_config.CLICKZETTA_ANALYZER_MODE or "smart",
|
||||
vector_distance_function=dify_config.CLICKZETTA_VECTOR_DISTANCE_FUNCTION or "cosine_distance",
|
||||
)
|
||||
|
||||
# Use dataset collection name as table name
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower()
|
||||
|
||||
return ClickzettaVector(collection_name=collection_name, config=config)
|
||||
|
|
@ -7,6 +7,7 @@ from urllib.parse import urlparse
|
|||
import requests
|
||||
from elasticsearch import Elasticsearch
|
||||
from flask import current_app
|
||||
from packaging.version import parse as parse_version
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
|
|
@ -149,7 +150,7 @@ class ElasticSearchVector(BaseVector):
|
|||
return cast(str, info["version"]["number"])
|
||||
|
||||
def _check_version(self):
|
||||
if self._version < "8.0.0":
|
||||
if parse_version(self._version) < parse_version("8.0.0"):
|
||||
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
|
||||
|
||||
def get_type(self) -> str:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
|
||||
import tablestore # type: ignore
|
||||
|
|
@ -22,6 +23,7 @@ class TableStoreConfig(BaseModel):
|
|||
access_key_secret: Optional[str] = None
|
||||
instance_name: Optional[str] = None
|
||||
endpoint: Optional[str] = None
|
||||
normalize_full_text_bm25_score: Optional[bool] = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
|
@ -47,6 +49,7 @@ class TableStoreVector(BaseVector):
|
|||
config.access_key_secret,
|
||||
config.instance_name,
|
||||
)
|
||||
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
|
||||
self._table_name = f"{collection_name}"
|
||||
self._index_name = f"{collection_name}_idx"
|
||||
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
|
||||
|
|
@ -131,8 +134,8 @@ class TableStoreVector(BaseVector):
|
|||
filtered_list = None
|
||||
if document_ids_filter:
|
||||
filtered_list = ["document_id=" + item for item in document_ids_filter]
|
||||
|
||||
return self._search_by_full_text(query, filtered_list, top_k)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._search_by_full_text(query, filtered_list, top_k, score_threshold)
|
||||
|
||||
def delete(self) -> None:
|
||||
self._delete_table_if_exist()
|
||||
|
|
@ -318,7 +321,19 @@ class TableStoreVector(BaseVector):
|
|||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
|
||||
@staticmethod
|
||||
def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float:
|
||||
"""
|
||||
Args:
|
||||
score: BM25 search score.
|
||||
k: decay factor, the larger the k, the steeper the low score end
|
||||
"""
|
||||
normalized_score = 1 - math.exp(-k * score)
|
||||
return max(0.0, min(1.0, normalized_score))
|
||||
|
||||
def _search_by_full_text(
|
||||
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
|
||||
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
|
||||
|
||||
|
|
@ -339,15 +354,27 @@ class TableStoreVector(BaseVector):
|
|||
|
||||
documents = []
|
||||
for search_hit in search_response.search_hits:
|
||||
score = None
|
||||
if self._normalize_full_text_bm25_score:
|
||||
score = self._normalize_score_exp_decay(search_hit.score)
|
||||
|
||||
# skip when score is below threshold and use normalize score
|
||||
if score and score <= score_threshold:
|
||||
continue
|
||||
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR.value)
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR.value)
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
|
||||
if score:
|
||||
metadata["score"] = score
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
|
||||
|
|
@ -355,6 +382,8 @@ class TableStoreVector(BaseVector):
|
|||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
if self._normalize_full_text_bm25_score:
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
|
||||
|
|
@ -375,5 +404,6 @@ class TableStoreVectorFactory(AbstractVectorFactory):
|
|||
instance_name=dify_config.TABLESTORE_INSTANCE_NAME,
|
||||
access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID,
|
||||
access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET,
|
||||
normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -246,6 +246,10 @@ class TencentVector(BaseVector):
|
|||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filter = None
|
||||
if document_ids_filter:
|
||||
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
|
||||
if not self._enable_hybrid_search:
|
||||
return []
|
||||
res = self._client.hybrid_search(
|
||||
|
|
@ -269,6 +273,7 @@ class TencentVector(BaseVector):
|
|||
),
|
||||
retrieve_vector=False,
|
||||
limit=kwargs.get("top_k", 4),
|
||||
filter=filter,
|
||||
)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
|
|
|||
|
|
@ -172,6 +172,10 @@ class Vector:
|
|||
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
|
||||
|
||||
return MatrixoneVectorFactory
|
||||
case VectorType.CLICKZETTA:
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
|
||||
|
||||
return ClickzettaVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
|
|
|||
|
|
@ -30,3 +30,4 @@ class VectorType(StrEnum):
|
|||
TABLESTORE = "tablestore"
|
||||
HUAWEI_CLOUD = "huawei_cloud"
|
||||
MATRIXONE = "matrixone"
|
||||
CLICKZETTA = "clickzetta"
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ SupportedComparisonOperator = Literal[
|
|||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import operator
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import requests
|
||||
|
|
@ -130,13 +131,15 @@ class NotionExtractor(BaseExtractor):
|
|||
data[property_name] = value
|
||||
row_dict = {k: v for k, v in data.items() if v}
|
||||
row_content = ""
|
||||
for key, value in row_dict.items():
|
||||
for key, value in sorted(row_dict.items(), key=operator.itemgetter(0)):
|
||||
if isinstance(value, dict):
|
||||
value_dict = {k: v for k, v in value.items() if v}
|
||||
value_content = "".join(f"{k}:{v} " for k, v in value_dict.items())
|
||||
row_content = row_content + f"{key}:{value_content}\n"
|
||||
else:
|
||||
row_content = row_content + f"{key}:{value}\n"
|
||||
if "url" in result:
|
||||
row_content = row_content + f"Row Page URL:{result.get('url', '')}\n"
|
||||
database_content.append(row_content)
|
||||
|
||||
has_more = response_data.get("has_more", False)
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class WordExtractor(BaseExtractor):
|
|||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
content = self.parse_docx(self.file_path, "storage")
|
||||
content = self.parse_docx(self.file_path)
|
||||
return [
|
||||
Document(
|
||||
page_content=content,
|
||||
|
|
@ -189,23 +189,8 @@ class WordExtractor(BaseExtractor):
|
|||
paragraph_content.append(run.text)
|
||||
return "".join(paragraph_content).strip()
|
||||
|
||||
def _parse_paragraph(self, paragraph, image_map):
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if run.element.xpath(".//a:blip"):
|
||||
for blip in run.element.xpath(".//a:blip"):
|
||||
embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if embed_id:
|
||||
rel_target = run.part.rels[embed_id].target_ref
|
||||
if rel_target in image_map:
|
||||
paragraph_content.append(image_map[rel_target])
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
return " ".join(paragraph_content) if paragraph_content else ""
|
||||
|
||||
def parse_docx(self, docx_path, image_folder):
|
||||
def parse_docx(self, docx_path):
|
||||
doc = DocxDocument(docx_path)
|
||||
os.makedirs(image_folder, exist_ok=True)
|
||||
|
||||
content = []
|
||||
|
||||
|
|
|
|||
|
|
@ -5,14 +5,13 @@ from __future__ import annotations
|
|||
from typing import Any, Optional
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
from core.rag.splitter.text_splitter import (
|
||||
TS,
|
||||
Collection,
|
||||
Literal,
|
||||
RecursiveCharacterTextSplitter,
|
||||
Set,
|
||||
TokenTextSplitter,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
|
@ -45,14 +44,6 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
|||
|
||||
return [len(text) for text in texts]
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
|
||||
"allowed_special": allowed_special,
|
||||
"disallowed_special": disallowed_special,
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs}
|
||||
|
||||
return cls(length_function=_character_encoder, **kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,9 +20,6 @@ class Tool(ABC):
|
|||
The base class of a tool
|
||||
"""
|
||||
|
||||
entity: ToolEntity
|
||||
runtime: ToolRuntime
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
|
|
|||
|
|
@ -37,12 +37,12 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
|||
@staticmethod
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||
try:
|
||||
if local_tz is None:
|
||||
local_tz = datetime.now().astimezone().tzinfo
|
||||
if isinstance(local_tz, str):
|
||||
local_tz = pytz.timezone(local_tz)
|
||||
local_time = datetime.strptime(localtime, time_format)
|
||||
localtime = local_tz.localize(local_time) # type: ignore
|
||||
if local_tz is None:
|
||||
localtime = local_time.astimezone() # type: ignore
|
||||
elif isinstance(local_tz, str):
|
||||
local_tz = pytz.timezone(local_tz)
|
||||
localtime = local_tz.localize(local_time) # type: ignore
|
||||
timestamp = int(localtime.timestamp()) # type: ignore
|
||||
return timestamp
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class TimezoneConversionTool(BuiltinTool):
|
|||
target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore
|
||||
if not target_time:
|
||||
yield self.create_text_message(
|
||||
f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}"
|
||||
f"Invalid datetime and timezone: {current_time},{current_timezone},{target_timezone}"
|
||||
)
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@ class BuiltinTool(Tool):
|
|||
:param meta: the meta data of a tool call processing
|
||||
"""
|
||||
|
||||
provider: str
|
||||
|
||||
def __init__(self, provider: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.provider = provider
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from os import getenv
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
|
@ -20,10 +21,21 @@ API_TOOL_DEFAULT_TIMEOUT = (
|
|||
)
|
||||
|
||||
|
||||
class ApiTool(Tool):
|
||||
api_bundle: ApiToolBundle
|
||||
provider_id: str
|
||||
@dataclass
|
||||
class ParsedResponse:
|
||||
"""Represents a parsed HTTP response with type information"""
|
||||
|
||||
content: Union[str, dict]
|
||||
is_json: bool
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Convert response to string format for credential validation"""
|
||||
if isinstance(self.content, dict):
|
||||
return json.dumps(self.content, ensure_ascii=False)
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class ApiTool(Tool):
|
||||
"""
|
||||
Api tool
|
||||
"""
|
||||
|
|
@ -61,7 +73,9 @@ class ApiTool(Tool):
|
|||
|
||||
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
|
||||
# validate response
|
||||
return self.validate_and_parse_response(response)
|
||||
parsed_response = self.validate_and_parse_response(response)
|
||||
# For credential validation, always return as string
|
||||
return parsed_response.to_string()
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API
|
||||
|
|
@ -115,23 +129,36 @@ class ApiTool(Tool):
|
|||
|
||||
return headers
|
||||
|
||||
def validate_and_parse_response(self, response: httpx.Response) -> str:
|
||||
def validate_and_parse_response(self, response: httpx.Response) -> ParsedResponse:
|
||||
"""
|
||||
validate the response
|
||||
validate the response and return parsed content with type information
|
||||
|
||||
:return: ParsedResponse with content and is_json flag
|
||||
"""
|
||||
if isinstance(response, httpx.Response):
|
||||
if response.status_code >= 400:
|
||||
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
|
||||
if not response.content:
|
||||
return "Empty response from the tool, please check your parameters and try again."
|
||||
return ParsedResponse(
|
||||
"Empty response from the tool, please check your parameters and try again.", False
|
||||
)
|
||||
|
||||
# Check content type
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
is_json_content_type = "application/json" in content_type
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
response = response.json()
|
||||
try:
|
||||
return json.dumps(response, ensure_ascii=False)
|
||||
except Exception:
|
||||
return json.dumps(response)
|
||||
json_response = response.json()
|
||||
# If content-type indicates JSON, return as JSON object
|
||||
if is_json_content_type:
|
||||
return ParsedResponse(json_response, True)
|
||||
else:
|
||||
# If content-type doesn't indicate JSON, treat as text regardless of content
|
||||
return ParsedResponse(response.text, False)
|
||||
except Exception:
|
||||
return response.text
|
||||
# Not valid JSON, return as text
|
||||
return ParsedResponse(response.text, False)
|
||||
else:
|
||||
raise ValueError(f"Invalid response type {type(response)}")
|
||||
|
||||
|
|
@ -372,7 +399,14 @@ class ApiTool(Tool):
|
|||
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
|
||||
|
||||
# validate response
|
||||
response = self.validate_and_parse_response(response)
|
||||
parsed_response = self.validate_and_parse_response(response)
|
||||
|
||||
# assemble invoke message
|
||||
yield self.create_text_message(response)
|
||||
# assemble invoke message based on response type
|
||||
if parsed_response.is_json and isinstance(parsed_response.content, dict):
|
||||
yield self.create_json_message(parsed_response.content)
|
||||
else:
|
||||
# Convert to string if needed and create text message
|
||||
text_response = (
|
||||
parsed_response.content if isinstance(parsed_response.content, str) else str(parsed_response.content)
|
||||
)
|
||||
yield self.create_text_message(text_response)
|
||||
|
|
|
|||
|
|
@ -8,23 +8,16 @@ from core.mcp.mcp_client import MCPClient
|
|||
from core.mcp.types import ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
runtime_parameters: Optional[list[ToolParameter]]
|
||||
server_url: str
|
||||
provider_id: str
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.runtime_parameters = None
|
||||
self.server_url = server_url
|
||||
self.provider_id = provider_id
|
||||
|
||||
|
|
|
|||
|
|
@ -9,11 +9,6 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
|
|||
|
||||
|
||||
class PluginTool(Tool):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
runtime_parameters: Optional[list[ToolParameter]]
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
|
||||
) -> None:
|
||||
|
|
@ -21,7 +16,7 @@ class PluginTool(Tool):
|
|||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
self.runtime_parameters = None
|
||||
self.runtime_parameters: Optional[list[ToolParameter]] = None
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.PLUGIN
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from core.tools.errors import (
|
|||
ToolProviderCredentialValidationError,
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
|
|
@ -247,7 +247,8 @@ class ToolEngine:
|
|||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
result += json.dumps(
|
||||
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
|
||||
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
else:
|
||||
result += str(response.message)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from os import listdir, path
|
|||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import TypeAdapter
|
||||
from yarl import URL
|
||||
|
||||
|
|
@ -616,7 +617,7 @@ class ToolManager:
|
|||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas
|
|||
|
||||
|
||||
class DatasetRetrieverTool(Tool):
|
||||
retrieval_tool: DatasetRetrieverBaseTool
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.retrieval_tool = retrieval_tool
|
||||
|
|
|
|||
|
|
@ -1,7 +1,14 @@
|
|||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
import pytz
|
||||
from flask_login import current_user
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
|
@ -10,6 +17,41 @@ from core.tools.tool_file_manager import ToolFileManager
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_json_value(v):
|
||||
if isinstance(v, datetime):
|
||||
tz_name = getattr(current_user, "timezone", None) if current_user is not None else None
|
||||
if not tz_name:
|
||||
tz_name = "UTC"
|
||||
return v.astimezone(pytz.timezone(tz_name)).isoformat()
|
||||
elif isinstance(v, date):
|
||||
return v.isoformat()
|
||||
elif isinstance(v, UUID):
|
||||
return str(v)
|
||||
elif isinstance(v, Decimal):
|
||||
return float(v)
|
||||
elif isinstance(v, bytes):
|
||||
try:
|
||||
return v.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return v.hex()
|
||||
elif isinstance(v, memoryview):
|
||||
return v.tobytes().hex()
|
||||
elif isinstance(v, np.ndarray):
|
||||
return v.tolist()
|
||||
elif isinstance(v, dict):
|
||||
return safe_json_dict(v)
|
||||
elif isinstance(v, list | tuple | set):
|
||||
return [safe_json_value(i) for i in v]
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
def safe_json_dict(d):
|
||||
if not isinstance(d, dict):
|
||||
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
|
||||
return {k: safe_json_value(v) for k, v in d.items()}
|
||||
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(
|
||||
|
|
@ -113,6 +155,12 @@ class ToolFileMessageTransformer:
|
|||
)
|
||||
else:
|
||||
yield message
|
||||
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
|
||||
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
|
||||
json_msg.json_object = safe_json_value(json_msg.json_object)
|
||||
yield message
|
||||
else:
|
||||
yield message
|
||||
|
||||
|
|
|
|||
|
|
@ -25,15 +25,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class WorkflowTool(Tool):
|
||||
workflow_app_id: str
|
||||
version: str
|
||||
workflow_entities: dict[str, Any]
|
||||
workflow_call_depth: int
|
||||
thread_pool_id: Optional[str] = None
|
||||
workflow_as_tool_id: str
|
||||
|
||||
label: str
|
||||
|
||||
"""
|
||||
Workflow tool.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -119,6 +119,13 @@ class ObjectSegment(Segment):
|
|||
|
||||
|
||||
class ArraySegment(Segment):
|
||||
@property
|
||||
def text(self) -> str:
|
||||
# Return empty string for empty arrays instead of "[]"
|
||||
if not self.value:
|
||||
return ""
|
||||
return super().text
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
items = []
|
||||
|
|
@ -155,6 +162,9 @@ class ArrayStringSegment(ArraySegment):
|
|||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
# Return empty string for empty arrays instead of "[]"
|
||||
if not self.value:
|
||||
return ""
|
||||
return json.dumps(self.value, ensure_ascii=False)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class SegmentType(StrEnum):
|
|||
elif array_validation == ArrayValidation.FIRST:
|
||||
return element_type.is_valid(value[0])
|
||||
else:
|
||||
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
|
||||
return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value)
|
||||
|
||||
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
|
||||
"""
|
||||
|
|
@ -152,7 +152,7 @@ class SegmentType(StrEnum):
|
|||
|
||||
|
||||
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
||||
# ARRAY_ANY does not have correpond element type.
|
||||
# ARRAY_ANY does not have corresponding element type.
|
||||
SegmentType.ARRAY_STRING: SegmentType.STRING,
|
||||
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
|
||||
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
|
||||
|
|
|
|||
|
|
@ -597,7 +597,7 @@ def _extract_text_from_vtt(vtt_bytes: bytes) -> str:
|
|||
|
||||
for i in range(1, len(raw_results)):
|
||||
spk, txt = raw_results[i]
|
||||
if spk == None:
|
||||
if spk is None:
|
||||
merged_results.append((None, current_text))
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class Executor:
|
|||
self.auth = node_data.authorization
|
||||
self.timeout = timeout
|
||||
self.ssl_verify = node_data.ssl_verify
|
||||
self.params = []
|
||||
self.params = None
|
||||
self.headers = {}
|
||||
self.content = None
|
||||
self.files = None
|
||||
|
|
@ -139,7 +139,8 @@ class Executor:
|
|||
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
|
||||
)
|
||||
|
||||
self.params = result
|
||||
if result:
|
||||
self.params = result
|
||||
|
||||
def _init_headers(self):
|
||||
"""
|
||||
|
|
@ -277,6 +278,22 @@ class Executor:
|
|||
elif self.auth.config.type == "custom":
|
||||
headers[authorization.config.header] = authorization.config.api_key or ""
|
||||
|
||||
# Handle Content-Type for multipart/form-data requests
|
||||
# Fix for issue #22880: Missing boundary when using multipart/form-data
|
||||
body = self.node_data.body
|
||||
if body and body.type == "form-data":
|
||||
# For multipart/form-data with files, let httpx handle the boundary automatically
|
||||
# by not setting Content-Type header when files are present
|
||||
if not self.files or all(f[0] == "__multipart_placeholder__" for f in self.files):
|
||||
# Only set Content-Type when there are no actual files
|
||||
# This ensures httpx generates the correct boundary
|
||||
if "content-type" not in (k.lower() for k in headers):
|
||||
headers["Content-Type"] = "multipart/form-data"
|
||||
elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE:
|
||||
# Set Content-Type for other body types
|
||||
if "content-type" not in (k.lower() for k in headers):
|
||||
headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
|
||||
|
||||
return headers
|
||||
|
||||
def _validate_and_parse_response(self, response: httpx.Response) -> Response:
|
||||
|
|
@ -384,15 +401,24 @@ class Executor:
|
|||
# '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file.
|
||||
# This prevents logging meaningless placeholder entries.
|
||||
if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files):
|
||||
for key, (filename, content, mime_type) in self.files:
|
||||
for file_entry in self.files:
|
||||
# file_entry should be (key, (filename, content, mime_type)), but handle edge cases
|
||||
if len(file_entry) != 2 or not isinstance(file_entry[1], tuple) or len(file_entry[1]) < 2:
|
||||
continue # skip malformed entries
|
||||
key = file_entry[0]
|
||||
content = file_entry[1][1]
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
# decode content
|
||||
try:
|
||||
body_string += content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# fix: decode binary content
|
||||
pass
|
||||
# decode content safely
|
||||
if isinstance(content, bytes):
|
||||
try:
|
||||
body_string += content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
body_string += content.decode("utf-8", errors="replace")
|
||||
elif isinstance(content, str):
|
||||
body_string += content
|
||||
else:
|
||||
body_string += f"[Unsupported content type: {type(content).__name__}]"
|
||||
body_string += "\r\n"
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.node_data.body:
|
||||
|
|
|
|||
|
|
@ -74,6 +74,8 @@ SupportedComparisonOperator = Literal[
|
|||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
|
|
|
|||
|
|
@ -602,6 +602,28 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
**{key: metadata_name, key_value: f"%{value}"}
|
||||
)
|
||||
)
|
||||
case "in":
|
||||
if isinstance(value, str):
|
||||
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
|
||||
escaped_value_str = ",".join(escaped_values)
|
||||
else:
|
||||
escaped_value_str = str(value)
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params(
|
||||
**{key: metadata_name, key_value: escaped_value_str}
|
||||
)
|
||||
)
|
||||
case "not in":
|
||||
if isinstance(value, str):
|
||||
escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")]
|
||||
escaped_value_str = ",".join(escaped_values)
|
||||
else:
|
||||
escaped_value_str = str(value)
|
||||
filters.append(
|
||||
(text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params(
|
||||
**{key: metadata_name, key_value: escaped_value_str}
|
||||
)
|
||||
)
|
||||
case "=" | "is":
|
||||
if isinstance(value, str):
|
||||
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import io
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
|
|
@ -33,12 +33,10 @@ from core.model_runtime.entities.message_entities import (
|
|||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
|
|
@ -1006,21 +1004,6 @@ class LLMNode(BaseNode):
|
|||
)
|
||||
return saved_file
|
||||
|
||||
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
|
||||
"""
|
||||
Fetch model schema
|
||||
"""
|
||||
model_name = self._node_data.model.name
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
|
||||
)
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
model_credentials = model_instance.credentials
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_schema
|
||||
|
||||
@staticmethod
|
||||
def fetch_structured_output_schema(
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -318,6 +318,33 @@ class ToolNode(BaseNode):
|
|||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": message.message.text,
|
||||
}
|
||||
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
|
||||
|
|
|
|||
|
|
@ -136,6 +136,8 @@ def init_app(app: DifyApp):
|
|||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
|
||||
from opentelemetry.instrumentation.celery import CeleryInstrumentor
|
||||
from opentelemetry.instrumentation.flask import FlaskInstrumentor
|
||||
from opentelemetry.instrumentation.redis import RedisInstrumentor
|
||||
from opentelemetry.instrumentation.requests import RequestsInstrumentor
|
||||
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
||||
from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider
|
||||
from opentelemetry.propagate import set_global_textmap
|
||||
|
|
@ -234,6 +236,8 @@ def init_app(app: DifyApp):
|
|||
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
|
||||
instrument_exception_logging()
|
||||
init_sqlalchemy_instrumentor(app)
|
||||
RedisInstrumentor().instrument()
|
||||
RequestsInstrumentor().instrument()
|
||||
atexit.register(shutdown_tracer)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -69,6 +69,19 @@ class Storage:
|
|||
from extensions.storage.supabase_storage import SupabaseStorage
|
||||
|
||||
return SupabaseStorage
|
||||
case StorageType.CLICKZETTA_VOLUME:
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
def create_clickzetta_volume_storage():
|
||||
# ClickZettaVolumeConfig will automatically read from environment variables
|
||||
# and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set
|
||||
volume_config = ClickZettaVolumeConfig()
|
||||
return ClickZettaVolumeStorage(volume_config)
|
||||
|
||||
return create_clickzetta_volume_storage
|
||||
case _:
|
||||
raise ValueError(f"unsupported storage type {storage_type}")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
from .clickzetta_volume_storage import ClickZettaVolumeStorage
|
||||
|
||||
__all__ = ["ClickZettaVolumeStorage"]
|
||||
|
|
@ -0,0 +1,530 @@
|
|||
"""ClickZetta Volume Storage Implementation
|
||||
|
||||
This module provides storage backend using ClickZetta Volume functionality.
|
||||
Supports Table Volume, User Volume, and External Volume types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import clickzetta # type: ignore[import]
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
from .volume_permissions import VolumePermissionManager, check_volume_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClickZettaVolumeConfig(BaseModel):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
instance: str = ""
|
||||
service: str = "api.clickzetta.com"
|
||||
workspace: str = "quick_start"
|
||||
vcluster: str = "default_ap"
|
||||
schema_name: str = "dify"
|
||||
volume_type: str = "table" # table|user|external
|
||||
volume_name: Optional[str] = None # For external volumes
|
||||
table_prefix: str = "dataset_" # Prefix for table volume names
|
||||
dify_prefix: str = "dify_km" # Directory prefix for User Volume
|
||||
permission_check: bool = True # Enable/disable permission checking
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""Validate the configuration values.
|
||||
|
||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
then fall back to CLICKZETTA_* environment variables (for vector DB config).
|
||||
"""
|
||||
import os
|
||||
|
||||
# Helper function to get environment variable with fallback
|
||||
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
|
||||
# First try CLICKZETTA_VOLUME_* specific config
|
||||
volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", ""))
|
||||
if volume_value:
|
||||
return str(volume_value)
|
||||
|
||||
# Then try environment variables
|
||||
volume_env = os.getenv(volume_key)
|
||||
if volume_env:
|
||||
return volume_env
|
||||
|
||||
# Fall back to existing CLICKZETTA_* config
|
||||
fallback_env = os.getenv(fallback_key)
|
||||
if fallback_env:
|
||||
return fallback_env
|
||||
|
||||
return default or ""
|
||||
|
||||
# Apply environment variables with fallback to existing CLICKZETTA_* config
|
||||
values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
|
||||
values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
|
||||
values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
|
||||
values.setdefault(
|
||||
"service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")
|
||||
)
|
||||
values.setdefault(
|
||||
"workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")
|
||||
)
|
||||
values.setdefault(
|
||||
"vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")
|
||||
)
|
||||
values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
|
||||
|
||||
# Volume-specific configurations (no fallback to vector DB config)
|
||||
values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table"))
|
||||
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
|
||||
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
|
||||
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
|
||||
# 暂时禁用权限检查功能,直接设置为false
|
||||
values.setdefault("permission_check", False)
|
||||
|
||||
# Validate required fields
|
||||
if not values.get("username"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required")
|
||||
if not values.get("instance"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required")
|
||||
|
||||
# Validate volume type
|
||||
volume_type = values["volume_type"]
|
||||
if volume_type not in ["table", "user", "external"]:
|
||||
raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external")
|
||||
|
||||
if volume_type == "external" and not values.get("volume_name"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class ClickZettaVolumeStorage(BaseStorage):
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
def __init__(self, config: ClickZettaVolumeConfig):
|
||||
"""Initialize ClickZetta Volume storage.
|
||||
|
||||
Args:
|
||||
config: ClickZetta Volume configuration
|
||||
"""
|
||||
self._config = config
|
||||
self._connection = None
|
||||
self._permission_manager: VolumePermissionManager | None = None
|
||||
self._init_connection()
|
||||
self._init_permission_manager()
|
||||
|
||||
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
|
||||
|
||||
def _init_connection(self):
|
||||
"""Initialize ClickZetta connection."""
|
||||
try:
|
||||
self._connection = clickzetta.connect(
|
||||
username=self._config.username,
|
||||
password=self._config.password,
|
||||
instance=self._config.instance,
|
||||
service=self._config.service,
|
||||
workspace=self._config.workspace,
|
||||
vcluster=self._config.vcluster,
|
||||
schema=self._config.schema_name,
|
||||
)
|
||||
logger.debug("ClickZetta connection established")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to connect to ClickZetta")
|
||||
raise
|
||||
|
||||
def _init_permission_manager(self):
|
||||
"""Initialize permission manager."""
|
||||
try:
|
||||
self._permission_manager = VolumePermissionManager(
|
||||
self._connection, self._config.volume_type, self._config.volume_name
|
||||
)
|
||||
logger.debug("Permission manager initialized")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to initialize permission manager")
|
||||
raise
|
||||
|
||||
def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str:
|
||||
"""Get the appropriate volume path based on volume type."""
|
||||
if self._config.volume_type == "user":
|
||||
# Add dify prefix for User Volume to organize files
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
elif self._config.volume_type == "table":
|
||||
# Check if this should use User Volume (special directories)
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
# Use User Volume with dify prefix for special directories
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
|
||||
if dataset_id:
|
||||
return f"{self._config.table_prefix}{dataset_id}/{filename}"
|
||||
else:
|
||||
# Extract dataset_id from filename if not provided
|
||||
# Format: dataset_id/filename
|
||||
if "/" in filename:
|
||||
return filename
|
||||
else:
|
||||
raise ValueError("dataset_id is required for table volume or filename must include dataset_id/")
|
||||
elif self._config.volume_type == "external":
|
||||
return filename
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str:
|
||||
"""Get SQL prefix for volume operations."""
|
||||
if self._config.volume_type == "user":
|
||||
return "USER VOLUME"
|
||||
elif self._config.volume_type == "table":
|
||||
# For Dify's current file storage pattern, most files are stored in
|
||||
# paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext"
|
||||
# These should use USER VOLUME for better compatibility
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return "USER VOLUME"
|
||||
|
||||
# Only use TABLE VOLUME for actual dataset-specific paths
|
||||
# like "dataset_12345/file.pdf" or paths with dataset_ prefix
|
||||
if dataset_id:
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
else:
|
||||
# Default table name for generic operations
|
||||
table_name = "default_dataset"
|
||||
return f"TABLE VOLUME {table_name}"
|
||||
elif self._config.volume_type == "external":
|
||||
return f"VOLUME {self._config.volume_name}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _execute_sql(self, sql: str, fetch: bool = False):
|
||||
"""Execute SQL command."""
|
||||
try:
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Connection not initialized")
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
if fetch:
|
||||
return cursor.fetchall()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("SQL execution failed: %s", sql)
|
||||
raise
|
||||
|
||||
def _ensure_table_volume_exists(self, dataset_id: str) -> None:
|
||||
"""Ensure table volume exists for the given dataset_id."""
|
||||
if self._config.volume_type != "table" or not dataset_id:
|
||||
return
|
||||
|
||||
# Skip for upload_files and other special directories that use USER VOLUME
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return
|
||||
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
|
||||
try:
|
||||
# Check if table exists
|
||||
check_sql = f"SHOW TABLES LIKE '{table_name}'"
|
||||
result = self._execute_sql(check_sql, fetch=True)
|
||||
|
||||
if not result:
|
||||
# Create table with volume
|
||||
create_sql = f"""
|
||||
CREATE TABLE {table_name} (
|
||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||
filename VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_filename (filename)
|
||||
) WITH VOLUME
|
||||
"""
|
||||
self._execute_sql(create_sql)
|
||||
logger.info("Created table volume: %s", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create table volume %s: %s", table_name, e)
|
||||
# Don't raise exception, let the operation continue
|
||||
# The table might exist but not be visible due to permissions
|
||||
|
||||
def save(self, filename: str, data: bytes) -> None:
|
||||
"""Save data to ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
data: File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Ensure table volume exists (for table volumes)
|
||||
if dataset_id:
|
||||
self._ensure_table_volume_exists(dataset_id)
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "save", dataset_id)
|
||||
|
||||
# Write data to temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
temp_file.write(data)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Upload to volume
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
actual_filename = volume_path.split("/")[-1] if "/" in volume_path else volume_path
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path)
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
Path(temp_file_path).unlink(missing_ok=True)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
"""Load file content from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "load_once", dataset_id)
|
||||
|
||||
# Download to temporary directory
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'"
|
||||
else:
|
||||
sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
# Find the downloaded file (may be in subdirectories)
|
||||
downloaded_file = None
|
||||
for root, dirs, files in os.walk(temp_dir):
|
||||
for file in files:
|
||||
if file == filename or file == os.path.basename(filename):
|
||||
downloaded_file = Path(root) / file
|
||||
break
|
||||
if downloaded_file:
|
||||
break
|
||||
|
||||
if not downloaded_file or not downloaded_file.exists():
|
||||
raise FileNotFoundError(f"Downloaded file not found: {filename}")
|
||||
|
||||
content = downloaded_file.read_bytes()
|
||||
|
||||
logger.debug("File %s loaded from ClickZetta Volume", filename)
|
||||
return content
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
"""Load file as stream from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Yields:
|
||||
File content chunks
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
batch_size = 4096
|
||||
stream = BytesIO(content)
|
||||
|
||||
while chunk := stream.read(batch_size):
|
||||
yield chunk
|
||||
|
||||
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
|
||||
|
||||
def download(self, filename: str, target_filepath: str):
|
||||
"""Download file from ClickZetta Volume to local path.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
target_filepath: Local target file path
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
|
||||
with Path(target_filepath).open("wb") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
|
||||
|
||||
def exists(self, filename: str) -> bool:
|
||||
"""Check if file exists in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
True if file exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
exists = len(rows) > 0
|
||||
logger.debug("File %s exists check: %s", filename, exists)
|
||||
return exists
|
||||
except Exception as e:
|
||||
logger.warning("Error checking file existence for %s: %s", filename, e)
|
||||
return False
|
||||
|
||||
def delete(self, filename: str):
|
||||
"""Delete file from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
"""
|
||||
if not self.exists(filename):
|
||||
logger.debug("File %s not found, skip delete", filename)
|
||||
return
|
||||
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"REMOVE {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"REMOVE {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
logger.debug("File %s deleted from ClickZetta Volume", filename)
|
||||
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
"""Scan files and directories in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
path: Path to scan (dataset_id for table volumes)
|
||||
files: Include files in results
|
||||
directories: Include directories in results
|
||||
|
||||
Returns:
|
||||
List of file/directory paths
|
||||
"""
|
||||
try:
|
||||
# For table volumes, path is treated as dataset_id
|
||||
dataset_id = None
|
||||
if self._config.volume_type == "table":
|
||||
dataset_id = path
|
||||
path = "" # Root of the table volume
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# For User Volume, add dify prefix to path
|
||||
if volume_prefix == "USER VOLUME":
|
||||
if path:
|
||||
scan_path = f"{self._config.dify_prefix}/{path}"
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'"
|
||||
else:
|
||||
if path:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix}"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
result = []
|
||||
for row in rows:
|
||||
file_path = row[0] # relative_path column
|
||||
|
||||
# For User Volume, remove dify prefix from results
|
||||
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
|
||||
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
|
||||
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
|
||||
|
||||
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
|
||||
result.append(file_path)
|
||||
|
||||
logger.debug("Scanned %d items in path %s", len(result), path)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error scanning path %s", path)
|
||||
return []
|
||||
|
|
@ -0,0 +1,516 @@
|
|||
"""ClickZetta Volume文件生命周期管理
|
||||
|
||||
该模块提供文件版本控制、自动清理、备份和恢复等生命周期管理功能。
|
||||
支持知识库文件的完整生命周期管理。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStatus(Enum):
|
||||
"""文件状态枚举"""
|
||||
|
||||
ACTIVE = "active" # 活跃状态
|
||||
ARCHIVED = "archived" # 已归档
|
||||
DELETED = "deleted" # 已删除(软删除)
|
||||
BACKUP = "backup" # 备份文件
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileMetadata:
|
||||
"""文件元数据"""
|
||||
|
||||
filename: str
|
||||
size: int | None
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
version: int | None
|
||||
status: FileStatus
|
||||
checksum: Optional[str] = None
|
||||
tags: Optional[dict[str, str]] = None
|
||||
parent_version: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
data = asdict(self)
|
||||
data["created_at"] = self.created_at.isoformat()
|
||||
data["modified_at"] = self.modified_at.isoformat()
|
||||
data["status"] = self.status.value
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "FileMetadata":
|
||||
"""从字典创建实例"""
|
||||
data = data.copy()
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
|
||||
data["status"] = FileStatus(data["status"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class FileLifecycleManager:
|
||||
"""文件生命周期管理器"""
|
||||
|
||||
def __init__(self, storage, dataset_id: Optional[str] = None):
|
||||
"""初始化生命周期管理器
|
||||
|
||||
Args:
|
||||
storage: ClickZetta Volume存储实例
|
||||
dataset_id: 数据集ID(用于Table Volume)
|
||||
"""
|
||||
self._storage = storage
|
||||
self._dataset_id = dataset_id
|
||||
self._metadata_file = ".dify_file_metadata.json"
|
||||
self._version_prefix = ".versions/"
|
||||
self._backup_prefix = ".backups/"
|
||||
self._deleted_prefix = ".deleted/"
|
||||
|
||||
# 获取权限管理器(如果存在)
|
||||
self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None)
|
||||
|
||||
def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata:
|
||||
"""保存文件并管理生命周期
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
data: 文件内容
|
||||
tags: 文件标签
|
||||
|
||||
Returns:
|
||||
文件元数据
|
||||
"""
|
||||
# 权限检查
|
||||
if not self._check_permission(filename, "save"):
|
||||
from .volume_permissions import VolumePermissionError
|
||||
|
||||
raise VolumePermissionError(
|
||||
f"Permission denied for lifecycle save operation on file: {filename}",
|
||||
operation="save",
|
||||
volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"),
|
||||
dataset_id=self._dataset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# 1. 检查是否存在旧版本
|
||||
metadata_dict = self._load_metadata()
|
||||
current_metadata = metadata_dict.get(filename)
|
||||
|
||||
# 2. 如果存在旧版本,创建版本备份
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata)
|
||||
|
||||
# 3. 计算文件信息
|
||||
now = datetime.now()
|
||||
checksum = self._calculate_checksum(data)
|
||||
new_version = (current_metadata["version"] + 1) if current_metadata else 1
|
||||
|
||||
# 4. 保存新文件
|
||||
self._storage.save(filename, data)
|
||||
|
||||
# 5. 创建元数据
|
||||
created_at = now
|
||||
parent_version = None
|
||||
|
||||
if current_metadata:
|
||||
# 如果created_at是字符串,转换为datetime
|
||||
if isinstance(current_metadata["created_at"], str):
|
||||
created_at = datetime.fromisoformat(current_metadata["created_at"])
|
||||
else:
|
||||
created_at = current_metadata["created_at"]
|
||||
parent_version = current_metadata["version"]
|
||||
|
||||
file_metadata = FileMetadata(
|
||||
filename=filename,
|
||||
size=len(data),
|
||||
created_at=created_at,
|
||||
modified_at=now,
|
||||
version=new_version,
|
||||
status=FileStatus.ACTIVE,
|
||||
checksum=checksum,
|
||||
tags=tags or {},
|
||||
parent_version=parent_version,
|
||||
)
|
||||
|
||||
# 6. 更新元数据
|
||||
metadata_dict[filename] = file_metadata.to_dict()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s saved with lifecycle management, version %s", filename, new_version)
|
||||
return file_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save file with lifecycle")
|
||||
raise
|
||||
|
||||
def get_file_metadata(self, filename: str) -> Optional[FileMetadata]:
|
||||
"""获取文件元数据
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
文件元数据,如果不存在返回None
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
return FileMetadata.from_dict(metadata_dict[filename])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get file metadata for %s", filename)
|
||||
return None
|
||||
|
||||
def list_file_versions(self, filename: str) -> list[FileMetadata]:
|
||||
"""列出文件的所有版本
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
文件版本列表,按版本号排序
|
||||
"""
|
||||
try:
|
||||
versions = []
|
||||
|
||||
# 获取当前版本
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
versions.append(current_metadata)
|
||||
|
||||
# 获取历史版本
|
||||
version_pattern = f"{self._version_prefix}{filename}.v*"
|
||||
try:
|
||||
version_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
for file_path in version_files:
|
||||
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
|
||||
# 解析版本号
|
||||
version_str = file_path.split(".v")[-1].split(".")[0]
|
||||
try:
|
||||
version_num = int(version_str)
|
||||
# 这里简化处理,实际应该从版本文件中读取元数据
|
||||
# 暂时创建基本的元数据信息
|
||||
except ValueError:
|
||||
continue
|
||||
except:
|
||||
# 如果无法扫描版本文件,只返回当前版本
|
||||
pass
|
||||
|
||||
return sorted(versions, key=lambda x: x.version or 0, reverse=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list file versions for %s", filename)
|
||||
return []
|
||||
|
||||
def restore_version(self, filename: str, version: int) -> bool:
|
||||
"""恢复文件到指定版本
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
version: 要恢复的版本号
|
||||
|
||||
Returns:
|
||||
恢复是否成功
|
||||
"""
|
||||
try:
|
||||
version_filename = f"{self._version_prefix}{filename}.v{version}"
|
||||
|
||||
# 检查版本文件是否存在
|
||||
if not self._storage.exists(version_filename):
|
||||
logger.warning("Version %s of %s not found", version, filename)
|
||||
return False
|
||||
|
||||
# 读取版本文件内容
|
||||
version_data = self._storage.load_once(version_filename)
|
||||
|
||||
# 保存当前版本为备份
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata.to_dict())
|
||||
|
||||
# 恢复文件
|
||||
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to restore %s to version %s", filename, version)
|
||||
return False
|
||||
|
||||
def archive_file(self, filename: str) -> bool:
|
||||
"""归档文件
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
归档是否成功
|
||||
"""
|
||||
# 权限检查
|
||||
if not self._check_permission(filename, "archive"):
|
||||
logger.warning("Permission denied for archive operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# 更新文件状态为归档
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename not in metadata_dict:
|
||||
logger.warning("File %s not found in metadata", filename)
|
||||
return False
|
||||
|
||||
metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s archived successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to archive file %s", filename)
|
||||
return False
|
||||
|
||||
def soft_delete_file(self, filename: str) -> bool:
|
||||
"""软删除文件(移动到删除目录)
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
# 权限检查
|
||||
if not self._check_permission(filename, "delete"):
|
||||
logger.warning("Permission denied for soft delete operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# 检查文件是否存在
|
||||
if not self._storage.exists(filename):
|
||||
logger.warning("File %s not found", filename)
|
||||
return False
|
||||
|
||||
# 读取文件内容
|
||||
file_data = self._storage.load_once(filename)
|
||||
|
||||
# 移动到删除目录
|
||||
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self._storage.save(deleted_filename, file_data)
|
||||
|
||||
# 删除原文件
|
||||
self._storage.delete(filename)
|
||||
|
||||
# 更新元数据
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
metadata_dict[filename]["status"] = FileStatus.DELETED.value
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s soft deleted successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to soft delete file %s", filename)
|
||||
return False
|
||||
|
||||
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
|
||||
"""清理旧版本文件
|
||||
|
||||
Args:
|
||||
max_versions: 保留的最大版本数
|
||||
max_age_days: 版本文件的最大保留天数
|
||||
|
||||
Returns:
|
||||
清理的文件数量
|
||||
"""
|
||||
try:
|
||||
cleaned_count = 0
|
||||
cutoff_date = datetime.now() - timedelta(days=max_age_days)
|
||||
|
||||
# 获取所有版本文件
|
||||
try:
|
||||
all_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
version_files = [f for f in all_files if f.startswith(self._version_prefix)]
|
||||
|
||||
# 按文件分组
|
||||
file_versions: dict[str, list[tuple[int, str]]] = {}
|
||||
for version_file in version_files:
|
||||
# 解析文件名和版本
|
||||
parts = version_file[len(self._version_prefix) :].split(".v")
|
||||
if len(parts) >= 2:
|
||||
base_filename = parts[0]
|
||||
version_part = parts[1].split(".")[0]
|
||||
try:
|
||||
version_num = int(version_part)
|
||||
if base_filename not in file_versions:
|
||||
file_versions[base_filename] = []
|
||||
file_versions[base_filename].append((version_num, version_file))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# 清理每个文件的旧版本
|
||||
for base_filename, versions in file_versions.items():
|
||||
# 按版本号排序
|
||||
versions.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 保留最新的max_versions个版本,删除其余的
|
||||
if len(versions) > max_versions:
|
||||
to_delete = versions[max_versions:]
|
||||
for version_num, version_file in to_delete:
|
||||
self._storage.delete(version_file)
|
||||
cleaned_count += 1
|
||||
logger.debug("Cleaned old version: %s", version_file)
|
||||
|
||||
logger.info("Cleaned %d old version files", cleaned_count)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not scan for version files: %s", e)
|
||||
|
||||
return cleaned_count
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to cleanup old versions")
|
||||
return 0
|
||||
|
||||
def get_storage_statistics(self) -> dict[str, Any]:
|
||||
"""获取存储统计信息
|
||||
|
||||
Returns:
|
||||
存储统计字典
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
|
||||
stats: dict[str, Any] = {
|
||||
"total_files": len(metadata_dict),
|
||||
"active_files": 0,
|
||||
"archived_files": 0,
|
||||
"deleted_files": 0,
|
||||
"total_size": 0,
|
||||
"versions_count": 0,
|
||||
"oldest_file": None,
|
||||
"newest_file": None,
|
||||
}
|
||||
|
||||
oldest_date = None
|
||||
newest_date = None
|
||||
|
||||
for filename, metadata in metadata_dict.items():
|
||||
file_meta = FileMetadata.from_dict(metadata)
|
||||
|
||||
# 统计文件状态
|
||||
if file_meta.status == FileStatus.ACTIVE:
|
||||
stats["active_files"] = (stats["active_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.ARCHIVED:
|
||||
stats["archived_files"] = (stats["archived_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.DELETED:
|
||||
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1
|
||||
|
||||
# 统计大小
|
||||
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)
|
||||
|
||||
# 统计版本
|
||||
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)
|
||||
|
||||
# 找出最新和最旧的文件
|
||||
if oldest_date is None or file_meta.created_at < oldest_date:
|
||||
oldest_date = file_meta.created_at
|
||||
stats["oldest_file"] = filename
|
||||
|
||||
if newest_date is None or file_meta.modified_at > newest_date:
|
||||
newest_date = file_meta.modified_at
|
||||
stats["newest_file"] = filename
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get storage statistics")
|
||||
return {}
|
||||
|
||||
def _create_version_backup(self, filename: str, metadata: dict):
|
||||
"""创建版本备份"""
|
||||
try:
|
||||
# 读取当前文件内容
|
||||
current_data = self._storage.load_once(filename)
|
||||
|
||||
# 保存为版本文件
|
||||
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
|
||||
self._storage.save(version_filename, current_data)
|
||||
|
||||
logger.debug("Created version backup: %s", version_filename)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create version backup for %s: %s", filename, e)
|
||||
|
||||
def _load_metadata(self) -> dict[str, Any]:
|
||||
"""加载元数据文件"""
|
||||
try:
|
||||
if self._storage.exists(self._metadata_file):
|
||||
metadata_content = self._storage.load_once(self._metadata_file)
|
||||
result = json.loads(metadata_content.decode("utf-8"))
|
||||
return dict(result) if result else {}
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load metadata: %s", e)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, metadata_dict: dict):
|
||||
"""保存元数据文件"""
|
||||
try:
|
||||
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
||||
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
|
||||
logger.debug("Metadata saved successfully")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save metadata")
|
||||
raise
|
||||
|
||||
def _calculate_checksum(self, data: bytes) -> str:
|
||||
"""计算文件校验和"""
|
||||
import hashlib
|
||||
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
def _check_permission(self, filename: str, operation: str) -> bool:
|
||||
"""检查文件操作权限
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
operation: 操作类型
|
||||
|
||||
Returns:
|
||||
True if permission granted, False otherwise
|
||||
"""
|
||||
# 如果没有权限管理器,默认允许
|
||||
if not self._permission_manager:
|
||||
return True
|
||||
|
||||
try:
|
||||
# 根据操作类型映射到权限
|
||||
operation_mapping = {
|
||||
"save": "save",
|
||||
"load": "load_once",
|
||||
"delete": "delete",
|
||||
"archive": "delete", # 归档需要删除权限
|
||||
"restore": "save", # 恢复需要写权限
|
||||
"cleanup": "delete", # 清理需要删除权限
|
||||
"read": "load_once",
|
||||
"write": "save",
|
||||
}
|
||||
|
||||
mapped_operation = operation_mapping.get(operation, operation)
|
||||
|
||||
# 检查权限
|
||||
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Permission check failed for %s operation %s", filename, operation)
|
||||
# 安全默认:权限检查失败时拒绝访问
|
||||
return False
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue