merge main

This commit is contained in:
zxhlyh 2025-08-05 10:30:53 +08:00
commit 201e4cd64d
308 changed files with 10716 additions and 1994 deletions

View File

@ -99,3 +99,6 @@ jobs:
- name: Run Tool
run: uv run --project api bash dev/pytest/pytest_tools.sh
- name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh

View File

@ -9,6 +9,7 @@ permissions:
jobs:
autofix:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

View File

@ -235,6 +235,10 @@ Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https:/
One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Deploy to AKS with Azure Devops Pipeline
One-Click deploy Dify to AKS with [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## Contributing

View File

@ -217,6 +217,10 @@ docker compose up -d
انشر Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### استخدام Azure Devops Pipeline للنشر على AKS
انشر Dify على AKS بنقرة واحدة باستخدام [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## المساهمة

View File

@ -235,6 +235,10 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### AKS-এ ডিপ্লয় করার জন্য Azure Devops Pipeline ব্যবহার
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ব্যবহার করে Dify কে AKS-এ এক ক্লিকে ডিপ্লয় করুন
## Contributing

View File

@ -233,6 +233,9 @@ docker compose up -d
使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云
#### 使用 Azure Devops Pipeline 部署到AKS
使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 将 Dify 一键部署到 AKS
## Star History

View File

@ -230,6 +230,10 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/)
Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Verwendung von Azure Devops Pipeline für AKS-Bereitstellung
Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) verwenden
## Contributing

View File

@ -230,6 +230,10 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/)
Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Uso de Azure Devops Pipeline para implementar en AKS
Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## Contribuir

View File

@ -228,6 +228,10 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/)
Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Utilisation d'Azure Devops Pipeline pour déployer sur AKS
Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## Contribuer

View File

@ -227,6 +227,10 @@ docker compose up -d
#### Alibaba Cloud Data Management
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます
#### AKSへのデプロイにAzure Devops Pipelineを使用
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)を使用してDifyをAKSにワンクリックでデプロイ
## 貢献

View File

@ -228,6 +228,10 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### AKS 'e' Deploy je Azure Devops Pipeline lo'laH
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) lo'laH Dify AKS 'e' wa'DIch click 'e' Deploy
## Contributing

View File

@ -222,6 +222,10 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다
#### AKS에 배포하기 위해 Azure Devops Pipeline 사용
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)을 사용하여 Dify를 AKS에 원클릭으로 배포
## 기여

View File

@ -227,6 +227,10 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/)
Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Usando Azure Devops Pipeline para Implantar no AKS
Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## Contribuindo

View File

@ -228,6 +228,10 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Uporaba Azure Devops Pipeline za uvajanje v AKS
Z enim klikom namestite Dify v AKS z uporabo [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## Prispevam

View File

@ -221,6 +221,10 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın
#### AKS'ye Dağıtım için Azure Devops Pipeline Kullanımı
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) kullanarak Dify'ı tek tıkla AKS'ye dağıtın
## Katkıda Bulunma

View File

@ -233,6 +233,10 @@ Dify 的所有功能都提供相應的 API因此您可以輕鬆地將 Dify
透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲
#### 使用 Azure Devops Pipeline 部署到AKS
使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 將 Dify 一鍵部署到 AKS
## 貢獻

View File

@ -224,6 +224,10 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/)
Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
#### Sử dụng Azure Devops Pipeline để Triển khai lên AKS
Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure Devops Pipeline Helm Chart bởi @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
## Đóng góp

View File

@ -232,6 +232,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com
TABLESTORE_INSTANCE_NAME=instance-name
TABLESTORE_ACCESS_KEY_ID=xxx
TABLESTORE_ACCESS_KEY_SECRET=xxx
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false
# Tidb Vector configuration
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com

View File

@ -5,6 +5,7 @@ import secrets
from typing import Any, Optional
import click
import sqlalchemy as sa
from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
@ -462,7 +463,7 @@ def convert_to_agent_apps():
"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query))
rs = conn.execute(sa.text(sql_query))
apps = []
for i in rs:
@ -707,7 +708,7 @@ def fix_app_site_missing():
sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id
where sites.id is null limit 1000"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql))
rs = conn.execute(sa.text(sql))
processed_count = 0
for i in rs:
@ -921,7 +922,7 @@ def clear_orphaned_file_records(force: bool):
)
orphaned_message_files = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])})
@ -942,7 +943,7 @@ def clear_orphaned_file_records(force: bool):
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
query = "DELETE FROM message_files WHERE id IN :ids"
with db.engine.begin() as conn:
conn.execute(db.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
click.echo(
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
)
@ -959,7 +960,7 @@ def clear_orphaned_file_records(force: bool):
click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white"))
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]})
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
@ -979,7 +980,7 @@ def clear_orphaned_file_records(force: bool):
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
elif ids_table["type"] == "text":
@ -994,7 +995,7 @@ def clear_orphaned_file_records(force: bool):
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
@ -1013,7 +1014,7 @@ def clear_orphaned_file_records(force: bool):
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
@ -1042,7 +1043,7 @@ def clear_orphaned_file_records(force: bool):
click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white"))
query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids"
with db.engine.begin() as conn:
conn.execute(db.text(query), {"ids": tuple(orphaned_files)})
conn.execute(sa.text(query), {"ids": tuple(orphaned_files)})
except Exception as e:
click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red"))
return
@ -1112,7 +1113,7 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white"))
query = f"SELECT {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
all_files_in_tables.append(str(i[0]))
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))

View File

@ -215,7 +215,7 @@ class DatabaseConfig(BaseSettings):
class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field(
description="Backend for Celery task results. Options: 'database', 'redis'.",
description="Backend for Celery task results. Options: 'database', 'redis', 'rabbitmq'.",
default="redis",
)
@ -245,7 +245,12 @@ class CeleryConfig(DatabaseConfig):
@computed_field
def CELERY_RESULT_BACKEND(self) -> str | None:
return f"db+{self.SQLALCHEMY_DATABASE_URI}" if self.CELERY_BACKEND == "database" else self.CELERY_BROKER_URL
if self.CELERY_BACKEND in ("database", "rabbitmq"):
return f"db+{self.SQLALCHEMY_DATABASE_URI}"
elif self.CELERY_BACKEND == "redis":
return self.CELERY_BROKER_URL
else:
return None
@property
def BROKER_USE_SSL(self) -> bool:

View File

@ -28,3 +28,8 @@ class TableStoreConfig(BaseSettings):
description="AccessKey secret for the instance name",
default=None,
)
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: bool = Field(
description="Whether to normalize full-text search scores to [0, 1]",
default=False,
)

View File

@ -9,10 +9,10 @@ DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])

View File

@ -84,6 +84,7 @@ from .datasets import (
external,
hit_testing,
metadata,
upload_file,
website,
)
from .datasets.rag_pipeline import (

View File

@ -67,7 +67,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "message_count": i.message_count})
@ -176,7 +176,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
@ -234,7 +234,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"}
@ -310,7 +310,7 @@ ORDER BY
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
@ -373,7 +373,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{
@ -435,7 +435,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)})
@ -495,7 +495,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})

View File

@ -2,6 +2,7 @@ from datetime import datetime
from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restful import Resource, reqparse
@ -71,7 +72,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "runs": i.runs})
@ -133,7 +134,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
@ -195,7 +196,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{
@ -277,7 +278,7 @@ GROUP BY
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}

View File

@ -645,7 +645,7 @@ class DocumentIndexingStatusApi(DocumentResource):
return marshal(document_dict, document_status_fields)
class DocumentDetailApi(DocumentResource):
class DocumentApi(DocumentResource):
METADATA_CHOICES = {"all", "only", "without"}
@setup_required
@ -733,6 +733,28 @@ class DocumentDetailApi(DocumentResource):
return response, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id)
try:
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
class DocumentProcessingApi(DocumentResource):
@setup_required
@ -771,30 +793,6 @@ class DocumentProcessingApi(DocumentResource):
return {"result": "success"}, 200
class DocumentDeleteApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id)
try:
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
class DocumentMetadataApi(DocumentResource):
@setup_required
@login_required
@ -1075,11 +1073,10 @@ api.add_resource(
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
)
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")

View File

@ -0,0 +1,62 @@
from flask_login import current_user
from flask_restful import Resource
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import UploadFile
from services.dataset_service import DocumentService
class UploadFileApi(Resource):
@setup_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""Get upload file."""
# check dataset
dataset_id = str(dataset_id)
dataset = (
db.session.query(Dataset)
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id)
.first()
)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check upload file
if document.data_source_type != "upload_file":
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:
raise ValueError("Upload file id not found in document data source info.")
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"url": url,
"download_url": f"{url}&as_attachment=true",
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at.timestamp(),
}, 200
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")

View File

@ -127,7 +127,7 @@ class EducationActivateLimitError(BaseHTTPException):
code = 429
class CompilanceRateLimitError(BaseHTTPException):
error_code = "compilance_rate_limit"
class ComplianceRateLimitError(BaseHTTPException):
error_code = "compliance_rate_limit"
description = "Rate limit exceeded for downloading compliance report."
code = 429

View File

@ -2,7 +2,7 @@ import logging
from flask import request
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.service_api import api
@ -30,6 +30,7 @@ from libs import helper
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
@ -113,7 +114,7 @@ class ChatApi(Resource):
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
parser.add_argument("workflow_id", type=str, required=False, location="json")
args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
@ -128,6 +129,12 @@ class ChatApi(Resource):
)
return helper.compact_generate_response(response)
except WorkflowNotFoundError as ex:
raise NotFound(str(ex))
except IsDraftWorkflowError as ex:
raise BadRequest(str(ex))
except WorkflowIdFormatError as ex:
raise BadRequest(str(ex))
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:

View File

@ -1,7 +1,9 @@
import json
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import BadRequest, NotFound
import services
from controllers.service_api import api
@ -15,6 +17,7 @@ from fields.conversation_fields import (
simple_conversation_fields,
)
from fields.conversation_variable_fields import (
conversation_variable_fields,
conversation_variable_infinite_scroll_pagination_fields,
)
from libs.helper import uuid_value
@ -120,7 +123,41 @@ class ConversationVariablesApi(Resource):
raise NotFound("Conversation Not Exists.")
class ConversationVariableDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(conversation_variable_fields)
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
"""Update a conversation variable's value"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conversation_id = str(c_id)
variable_id = str(variable_id)
parser = reqparse.RequestParser()
parser.add_argument("value", required=True, location="json")
args = parser.parse_args()
try:
return ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id, end_user, json.loads(args["value"])
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationVariableNotExistsError:
raise NotFound("Conversation Variable Not Exists.")
except services.errors.conversation.ConversationVariableTypeMismatchError as e:
raise BadRequest(str(e))
api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="conversation_name")
api.add_resource(ConversationApi, "/conversations")
api.add_resource(ConversationDetailApi, "/conversations/<uuid:c_id>", endpoint="conversation_detail")
api.add_resource(ConversationVariablesApi, "/conversations/<uuid:c_id>/variables", endpoint="conversation_variables")
api.add_resource(
ConversationVariableDetailApi,
"/conversations/<uuid:c_id>/variables/<uuid:variable_id>",
endpoint="conversation_variable_detail",
methods=["PUT"],
)

View File

@ -5,7 +5,7 @@ from flask import request
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import InternalServerError
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.service_api import api
from controllers.service_api.app.error import (
@ -34,6 +34,7 @@ from libs.helper import TimestampField
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_app_service import WorkflowAppService
@ -120,6 +121,59 @@ class WorkflowRunApi(Resource):
raise InternalServerError()
class WorkflowRunByIdApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, workflow_id: str):
"""
Run specific workflow by ID
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
args = parser.parse_args()
# Add workflow_id to args for AppGenerateService
args["workflow_id"] = workflow_id
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming"
try:
response = AppGenerateService.generate(
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
)
return helper.compact_generate_response(response)
except WorkflowNotFoundError as ex:
raise NotFound(str(ex))
except IsDraftWorkflowError as ex:
raise BadRequest(str(ex))
except WorkflowIdFormatError as ex:
raise BadRequest(str(ex))
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowTaskStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id: str):
@ -193,5 +247,6 @@ class WorkflowAppLogApi(Resource):
api.add_resource(WorkflowRunApi, "/workflows/run")
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_run_id>")
api.add_resource(WorkflowRunByIdApi, "/workflows/<string:workflow_id>/run")
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
api.add_resource(WorkflowAppLogApi, "/workflows/logs")

View File

@ -358,39 +358,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
return documents_and_batch_fields, 200
class DocumentDeleteApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document."""
document_id = str(document_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
raise NotFound("Document Not Exists.")
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
try:
# delete document
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return 204
class DocumentListApi(DatasetApiResource):
def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
@ -473,7 +440,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
return data
class DocumentDetailApi(DatasetApiResource):
class DocumentApi(DatasetApiResource):
METADATA_CHOICES = {"all", "only", "without"}
def get(self, tenant_id, dataset_id, document_id):
@ -567,6 +534,37 @@ class DocumentDetailApi(DatasetApiResource):
return response
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document."""
document_id = str(document_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
raise NotFound("Document Not Exists.")
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
try:
# delete document
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return 204
api.add_resource(
DocumentAddByTextApi,
@ -588,7 +586,6 @@ api.add_resource(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
)
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")

View File

@ -176,7 +176,7 @@ class ProviderConfig(BasicProviderConfig):
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
required: bool = False
default: Optional[Union[int, str]] = None
default: Optional[Union[int, str, float, bool]] = None
options: Optional[list[Option]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None

View File

@ -32,7 +32,7 @@ def get_attr(*, file: File, attr: FileAttribute):
case FileAttribute.TRANSFER_METHOD:
return file.transfer_method.value
case FileAttribute.URL:
return file.remote_url
return _to_url(file)
case FileAttribute.EXTENSION:
return file.extension
case FileAttribute.RELATED_ID:

View File

@ -322,7 +322,7 @@ class OpsTraceManager:
:return:
"""
# auth check
if enabled == True:
if enabled:
try:
provider_config_map[tracing_provider]
except KeyError:

View File

@ -7,6 +7,7 @@ from urllib.parse import urlparse
import requests
from elasticsearch import Elasticsearch
from flask import current_app
from packaging.version import parse as parse_version
from pydantic import BaseModel, model_validator
from core.rag.datasource.vdb.field import Field
@ -149,7 +150,7 @@ class ElasticSearchVector(BaseVector):
return cast(str, info["version"]["number"])
def _check_version(self):
if self._version < "8.0.0":
if parse_version(self._version) < parse_version("8.0.0"):
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
def get_type(self) -> str:

View File

@ -1,5 +1,6 @@
import json
import logging
import math
from typing import Any, Optional
import tablestore # type: ignore
@ -22,6 +23,7 @@ class TableStoreConfig(BaseModel):
access_key_secret: Optional[str] = None
instance_name: Optional[str] = None
endpoint: Optional[str] = None
normalize_full_text_bm25_score: Optional[bool] = False
@model_validator(mode="before")
@classmethod
@ -47,6 +49,7 @@ class TableStoreVector(BaseVector):
config.access_key_secret,
config.instance_name,
)
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
self._table_name = f"{collection_name}"
self._index_name = f"{collection_name}_idx"
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
@ -131,8 +134,8 @@ class TableStoreVector(BaseVector):
filtered_list = None
if document_ids_filter:
filtered_list = ["document_id=" + item for item in document_ids_filter]
return self._search_by_full_text(query, filtered_list, top_k)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._search_by_full_text(query, filtered_list, top_k, score_threshold)
def delete(self) -> None:
self._delete_table_if_exist()
@ -318,7 +321,19 @@ class TableStoreVector(BaseVector):
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
@staticmethod
def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float:
"""
Args:
score: BM25 search score.
k: decay factor, the larger the k, the steeper the low score end
"""
normalized_score = 1 - math.exp(-k * score)
return max(0.0, min(1.0, normalized_score))
def _search_by_full_text(
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
@ -339,15 +354,27 @@ class TableStoreVector(BaseVector):
documents = []
for search_hit in search_response.search_hits:
score = None
if self._normalize_full_text_bm25_score:
score = self._normalize_score_exp_decay(search_hit.score)
# skip when score is below threshold and use normalize score
if score and score <= score_threshold:
continue
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
vector_str = ots_column_map.get(Field.VECTOR.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}
vector_str = ots_column_map.get(Field.VECTOR.value)
vector = json.loads(vector_str) if vector_str else None
if score:
metadata["score"] = score
documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
@ -355,6 +382,8 @@ class TableStoreVector(BaseVector):
metadata=metadata,
)
)
if self._normalize_full_text_bm25_score:
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
@ -375,5 +404,6 @@ class TableStoreVectorFactory(AbstractVectorFactory):
instance_name=dify_config.TABLESTORE_INSTANCE_NAME,
access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID,
access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET,
normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE,
),
)

View File

@ -27,7 +27,7 @@ class TimezoneConversionTool(BuiltinTool):
target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore
if not target_time:
yield self.create_text_message(
f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}"
f"Invalid datetime and timezone: {current_time},{current_timezone},{target_timezone}"
)
return

View File

@ -7,6 +7,7 @@ from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
import sqlalchemy as sa
from pydantic import TypeAdapter
from yarl import URL
@ -616,7 +617,7 @@ class ToolManager:
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
@classmethod

View File

@ -109,7 +109,7 @@ class SegmentType(StrEnum):
elif array_validation == ArrayValidation.FIRST:
return element_type.is_valid(value[0])
else:
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value)
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
"""
@ -152,7 +152,7 @@ class SegmentType(StrEnum):
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
# ARRAY_ANY does not have correpond element type.
# ARRAY_ANY does not have corresponding element type.
SegmentType.ARRAY_STRING: SegmentType.STRING,
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,

View File

@ -597,7 +597,7 @@ def _extract_text_from_vtt(vtt_bytes: bytes) -> str:
for i in range(1, len(raw_results)):
spk, txt = raw_results[i]
if spk == None:
if spk is None:
merged_results.append((None, current_text))
continue

View File

@ -277,6 +277,22 @@ class Executor:
elif self.auth.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key or ""
# Handle Content-Type for multipart/form-data requests
# Fix for issue #22880: Missing boundary when using multipart/form-data
body = self.node_data.body
if body and body.type == "form-data":
# For multipart/form-data with files, let httpx handle the boundary automatically
# by not setting Content-Type header when files are present
if not self.files or all(f[0] == "__multipart_placeholder__" for f in self.files):
# Only set Content-Type when there are no actual files
# This ensures httpx generates the correct boundary
if "content-type" not in (k.lower() for k in headers):
headers["Content-Type"] = "multipart/form-data"
elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE:
# Set Content-Type for other body types
if "content-type" not in (k.lower() for k in headers):
headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
return headers
def _validate_and_parse_response(self, response: httpx.Response) -> Response:
@ -384,15 +400,24 @@ class Executor:
# '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file.
# This prevents logging meaningless placeholder entries.
if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files):
for key, (filename, content, mime_type) in self.files:
for file_entry in self.files:
# file_entry should be (key, (filename, content, mime_type)), but handle edge cases
if len(file_entry) != 2 or not isinstance(file_entry[1], tuple) or len(file_entry[1]) < 2:
continue # skip malformed entries
key = file_entry[0]
content = file_entry[1][1]
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
# decode content
try:
body_string += content.decode("utf-8")
except UnicodeDecodeError:
# fix: decode binary content
pass
# decode content safely
if isinstance(content, bytes):
try:
body_string += content.decode("utf-8")
except UnicodeDecodeError:
body_string += content.decode("utf-8", errors="replace")
elif isinstance(content, str):
body_string += content
else:
body_string += f"[Unsupported content type: {type(content).__name__}]"
body_string += "\r\n"
body_string += f"--{boundary}--\r\n"
elif self.node_data.body:

View File

@ -3,7 +3,7 @@ import io
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
@ -33,12 +33,10 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelFeature,
ModelPropertyKey,
ModelType,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -1006,21 +1004,6 @@ class LLMNode(BaseNode):
)
return saved_file
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
"""
Fetch model schema
"""
model_name = self._node_data.model.name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
)
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_schema
@staticmethod
def fetch_structured_output_schema(
*,

View File

@ -318,6 +318,33 @@ class ToolNode(BaseNode):
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": message.message.text,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])

View File

@ -1,4 +1,6 @@
import mimetypes
import os
import urllib.parse
import uuid
from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast
@ -241,16 +243,21 @@ def _build_from_remote_url(
def _get_remote_file_info(url: str):
file_size = -1
filename = url.split("/")[-1].split("?")[0] or "unknown_file"
mime_type = mimetypes.guess_type(filename)[0] or ""
parsed_url = urllib.parse.urlparse(url)
url_path = parsed_url.path
filename = os.path.basename(url_path)
# Initialize mime_type from filename as fallback
mime_type, _ = mimetypes.guess_type(filename)
resp = ssrf_proxy.head(url, follow_redirects=True)
resp = cast(httpx.Response, resp)
if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"):
filename = str(content_disposition.split("filename=")[-1].strip('"'))
# Re-guess mime_type from updated filename
mime_type, _ = mimetypes.guess_type(filename)
file_size = int(resp.headers.get("Content-Length", file_size))
mime_type = mime_type or str(resp.headers.get("Content-Type", ""))
return mime_type, filename, file_size

View File

@ -59,6 +59,8 @@ model_config_fields = {
"updated_at": TimestampField,
}
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
app_detail_fields = {
"id": fields.String,
"name": fields.String,
@ -77,6 +79,7 @@ app_detail_fields = {
"updated_by": fields.String,
"updated_at": TimestampField,
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_fields)),
}
prompt_config_fields = {
@ -92,8 +95,6 @@ model_config_partial_fields = {
"updated_at": TimestampField,
}
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
app_partial_fields = {
"id": fields.String,
"name": fields.String,
@ -185,7 +186,6 @@ app_detail_fields_with_site = {
"enable_api": fields.Boolean,
"model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
"site": fields.Nested(site_fields),
"api_base_url": fields.String,
"use_icon_as_answer_icon": fields.Boolean,
"max_active_requests": fields.Integer,
@ -195,6 +195,8 @@ app_detail_fields_with_site = {
"updated_at": TimestampField,
"deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_fields)),
"site": fields.Nested(site_fields),
}

View File

@ -0,0 +1,25 @@
"""manual dataset field update
Revision ID: 532b3f888abf
Revises: 8bcc02c9bd07
Create Date: 2025-07-24 14:50:48.779833
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '532b3f888abf'
down_revision = '8bcc02c9bd07'
branch_labels = None
depends_on = None
def upgrade():
op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'::character varying")
def downgrade():
op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'")

View File

@ -3,8 +3,9 @@ import json
from datetime import datetime
from typing import Optional, cast
import sqlalchemy as sa
from flask_login import UserMixin # type: ignore
from sqlalchemy import func, select
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from models.base import Base
@ -83,26 +84,24 @@ class AccountStatus(enum.StrEnum):
class Account(UserMixin, Base):
__tablename__ = "accounts"
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(db.String(255))
email: Mapped[str] = mapped_column(db.String(255))
password: Mapped[Optional[str]] = mapped_column(db.String(255))
password_salt: Mapped[Optional[str]] = mapped_column(db.String(255))
avatar: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
interface_language: Mapped[Optional[str]] = mapped_column(db.String(255))
interface_theme: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
timezone: Mapped[Optional[str]] = mapped_column(db.String(255))
last_login_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
last_login_ip: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
last_active_at: Mapped[datetime] = mapped_column(
db.DateTime, server_default=func.current_timestamp(), nullable=False
)
status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'active'::character varying"))
initialized_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(String(255))
email: Mapped[str] = mapped_column(String(255))
password: Mapped[Optional[str]] = mapped_column(String(255))
password_salt: Mapped[Optional[str]] = mapped_column(String(255))
avatar: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
interface_language: Mapped[Optional[str]] = mapped_column(String(255))
interface_theme: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
timezone: Mapped[Optional[str]] = mapped_column(String(255))
last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying"))
initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
@reconstructor
def init_on_load(self):
@ -197,16 +196,16 @@ class TenantStatus(enum.StrEnum):
class Tenant(Base):
__tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(db.String(255))
encrypt_public_key = db.Column(db.Text)
plan: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'basic'::character varying"))
status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying"))
custom_config: Mapped[Optional[str]] = mapped_column(db.Text)
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(String(255))
encrypt_public_key = db.Column(sa.Text)
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying"))
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
custom_config: Mapped[Optional[str]] = mapped_column(sa.Text)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
def get_accounts(self) -> list[Account]:
return (
@ -227,56 +226,56 @@ class Tenant(Base):
class TenantAccountJoin(Base):
__tablename__ = "tenant_account_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
db.Index("tenant_account_join_account_id_idx", "account_id"),
db.Index("tenant_account_join_tenant_id_idx", "tenant_id"),
db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
sa.Index("tenant_account_join_account_id_idx", "account_id"),
sa.Index("tenant_account_join_tenant_id_idx", "tenant_id"),
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
account_id: Mapped[str] = mapped_column(StringUUID)
current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
role: Mapped[str] = mapped_column(db.String(16), server_default="normal")
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
role: Mapped[str] = mapped_column(String(16), server_default="normal")
invited_by: Mapped[Optional[str]] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
class AccountIntegrate(Base):
__tablename__ = "account_integrates"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
db.UniqueConstraint("account_id", "provider", name="unique_account_provider"),
db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
sa.UniqueConstraint("account_id", "provider", name="unique_account_provider"),
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
account_id: Mapped[str] = mapped_column(StringUUID)
provider: Mapped[str] = mapped_column(db.String(16))
open_id: Mapped[str] = mapped_column(db.String(255))
encrypted_token: Mapped[str] = mapped_column(db.String(255))
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
provider: Mapped[str] = mapped_column(String(16))
open_id: Mapped[str] = mapped_column(String(255))
encrypted_token: Mapped[str] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
class InvitationCode(Base):
__tablename__ = "invitation_codes"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
db.Index("invitation_codes_batch_idx", "batch"),
db.Index("invitation_codes_code_idx", "code", "status"),
sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
sa.Index("invitation_codes_batch_idx", "batch"),
sa.Index("invitation_codes_code_idx", "code", "status"),
)
id: Mapped[int] = mapped_column(db.Integer)
batch: Mapped[str] = mapped_column(db.String(255))
code: Mapped[str] = mapped_column(db.String(32))
status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'unused'::character varying"))
used_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
id: Mapped[int] = mapped_column(sa.Integer)
batch: Mapped[str] = mapped_column(String(255))
code: Mapped[str] = mapped_column(String(32))
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying"))
used_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID)
used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
deprecated_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)"))
deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
class TenantPluginPermission(Base):
@ -292,16 +291,14 @@ class TenantPluginPermission(Base):
__tablename__ = "account_plugin_permissions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"),
db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
sa.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"),
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
install_permission: Mapped[InstallPermission] = mapped_column(
db.String(16), nullable=False, server_default="everyone"
)
debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone")
install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone")
debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone")
class TenantPluginAutoUpgradeStrategy(Base):
@ -317,20 +314,16 @@ class TenantPluginAutoUpgradeStrategy(Base):
__tablename__ = "tenant_plugin_auto_upgrade_strategies"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"),
db.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
sa.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"),
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
strategy_setting: Mapped[StrategySetting] = mapped_column(db.String(16), nullable=False, server_default="fix_only")
upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day
upgrade_mode: Mapped[UpgradeMode] = mapped_column(db.String(16), nullable=False, server_default="exclude")
exclude_plugins: Mapped[list[str]] = mapped_column(
db.ARRAY(db.String(255)), nullable=False
) # plugin_id (author/name)
include_plugins: Mapped[list[str]] = mapped_column(
db.ARRAY(db.String(255)), nullable=False
) # plugin_id (author/name)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only")
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day
upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude")
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
created_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -1,10 +1,11 @@
import enum
from datetime import datetime
from sqlalchemy import func
from sqlalchemy.orm import mapped_column
import sqlalchemy as sa
from sqlalchemy import DateTime, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .engine import db
from .types import StringUUID
@ -18,13 +19,13 @@ class APIBasedExtensionPoint(enum.Enum):
class APIBasedExtension(Base):
__tablename__ = "api_based_extensions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
db.Index("api_based_extension_tenant_idx", "tenant_id"),
sa.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
sa.Index("api_based_extension_tenant_idx", "tenant_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
name = mapped_column(db.String(255), nullable=False)
api_endpoint = mapped_column(db.String(255), nullable=False)
api_key = mapped_column(db.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
name: Mapped[str] = mapped_column(String(255), nullable=False)
api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False)
api_key = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -12,7 +12,8 @@ from datetime import datetime
from json import JSONDecodeError
from typing import Any, Optional, cast
from sqlalchemy import func, select
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
@ -38,25 +39,25 @@ class DatasetPermissionEnum(enum.StrEnum):
class Dataset(Base):
__tablename__ = "datasets"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
db.Index("dataset_tenant_idx", "tenant_id"),
db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
sa.PrimaryKeyConstraint("id", name="dataset_pkey"),
sa.Index("dataset_tenant_idx", "tenant_id"),
sa.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
)
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
PROVIDER_LIST = ["vendor", "external", None]
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
name: Mapped[str] = mapped_column(db.String(255))
description = mapped_column(db.Text, nullable=True)
provider: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'vendor'::character varying"))
permission: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'only_me'::character varying"))
data_source_type = mapped_column(db.String(255))
indexing_technique: Mapped[Optional[str]] = mapped_column(db.String(255))
index_struct = mapped_column(db.Text, nullable=True)
name: Mapped[str] = mapped_column(String(255))
description = mapped_column(sa.Text, nullable=True)
provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'::character varying"))
permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'::character varying"))
data_source_type = mapped_column(String(255))
indexing_technique: Mapped[Optional[str]] = mapped_column(String(255))
index_struct = mapped_column(sa.Text, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
embedding_model = db.Column(db.String(255), nullable=True) # TODO: mapped_column
@ -294,16 +295,16 @@ class Dataset(Base):
class DatasetProcessRule(Base):
__tablename__ = "dataset_process_rules"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
sa.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
dataset_id = mapped_column(StringUUID, nullable=False)
mode = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
rules = mapped_column(db.Text, nullable=True)
mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying"))
rules = mapped_column(sa.Text, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
MODES = ["automatic", "custom", "hierarchical"]
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
@ -334,72 +335,70 @@ class DatasetProcessRule(Base):
class Document(Base):
__tablename__ = "documents"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_pkey"),
db.Index("document_dataset_id_idx", "dataset_id"),
db.Index("document_is_paused_idx", "is_paused"),
db.Index("document_tenant_idx", "tenant_id"),
db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
sa.PrimaryKeyConstraint("id", name="document_pkey"),
sa.Index("document_dataset_id_idx", "dataset_id"),
sa.Index("document_is_paused_idx", "is_paused"),
sa.Index("document_tenant_idx", "tenant_id"),
sa.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
)
# initial fields
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
position = mapped_column(db.Integer, nullable=False)
data_source_type = mapped_column(db.String(255), nullable=False)
data_source_info = mapped_column(db.Text, nullable=True)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
data_source_type: Mapped[str] = mapped_column(String(255), nullable=False)
data_source_info = mapped_column(sa.Text, nullable=True)
dataset_process_rule_id = mapped_column(StringUUID, nullable=True)
batch = mapped_column(db.String(255), nullable=False)
name = mapped_column(db.String(255), nullable=False)
created_from = mapped_column(db.String(255), nullable=False)
batch: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_api_request_id = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
# start processing
processing_started_at = mapped_column(db.DateTime, nullable=True)
processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# parsing
file_id = mapped_column(db.Text, nullable=True)
word_count = mapped_column(db.Integer, nullable=True)
parsing_completed_at = mapped_column(db.DateTime, nullable=True)
file_id = mapped_column(sa.Text, nullable=True)
word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable
parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# cleaning
cleaning_completed_at = mapped_column(db.DateTime, nullable=True)
cleaning_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# split
splitting_completed_at = mapped_column(db.DateTime, nullable=True)
splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# indexing
tokens = mapped_column(db.Integer, nullable=True)
indexing_latency = mapped_column(db.Float, nullable=True)
completed_at = mapped_column(db.DateTime, nullable=True)
tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
indexing_latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# pause
is_paused = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
is_paused: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
paused_by = mapped_column(StringUUID, nullable=True)
paused_at = mapped_column(db.DateTime, nullable=True)
paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# error
error = mapped_column(db.Text, nullable=True)
stopped_at = mapped_column(db.DateTime, nullable=True)
error = mapped_column(sa.Text, nullable=True)
stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# basic fields
indexing_status = mapped_column(
db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")
)
enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
disabled_at = mapped_column(db.DateTime, nullable=True)
indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'::character varying"))
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
archived = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
archived_reason = mapped_column(db.String(255), nullable=True)
archived: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
archived_reason = mapped_column(String(255), nullable=True)
archived_by = mapped_column(StringUUID, nullable=True)
archived_at = mapped_column(db.DateTime, nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
doc_type = mapped_column(db.String(40), nullable=True)
archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
doc_type = mapped_column(String(40), nullable=True)
doc_metadata = mapped_column(JSONB, nullable=True)
doc_form = mapped_column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
doc_language = mapped_column(db.String(255), nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'::character varying"))
doc_language = mapped_column(String(255), nullable=True)
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
@ -556,7 +555,7 @@ class Document(Base):
"id": "built-in",
"name": BuiltInField.upload_date,
"type": "time",
"value": self.created_at.timestamp(),
"value": str(self.created_at.timestamp()),
}
)
built_in_fields.append(
@ -564,7 +563,7 @@ class Document(Base):
"id": "built-in",
"name": BuiltInField.last_update_date,
"type": "time",
"value": self.updated_at.timestamp(),
"value": str(self.updated_at.timestamp()),
}
)
built_in_fields.append(
@ -677,45 +676,45 @@ class Document(Base):
class DocumentSegment(Base):
__tablename__ = "document_segments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
db.Index("document_segment_dataset_id_idx", "dataset_id"),
db.Index("document_segment_document_id_idx", "document_id"),
db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
db.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"),
db.Index("document_segment_tenant_idx", "tenant_id"),
sa.PrimaryKeyConstraint("id", name="document_segment_pkey"),
sa.Index("document_segment_dataset_id_idx", "dataset_id"),
sa.Index("document_segment_document_id_idx", "document_id"),
sa.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
sa.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
sa.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"),
sa.Index("document_segment_tenant_idx", "tenant_id"),
)
# initial fields
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
position: Mapped[int]
content = mapped_column(db.Text, nullable=False)
answer = mapped_column(db.Text, nullable=True)
content = mapped_column(sa.Text, nullable=False)
answer = mapped_column(sa.Text, nullable=True)
word_count: Mapped[int]
tokens: Mapped[int]
# indexing fields
keywords = mapped_column(db.JSON, nullable=True)
index_node_id = mapped_column(db.String(255), nullable=True)
index_node_hash = mapped_column(db.String(255), nullable=True)
keywords = mapped_column(sa.JSON, nullable=True)
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
# basic fields
hit_count = mapped_column(db.Integer, nullable=False, default=0)
enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
disabled_at = mapped_column(db.DateTime, nullable=True)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'waiting'::character varying"))
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'::character varying"))
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
indexing_at = mapped_column(db.DateTime, nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
error = mapped_column(db.Text, nullable=True)
stopped_at = mapped_column(db.DateTime, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
error = mapped_column(sa.Text, nullable=True)
stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
@property
def dataset(self):
@ -828,32 +827,36 @@ class DocumentSegment(Base):
class ChildChunk(Base):
__tablename__ = "child_chunks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
db.Index("child_chunks_node_idx", "index_node_id", "dataset_id"),
db.Index("child_chunks_segment_idx", "segment_id"),
sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
sa.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
sa.Index("child_chunks_node_idx", "index_node_id", "dataset_id"),
sa.Index("child_chunks_segment_idx", "segment_id"),
)
# initial fields
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
segment_id = mapped_column(StringUUID, nullable=False)
position = mapped_column(db.Integer, nullable=False)
content = mapped_column(db.Text, nullable=False)
word_count = mapped_column(db.Integer, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
content = mapped_column(sa.Text, nullable=False)
word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
# indexing fields
index_node_id = mapped_column(db.String(255), nullable=True)
index_node_hash = mapped_column(db.String(255), nullable=True)
type = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying"))
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
indexing_at = mapped_column(db.DateTime, nullable=True)
completed_at = mapped_column(db.DateTime, nullable=True)
error = mapped_column(db.Text, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
error = mapped_column(sa.Text, nullable=True)
@property
def dataset(self):
@ -871,14 +874,14 @@ class ChildChunk(Base):
class AppDatasetJoin(Base):
__tablename__ = "app_dataset_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
sa.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
sa.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
)
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
def app(self):
@ -888,32 +891,32 @@ class AppDatasetJoin(Base):
class DatasetQuery(Base):
__tablename__ = "dataset_queries"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
db.Index("dataset_query_dataset_id_idx", "dataset_id"),
sa.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
sa.Index("dataset_query_dataset_id_idx", "dataset_id"),
)
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
dataset_id = mapped_column(StringUUID, nullable=False)
content = mapped_column(db.Text, nullable=False)
source = mapped_column(db.String(255), nullable=False)
content = mapped_column(sa.Text, nullable=False)
source: Mapped[str] = mapped_column(String(255), nullable=False)
source_app_id = mapped_column(StringUUID, nullable=True)
created_by_role = mapped_column(db.String, nullable=False)
created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
class DatasetKeywordTable(Base):
__tablename__ = "dataset_keyword_tables"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
sa.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
sa.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
dataset_id = mapped_column(StringUUID, nullable=False, unique=True)
keyword_table = mapped_column(db.Text, nullable=False)
keyword_table = mapped_column(sa.Text, nullable=False)
data_source_type = mapped_column(
db.String(255), nullable=False, server_default=db.text("'database'::character varying")
String(255), nullable=False, server_default=sa.text("'database'::character varying")
)
@property
@ -950,19 +953,19 @@ class DatasetKeywordTable(Base):
class Embedding(Base):
__tablename__ = "embeddings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="embedding_pkey"),
db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
db.Index("created_at_idx", "created_at"),
sa.PrimaryKeyConstraint("id", name="embedding_pkey"),
sa.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
sa.Index("created_at_idx", "created_at"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
model_name = mapped_column(
db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")
String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'::character varying")
)
hash = mapped_column(db.String(64), nullable=False)
embedding = mapped_column(db.LargeBinary, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
provider_name = mapped_column(db.String(255), nullable=False, server_default=db.text("''::character varying"))
hash = mapped_column(String(64), nullable=False)
embedding = mapped_column(sa.LargeBinary, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
provider_name = mapped_column(String(255), nullable=False, server_default=sa.text("''::character varying"))
def set_embedding(self, embedding_data: list[float]):
self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
@ -974,84 +977,84 @@ class Embedding(Base):
class DatasetCollectionBinding(Base):
__tablename__ = "dataset_collection_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
db.Index("provider_model_name_idx", "provider_name", "model_name"),
sa.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
sa.Index("provider_model_name_idx", "provider_name", "model_name"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
provider_name = mapped_column(db.String(255), nullable=False)
model_name = mapped_column(db.String(255), nullable=False)
type = mapped_column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
collection_name = mapped_column(db.String(64), nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
type = mapped_column(String(40), server_default=sa.text("'dataset'::character varying"), nullable=False)
collection_name = mapped_column(String(64), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class TidbAuthBinding(Base):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
db.Index("tidb_auth_bindings_active_idx", "active"),
db.Index("tidb_auth_bindings_created_at_idx", "created_at"),
db.Index("tidb_auth_bindings_status_idx", "status"),
sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
sa.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
sa.Index("tidb_auth_bindings_active_idx", "active"),
sa.Index("tidb_auth_bindings_created_at_idx", "created_at"),
sa.Index("tidb_auth_bindings_status_idx", "status"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=True)
cluster_id = mapped_column(db.String(255), nullable=False)
cluster_name = mapped_column(db.String(255), nullable=False)
active = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
status = mapped_column(db.String(255), nullable=False, server_default=db.text("CREATING"))
account = mapped_column(db.String(255), nullable=False)
password = mapped_column(db.String(255), nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'::character varying"))
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class Whitelist(Base):
__tablename__ = "whitelists"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
db.Index("whitelists_tenant_idx", "tenant_id"),
sa.PrimaryKeyConstraint("id", name="whitelists_pkey"),
sa.Index("whitelists_tenant_idx", "tenant_id"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=True)
category = mapped_column(db.String(255), nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
category: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetPermission(Base):
__tablename__ = "dataset_permissions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
db.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
db.Index("idx_dataset_permissions_account_id", "account_id"),
db.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
sa.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
sa.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
sa.Index("idx_dataset_permissions_account_id", "account_id"),
sa.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), primary_key=True)
dataset_id = mapped_column(StringUUID, nullable=False)
account_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
has_permission = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
has_permission: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class ExternalKnowledgeApis(Base):
__tablename__ = "external_knowledge_apis"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
db.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
db.Index("external_knowledge_apis_name_idx", "name"),
sa.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
sa.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
sa.Index("external_knowledge_apis_name_idx", "name"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
name = mapped_column(db.String(255), nullable=False)
description = mapped_column(db.String(255), nullable=False)
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(String(255), nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
settings = mapped_column(db.Text, nullable=True)
settings = mapped_column(sa.Text, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
def to_dict(self):
return {
@ -1091,71 +1094,79 @@ class ExternalKnowledgeApis(Base):
class ExternalKnowledgeBindings(Base):
__tablename__ = "external_knowledge_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
sa.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
sa.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
sa.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
sa.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
external_knowledge_api_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
external_knowledge_id = mapped_column(db.Text, nullable=False)
external_knowledge_id = mapped_column(sa.Text, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetAutoDisableLog(Base):
__tablename__ = "dataset_auto_disable_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
db.Index("dataset_auto_disable_log_created_atx", "created_at"),
sa.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
sa.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
sa.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
sa.Index("dataset_auto_disable_log_created_atx", "created_at"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
notified = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
class RateLimitLog(Base):
__tablename__ = "rate_limit_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
db.Index("rate_limit_log_tenant_idx", "tenant_id"),
db.Index("rate_limit_log_operation_idx", "operation"),
sa.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
sa.Index("rate_limit_log_tenant_idx", "tenant_id"),
sa.Index("rate_limit_log_operation_idx", "operation"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
subscription_plan = mapped_column(db.String(255), nullable=False)
operation = mapped_column(db.String(255), nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False)
operation: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
class DatasetMetadata(Base):
__tablename__ = "dataset_metadatas"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
db.Index("dataset_metadata_tenant_idx", "tenant_id"),
db.Index("dataset_metadata_dataset_idx", "dataset_id"),
sa.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
sa.Index("dataset_metadata_tenant_idx", "tenant_id"),
sa.Index("dataset_metadata_dataset_idx", "dataset_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
type = mapped_column(db.String(255), nullable=False)
name = mapped_column(db.String(255), nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
type: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
created_by = mapped_column(StringUUID, nullable=False)
updated_by = mapped_column(StringUUID, nullable=True)
@ -1163,19 +1174,19 @@ class DatasetMetadata(Base):
class DatasetMetadataBinding(Base):
__tablename__ = "dataset_metadata_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
db.Index("dataset_metadata_binding_document_idx", "document_id"),
sa.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
sa.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
sa.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
sa.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
sa.Index("dataset_metadata_binding_document_idx", "document_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
metadata_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
created_by = mapped_column(StringUUID, nullable=False)

View File

@ -17,7 +17,7 @@ if TYPE_CHECKING:
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
@ -35,10 +35,10 @@ from .types import StringUUID
class DifySetup(Base):
__tablename__ = "dify_setups"
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
__table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
version = mapped_column(db.String(255), nullable=False)
setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
version: Mapped[str] = mapped_column(String(255), nullable=False)
setup_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class AppMode(StrEnum):
@ -71,33 +71,33 @@ class IconType(Enum):
class App(Base):
__tablename__ = "apps"
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_pkey"), sa.Index("app_tenant_id_idx", "tenant_id"))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
name: Mapped[str] = mapped_column(db.String(255))
description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying"))
mode: Mapped[str] = mapped_column(db.String(255))
icon_type: Mapped[Optional[str]] = mapped_column(db.String(255)) # image, emoji
icon = db.Column(db.String(255))
icon_background: Mapped[Optional[str]] = mapped_column(db.String(255))
name: Mapped[str] = mapped_column(String(255))
description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying"))
mode: Mapped[str] = mapped_column(String(255))
icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji
icon = db.Column(String(255))
icon_background: Mapped[Optional[str]] = mapped_column(String(255))
app_model_config_id = mapped_column(StringUUID, nullable=True)
workflow_id = mapped_column(StringUUID, nullable=True)
status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying"))
enable_site: Mapped[bool] = mapped_column(db.Boolean)
enable_api: Mapped[bool] = mapped_column(db.Boolean)
api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
api_rph: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
is_demo: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
is_public: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
is_universal: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
tracing = mapped_column(db.Text, nullable=True)
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
enable_site: Mapped[bool] = mapped_column(sa.Boolean)
enable_api: Mapped[bool] = mapped_column(sa.Boolean)
api_rpm: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"))
api_rph: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"))
is_demo: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
tracing = mapped_column(sa.Text, nullable=True)
max_active_requests: Mapped[Optional[int]]
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property
def desc_or_prompt(self):
@ -304,36 +304,36 @@ class App(Base):
class AppModelConfig(Base):
__tablename__ = "app_model_configs"
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id"))
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
provider = mapped_column(db.String(255), nullable=True)
model_id = mapped_column(db.String(255), nullable=True)
configs = mapped_column(db.JSON, nullable=True)
provider = mapped_column(String(255), nullable=True)
model_id = mapped_column(String(255), nullable=True)
configs = mapped_column(sa.JSON, nullable=True)
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
opening_statement = mapped_column(db.Text)
suggested_questions = mapped_column(db.Text)
suggested_questions_after_answer = mapped_column(db.Text)
speech_to_text = mapped_column(db.Text)
text_to_speech = mapped_column(db.Text)
more_like_this = mapped_column(db.Text)
model = mapped_column(db.Text)
user_input_form = mapped_column(db.Text)
dataset_query_variable = mapped_column(db.String(255))
pre_prompt = mapped_column(db.Text)
agent_mode = mapped_column(db.Text)
sensitive_word_avoidance = mapped_column(db.Text)
retriever_resource = mapped_column(db.Text)
prompt_type = mapped_column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying"))
chat_prompt_config = mapped_column(db.Text)
completion_prompt_config = mapped_column(db.Text)
dataset_configs = mapped_column(db.Text)
external_data_tools = mapped_column(db.Text)
file_upload = mapped_column(db.Text)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
opening_statement = mapped_column(sa.Text)
suggested_questions = mapped_column(sa.Text)
suggested_questions_after_answer = mapped_column(sa.Text)
speech_to_text = mapped_column(sa.Text)
text_to_speech = mapped_column(sa.Text)
more_like_this = mapped_column(sa.Text)
model = mapped_column(sa.Text)
user_input_form = mapped_column(sa.Text)
dataset_query_variable = mapped_column(String(255))
pre_prompt = mapped_column(sa.Text)
agent_mode = mapped_column(sa.Text)
sensitive_word_avoidance = mapped_column(sa.Text)
retriever_resource = mapped_column(sa.Text)
prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'::character varying"))
chat_prompt_config = mapped_column(sa.Text)
completion_prompt_config = mapped_column(sa.Text)
dataset_configs = mapped_column(sa.Text)
external_data_tools = mapped_column(sa.Text)
file_upload = mapped_column(sa.Text)
@property
def app(self):
@ -555,24 +555,24 @@ class AppModelConfig(Base):
class RecommendedApp(Base):
__tablename__ = "recommended_apps"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
db.Index("recommended_app_app_id_idx", "app_id"),
db.Index("recommended_app_is_listed_idx", "is_listed", "language"),
sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
sa.Index("recommended_app_app_id_idx", "app_id"),
sa.Index("recommended_app_is_listed_idx", "is_listed", "language"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
description = mapped_column(db.JSON, nullable=False)
copyright = mapped_column(db.String(255), nullable=False)
privacy_policy = mapped_column(db.String(255), nullable=False)
description = mapped_column(sa.JSON, nullable=False)
copyright: Mapped[str] = mapped_column(String(255), nullable=False)
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False)
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
category = mapped_column(db.String(255), nullable=False)
position = mapped_column(db.Integer, nullable=False, default=0)
is_listed = mapped_column(db.Boolean, nullable=False, default=True)
install_count = mapped_column(db.Integer, nullable=False, default=0)
language = mapped_column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying"))
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
category: Mapped[str] = mapped_column(String(255), nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self):
@ -583,20 +583,20 @@ class RecommendedApp(Base):
class InstalledApp(Base):
__tablename__ = "installed_apps"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="installed_app_pkey"),
db.Index("installed_app_tenant_id_idx", "tenant_id"),
db.Index("installed_app_app_id_idx", "app_id"),
db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
sa.PrimaryKeyConstraint("id", name="installed_app_pkey"),
sa.Index("installed_app_tenant_id_idx", "tenant_id"),
sa.Index("installed_app_app_id_idx", "app_id"),
sa.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
app_owner_tenant_id = mapped_column(StringUUID, nullable=False)
position = mapped_column(db.Integer, nullable=False, default=0)
is_pinned = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
last_used_at = mapped_column(db.DateTime, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
is_pinned: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
last_used_at = mapped_column(sa.DateTime, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self):
@ -612,47 +612,47 @@ class InstalledApp(Base):
class Conversation(Base):
__tablename__ = "conversations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="conversation_pkey"),
db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
sa.PrimaryKeyConstraint("id", name="conversation_pkey"),
sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
app_model_config_id = mapped_column(StringUUID, nullable=True)
model_provider = mapped_column(db.String(255), nullable=True)
override_model_configs = mapped_column(db.Text)
model_id = mapped_column(db.String(255), nullable=True)
mode: Mapped[str] = mapped_column(db.String(255))
name = mapped_column(db.String(255), nullable=False)
summary = mapped_column(db.Text)
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
introduction = mapped_column(db.Text)
system_instruction = mapped_column(db.Text)
system_instruction_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
status = mapped_column(db.String(255), nullable=False)
model_provider = mapped_column(String(255), nullable=True)
override_model_configs = mapped_column(sa.Text)
model_id = mapped_column(String(255), nullable=True)
mode: Mapped[str] = mapped_column(String(255))
name: Mapped[str] = mapped_column(String(255), nullable=False)
summary = mapped_column(sa.Text)
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
introduction = mapped_column(sa.Text)
system_instruction = mapped_column(sa.Text)
system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
status: Mapped[str] = mapped_column(String(255), nullable=False)
# The `invoke_from` records how the conversation is created.
#
# Its value corresponds to the members of `InvokeFrom`.
# (api/core/app/entities/app_invoke_entities.py)
invoke_from = mapped_column(db.String(255), nullable=True)
invoke_from = mapped_column(String(255), nullable=True)
# ref: ConversationSource.
from_source = mapped_column(db.String(255), nullable=False)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_end_user_id = mapped_column(StringUUID)
from_account_id = mapped_column(StringUUID)
read_at = mapped_column(db.DateTime)
read_at = mapped_column(sa.DateTime)
read_account_id = mapped_column(StringUUID)
dialogue_count: Mapped[int] = mapped_column(default=0)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
message_annotations = db.relationship(
"MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
)
is_deleted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property
def inputs(self):
@ -894,36 +894,36 @@ class Message(Base):
Index("message_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
model_provider = mapped_column(db.String(255), nullable=True)
model_id = mapped_column(db.String(255), nullable=True)
override_model_configs = mapped_column(db.Text)
conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
query: Mapped[str] = mapped_column(db.Text, nullable=False)
message = mapped_column(db.JSON, nullable=False)
message_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
message_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
answer: Mapped[str] = db.Column(db.Text, nullable=False) # TODO make it mapped_column
answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
model_provider = mapped_column(String(255), nullable=True)
model_id = mapped_column(String(255), nullable=True)
override_model_configs = mapped_column(sa.Text)
conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
query: Mapped[str] = mapped_column(sa.Text, nullable=False)
message = mapped_column(sa.JSON, nullable=False)
message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
message_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
answer: Mapped[str] = db.Column(sa.Text, nullable=False) # TODO make it mapped_column
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
parent_message_id = mapped_column(StringUUID, nullable=True)
provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0"))
total_price = mapped_column(db.Numeric(10, 7))
currency = mapped_column(db.String(255), nullable=False)
status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
error = mapped_column(db.Text)
message_metadata = mapped_column(db.Text)
invoke_from: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
from_source = mapped_column(db.String(255), nullable=False)
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_price = mapped_column(sa.Numeric(10, 7))
currency: Mapped[str] = mapped_column(String(255), nullable=False)
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
error = mapped_column(sa.Text)
message_metadata = mapped_column(sa.Text)
invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID)
from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
agent_based = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
@property
@ -1230,23 +1230,23 @@ class Message(Base):
class MessageFeedback(Base):
__tablename__ = "message_feedbacks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
db.Index("message_feedback_app_idx", "app_id"),
db.Index("message_feedback_message_idx", "message_id", "from_source"),
db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
sa.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
sa.Index("message_feedback_app_idx", "app_id"),
sa.Index("message_feedback_message_idx", "message_id", "from_source"),
sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
conversation_id = mapped_column(StringUUID, nullable=False)
message_id = mapped_column(StringUUID, nullable=False)
rating = mapped_column(db.String(255), nullable=False)
content = mapped_column(db.Text)
from_source = mapped_column(db.String(255), nullable=False)
rating: Mapped[str] = mapped_column(String(255), nullable=False)
content = mapped_column(sa.Text)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_end_user_id = mapped_column(StringUUID)
from_account_id = mapped_column(StringUUID)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def from_account(self):
@ -1272,9 +1272,9 @@ class MessageFeedback(Base):
class MessageFile(Base):
__tablename__ = "message_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_file_pkey"),
db.Index("message_file_message_idx", "message_id"),
db.Index("message_file_created_by_idx", "created_by"),
sa.PrimaryKeyConstraint("id", name="message_file_pkey"),
sa.Index("message_file_message_idx", "message_id"),
sa.Index("message_file_created_by_idx", "created_by"),
)
def __init__(
@ -1298,37 +1298,37 @@ class MessageFile(Base):
self.created_by_role = created_by_role.value
self.created_by = created_by
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(db.String(255), nullable=False)
transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
belongs_to: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
type: Mapped[str] = mapped_column(String(255), nullable=False)
transfer_method: Mapped[str] = mapped_column(String(255), nullable=False)
url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageAnnotation(Base):
__tablename__ = "message_annotations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
db.Index("message_annotation_app_idx", "app_id"),
db.Index("message_annotation_conversation_idx", "conversation_id"),
db.Index("message_annotation_message_idx", "message_id"),
sa.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
sa.Index("message_annotation_app_idx", "app_id"),
sa.Index("message_annotation_conversation_idx", "conversation_id"),
sa.Index("message_annotation_message_idx", "message_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id: Mapped[str] = mapped_column(StringUUID)
conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, db.ForeignKey("conversations.id"))
conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
message_id: Mapped[Optional[str]] = mapped_column(StringUUID)
question = db.Column(db.Text, nullable=True)
content = mapped_column(db.Text, nullable=False)
hit_count = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
question = db.Column(sa.Text, nullable=True)
content = mapped_column(sa.Text, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
account_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def account(self):
@ -1344,24 +1344,24 @@ class MessageAnnotation(Base):
class AppAnnotationHitHistory(Base):
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
db.Index("app_annotation_hit_histories_app_idx", "app_id"),
db.Index("app_annotation_hit_histories_account_idx", "account_id"),
db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"),
db.Index("app_annotation_hit_histories_message_idx", "message_id"),
sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
sa.Index("app_annotation_hit_histories_app_idx", "app_id"),
sa.Index("app_annotation_hit_histories_account_idx", "account_id"),
sa.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"),
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
source = mapped_column(db.Text, nullable=False)
question = mapped_column(db.Text, nullable=False)
source = mapped_column(sa.Text, nullable=False)
question = mapped_column(sa.Text, nullable=False)
account_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
score = mapped_column(Float, nullable=False, server_default=db.text("0"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
score = mapped_column(Float, nullable=False, server_default=sa.text("0"))
message_id = mapped_column(StringUUID, nullable=False)
annotation_question = mapped_column(db.Text, nullable=False)
annotation_content = mapped_column(db.Text, nullable=False)
annotation_question = mapped_column(sa.Text, nullable=False)
annotation_content = mapped_column(sa.Text, nullable=False)
@property
def account(self):
@ -1382,18 +1382,18 @@ class AppAnnotationHitHistory(Base):
class AppAnnotationSetting(Base):
__tablename__ = "app_annotation_settings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
db.Index("app_annotation_settings_app_idx", "app_id"),
sa.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
sa.Index("app_annotation_settings_app_idx", "app_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
score_threshold = mapped_column(Float, nullable=False, server_default=db.text("0"))
score_threshold = mapped_column(Float, nullable=False, server_default=sa.text("0"))
collection_binding_id = mapped_column(StringUUID, nullable=False)
created_user_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_user_id = mapped_column(StringUUID, nullable=False)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def collection_binding_detail(self):
@ -1410,58 +1410,58 @@ class AppAnnotationSetting(Base):
class OperationLog(Base):
__tablename__ = "operation_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="operation_log_pkey"),
db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
sa.PrimaryKeyConstraint("id", name="operation_log_pkey"),
sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
account_id = mapped_column(StringUUID, nullable=False)
action = mapped_column(db.String(255), nullable=False)
content = mapped_column(db.JSON)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_ip = mapped_column(db.String(255), nullable=False)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
action: Mapped[str] = mapped_column(String(255), nullable=False)
content = mapped_column(sa.JSON)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_ip: Mapped[str] = mapped_column(String(255), nullable=False)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class EndUser(Base, UserMixin):
__tablename__ = "end_users"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="end_user_pkey"),
db.Index("end_user_session_id_idx", "session_id", "type"),
db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
sa.PrimaryKeyConstraint("id", name="end_user_pkey"),
sa.Index("end_user_session_id_idx", "session_id", "type"),
sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=True)
type = mapped_column(db.String(255), nullable=False)
external_user_id = mapped_column(db.String(255), nullable=True)
name = mapped_column(db.String(255))
is_anonymous = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
type: Mapped[str] = mapped_column(String(255), nullable=False)
external_user_id = mapped_column(String(255), nullable=True)
name = mapped_column(String(255))
is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
session_id: Mapped[str] = mapped_column()
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class AppMCPServer(Base):
__tablename__ = "app_mcp_servers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),
db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
sa.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),
sa.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
sa.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
name = mapped_column(db.String(255), nullable=False)
description = mapped_column(db.String(255), nullable=False)
server_code = mapped_column(db.String(255), nullable=False)
status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
parameters = mapped_column(db.Text, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(String(255), nullable=False)
server_code: Mapped[str] = mapped_column(String(255), nullable=False)
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
parameters = mapped_column(sa.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_server_code(n):
@ -1480,35 +1480,35 @@ class AppMCPServer(Base):
class Site(Base):
__tablename__ = "sites"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="site_pkey"),
db.Index("site_app_id_idx", "app_id"),
db.Index("site_code_idx", "code", "status"),
sa.PrimaryKeyConstraint("id", name="site_pkey"),
sa.Index("site_app_id_idx", "app_id"),
sa.Index("site_code_idx", "code", "status"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
title = mapped_column(db.String(255), nullable=False)
icon_type = mapped_column(db.String(255), nullable=True)
icon = mapped_column(db.String(255))
icon_background = mapped_column(db.String(255))
description = mapped_column(db.Text)
default_language = mapped_column(db.String(255), nullable=False)
chat_color_theme = mapped_column(db.String(255))
chat_color_theme_inverted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
copyright = mapped_column(db.String(255))
privacy_policy = mapped_column(db.String(255))
show_workflow_steps = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
use_icon_as_answer_icon = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
title: Mapped[str] = mapped_column(String(255), nullable=False)
icon_type = mapped_column(String(255), nullable=True)
icon = mapped_column(String(255))
icon_background = mapped_column(String(255))
description = mapped_column(sa.Text)
default_language: Mapped[str] = mapped_column(String(255), nullable=False)
chat_color_theme = mapped_column(String(255))
chat_color_theme_inverted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
copyright = mapped_column(String(255))
privacy_policy = mapped_column(String(255))
show_workflow_steps: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
_custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="")
customize_domain = mapped_column(db.String(255))
customize_token_strategy = mapped_column(db.String(255), nullable=False)
prompt_public = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
customize_domain = mapped_column(String(255))
customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False)
prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
code = mapped_column(db.String(255))
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
code = mapped_column(String(255))
@property
def custom_disclaimer(self):
@ -1537,19 +1537,19 @@ class Site(Base):
class ApiToken(Base):
__tablename__ = "api_tokens"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_token_pkey"),
db.Index("api_token_app_id_type_idx", "app_id", "type"),
db.Index("api_token_token_idx", "token", "type"),
db.Index("api_token_tenant_idx", "tenant_id", "type"),
sa.PrimaryKeyConstraint("id", name="api_token_pkey"),
sa.Index("api_token_app_id_type_idx", "app_id", "type"),
sa.Index("api_token_token_idx", "token", "type"),
sa.Index("api_token_tenant_idx", "tenant_id", "type"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=True)
tenant_id = mapped_column(StringUUID, nullable=True)
type = mapped_column(db.String(16), nullable=False)
token = mapped_column(db.String(255), nullable=False)
last_used_at = mapped_column(db.DateTime, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
type = mapped_column(String(16), nullable=False)
token: Mapped[str] = mapped_column(String(255), nullable=False)
last_used_at = mapped_column(sa.DateTime, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_api_key(prefix, n):
@ -1563,27 +1563,27 @@ class ApiToken(Base):
class UploadFile(Base):
__tablename__ = "upload_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="upload_file_pkey"),
db.Index("upload_file_tenant_idx", "tenant_id"),
sa.PrimaryKeyConstraint("id", name="upload_file_pkey"),
sa.Index("upload_file_tenant_idx", "tenant_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
storage_type: Mapped[str] = mapped_column(db.String(255), nullable=False)
key: Mapped[str] = mapped_column(db.String(255), nullable=False)
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
size: Mapped[int] = mapped_column(db.Integer, nullable=False)
extension: Mapped[str] = mapped_column(db.String(255), nullable=False)
mime_type: Mapped[str] = mapped_column(db.String(255), nullable=True)
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
key: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
extension: Mapped[str] = mapped_column(String(255), nullable=False)
mime_type: Mapped[str] = mapped_column(String(255), nullable=True)
created_by_role: Mapped[str] = mapped_column(
db.String(255), nullable=False, server_default=db.text("'account'::character varying")
String(255), nullable=False, server_default=sa.text("'account'::character varying")
)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True)
hash: Mapped[str | None] = mapped_column(db.String(255), nullable=True)
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
source_url: Mapped[str] = mapped_column(sa.TEXT, default="")
def __init__(
@ -1625,71 +1625,71 @@ class UploadFile(Base):
class ApiRequest(Base):
__tablename__ = "api_requests"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_request_pkey"),
db.Index("api_request_token_idx", "tenant_id", "api_token_id"),
sa.PrimaryKeyConstraint("id", name="api_request_pkey"),
sa.Index("api_request_token_idx", "tenant_id", "api_token_id"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
api_token_id = mapped_column(StringUUID, nullable=False)
path = mapped_column(db.String(255), nullable=False)
request = mapped_column(db.Text, nullable=True)
response = mapped_column(db.Text, nullable=True)
ip = mapped_column(db.String(255), nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
path: Mapped[str] = mapped_column(String(255), nullable=False)
request = mapped_column(sa.Text, nullable=True)
response = mapped_column(sa.Text, nullable=True)
ip: Mapped[str] = mapped_column(String(255), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageChain(Base):
__tablename__ = "message_chains"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_chain_pkey"),
db.Index("message_chain_message_id_idx", "message_id"),
sa.PrimaryKeyConstraint("id", name="message_chain_pkey"),
sa.Index("message_chain_message_id_idx", "message_id"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
message_id = mapped_column(StringUUID, nullable=False)
type = mapped_column(db.String(255), nullable=False)
input = mapped_column(db.Text, nullable=True)
output = mapped_column(db.Text, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
type: Mapped[str] = mapped_column(String(255), nullable=False)
input = mapped_column(sa.Text, nullable=True)
output = mapped_column(sa.Text, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
class MessageAgentThought(Base):
__tablename__ = "message_agent_thoughts"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
db.Index("message_agent_thought_message_id_idx", "message_id"),
db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
sa.Index("message_agent_thought_message_id_idx", "message_id"),
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
message_id = mapped_column(StringUUID, nullable=False)
message_chain_id = mapped_column(StringUUID, nullable=True)
position = mapped_column(db.Integer, nullable=False)
thought = mapped_column(db.Text, nullable=True)
tool = mapped_column(db.Text, nullable=True)
tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
tool_meta_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
tool_input = mapped_column(db.Text, nullable=True)
observation = mapped_column(db.Text, nullable=True)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
thought = mapped_column(sa.Text, nullable=True)
tool = mapped_column(sa.Text, nullable=True)
tool_labels_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text"))
tool_meta_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text"))
tool_input = mapped_column(sa.Text, nullable=True)
observation = mapped_column(sa.Text, nullable=True)
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
tool_process_data = mapped_column(db.Text, nullable=True)
message = mapped_column(db.Text, nullable=True)
message_token = mapped_column(db.Integer, nullable=True)
message_unit_price = mapped_column(db.Numeric, nullable=True)
message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
message_files = mapped_column(db.Text, nullable=True)
answer = db.Column(db.Text, nullable=True)
answer_token = mapped_column(db.Integer, nullable=True)
answer_unit_price = mapped_column(db.Numeric, nullable=True)
answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
tokens = mapped_column(db.Integer, nullable=True)
total_price = mapped_column(db.Numeric, nullable=True)
currency = mapped_column(db.String, nullable=True)
latency = mapped_column(db.Float, nullable=True)
created_by_role = mapped_column(db.String, nullable=False)
tool_process_data = mapped_column(sa.Text, nullable=True)
message = mapped_column(sa.Text, nullable=True)
message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
message_unit_price = mapped_column(sa.Numeric, nullable=True)
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
message_files = mapped_column(sa.Text, nullable=True)
answer = db.Column(sa.Text, nullable=True)
answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
answer_unit_price = mapped_column(sa.Numeric, nullable=True)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
total_price = mapped_column(sa.Numeric, nullable=True)
currency = mapped_column(String, nullable=True)
latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
def files(self) -> list:
@ -1771,80 +1771,80 @@ class MessageAgentThought(Base):
class DatasetRetrieverResource(Base):
__tablename__ = "dataset_retriever_resources"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
db.Index("dataset_retriever_resource_message_id_idx", "message_id"),
sa.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
sa.Index("dataset_retriever_resource_message_id_idx", "message_id"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
message_id = mapped_column(StringUUID, nullable=False)
position = mapped_column(db.Integer, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
dataset_name = mapped_column(db.Text, nullable=False)
dataset_name = mapped_column(sa.Text, nullable=False)
document_id = mapped_column(StringUUID, nullable=True)
document_name = mapped_column(db.Text, nullable=False)
data_source_type = mapped_column(db.Text, nullable=True)
document_name = mapped_column(sa.Text, nullable=False)
data_source_type = mapped_column(sa.Text, nullable=True)
segment_id = mapped_column(StringUUID, nullable=True)
score = mapped_column(db.Float, nullable=True)
content = mapped_column(db.Text, nullable=False)
hit_count = mapped_column(db.Integer, nullable=True)
word_count = mapped_column(db.Integer, nullable=True)
segment_position = mapped_column(db.Integer, nullable=True)
index_node_hash = mapped_column(db.Text, nullable=True)
retriever_from = mapped_column(db.Text, nullable=False)
score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
content = mapped_column(sa.Text, nullable=False)
hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
index_node_hash = mapped_column(sa.Text, nullable=True)
retriever_from = mapped_column(sa.Text, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
class Tag(Base):
__tablename__ = "tags"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tag_pkey"),
db.Index("tag_type_idx", "type"),
db.Index("tag_name_idx", "name"),
sa.PrimaryKeyConstraint("id", name="tag_pkey"),
sa.Index("tag_type_idx", "type"),
sa.Index("tag_name_idx", "name"),
)
TAG_TYPE_LIST = ["knowledge", "app"]
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=True)
type = mapped_column(db.String(16), nullable=False)
name = mapped_column(db.String(255), nullable=False)
type = mapped_column(String(16), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class TagBinding(Base):
__tablename__ = "tag_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
db.Index("tag_bind_target_id_idx", "target_id"),
db.Index("tag_bind_tag_id_idx", "tag_id"),
sa.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
sa.Index("tag_bind_target_id_idx", "target_id"),
sa.Index("tag_bind_tag_id_idx", "tag_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=True)
tag_id = mapped_column(StringUUID, nullable=True)
target_id = mapped_column(StringUUID, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class TraceAppConfig(Base):
__tablename__ = "trace_app_config"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
db.Index("trace_app_config_app_id_idx", "app_id"),
sa.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
sa.Index("trace_app_config_app_id_idx", "app_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
tracing_provider = mapped_column(db.String(255), nullable=True)
tracing_config = mapped_column(db.JSON, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
tracing_provider = mapped_column(String(255), nullable=True)
tracing_config = mapped_column(sa.JSON, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
is_active = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
@property
def tracing_config_dict(self):

View File

@ -2,11 +2,11 @@ from datetime import datetime
from enum import Enum
from typing import Optional
from sqlalchemy import func, text
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, text
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .engine import db
from .types import StringUUID
@ -47,31 +47,31 @@ class Provider(Base):
__tablename__ = "providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="provider_pkey"),
db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"),
db.UniqueConstraint(
sa.PrimaryKeyConstraint("id", name="provider_pkey"),
sa.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"),
sa.UniqueConstraint(
"tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota"
),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
provider_type: Mapped[str] = mapped_column(
db.String(40), nullable=False, server_default=text("'custom'::character varying")
String(40), nullable=False, server_default=text("'custom'::character varying")
)
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
quota_type: Mapped[Optional[str]] = mapped_column(
db.String(40), nullable=True, server_default=text("''::character varying")
String(40), nullable=True, server_default=text("''::character varying")
)
quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True)
quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0)
quota_limit: Mapped[Optional[int]] = mapped_column(sa.BigInteger, nullable=True)
quota_used: Mapped[Optional[int]] = mapped_column(sa.BigInteger, default=0)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
def __repr__(self):
return (
@ -104,80 +104,80 @@ class ProviderModel(Base):
__tablename__ = "provider_models"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="provider_model_pkey"),
db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"),
db.UniqueConstraint(
sa.PrimaryKeyConstraint("id", name="provider_model_pkey"),
sa.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"),
sa.UniqueConstraint(
"tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name"
),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class TenantDefaultModel(Base):
__tablename__ = "tenant_default_models"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
sa.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class TenantPreferredModelProvider(Base):
__tablename__ = "tenant_preferred_model_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
sa.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
sa.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderOrder(Base):
__tablename__ = "provider_orders"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="provider_order_pkey"),
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
sa.PrimaryKeyConstraint("id", name="provider_order_pkey"),
sa.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False)
payment_id: Mapped[Optional[str]] = mapped_column(db.String(191))
transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191))
quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1"))
currency: Mapped[Optional[str]] = mapped_column(db.String(40))
total_amount: Mapped[Optional[int]] = mapped_column(db.Integer)
payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False)
payment_id: Mapped[Optional[str]] = mapped_column(String(191))
transaction_id: Mapped[Optional[str]] = mapped_column(String(191))
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
currency: Mapped[Optional[str]] = mapped_column(String(40))
total_amount: Mapped[Optional[int]] = mapped_column(sa.Integer)
payment_status: Mapped[str] = mapped_column(
db.String(40), nullable=False, server_default=text("'wait_pay'::character varying")
String(40), nullable=False, server_default=text("'wait_pay'::character varying")
)
paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
pay_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
refunded_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderModelSetting(Base):
@ -187,19 +187,19 @@ class ProviderModelSetting(Base):
__tablename__ = "provider_model_settings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"),
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
sa.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"),
sa.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class LoadBalancingModelConfig(Base):
@ -209,17 +209,17 @@ class LoadBalancingModelConfig(Base):
__tablename__ = "load_balancing_model_configs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"),
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
sa.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"),
sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -1,49 +1,51 @@
import json
from datetime import datetime
from typing import Optional
from sqlalchemy import func
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import Mapped, mapped_column
from models.base import Base
from .engine import db
from .types import StringUUID
class DataSourceOauthBinding(Base):
__tablename__ = "data_source_oauth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
db.Index("source_binding_tenant_id_idx", "tenant_id"),
db.Index("source_info_idx", "source_info", postgresql_using="gin"),
sa.PrimaryKeyConstraint("id", name="source_binding_pkey"),
sa.Index("source_binding_tenant_id_idx", "tenant_id"),
sa.Index("source_info_idx", "source_info", postgresql_using="gin"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
access_token = mapped_column(db.String(255), nullable=False)
provider = mapped_column(db.String(255), nullable=False)
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
source_info = mapped_column(JSONB, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
class DataSourceApiKeyAuthBinding(Base):
__tablename__ = "data_source_api_key_auth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"),
db.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
sa.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
sa.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"),
sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
category = mapped_column(db.String(255), nullable=False)
provider = mapped_column(db.String(255), nullable=False)
credentials = mapped_column(db.Text, nullable=True) # JSON
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
category: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
credentials = mapped_column(sa.Text, nullable=True) # JSON
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
def to_dict(self):
return {

View File

@ -1,7 +1,9 @@
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from celery import states # type: ignore
from sqlalchemy import DateTime, String
from sqlalchemy.orm import Mapped, mapped_column
from libs.datetime_utils import naive_utc_now
@ -15,23 +17,23 @@ class CeleryTask(Base):
__tablename__ = "celery_taskmeta"
id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
task_id = mapped_column(db.String(155), unique=True)
status = mapped_column(db.String(50), default=states.PENDING)
id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
task_id = mapped_column(String(155), unique=True)
status = mapped_column(String(50), default=states.PENDING)
result = mapped_column(db.PickleType, nullable=True)
date_done = mapped_column(
db.DateTime,
DateTime,
default=lambda: naive_utc_now(),
onupdate=lambda: naive_utc_now(),
nullable=True,
)
traceback = mapped_column(db.Text, nullable=True)
name = mapped_column(db.String(155), nullable=True)
args = mapped_column(db.LargeBinary, nullable=True)
kwargs = mapped_column(db.LargeBinary, nullable=True)
worker = mapped_column(db.String(155), nullable=True)
retries = mapped_column(db.Integer, nullable=True)
queue = mapped_column(db.String(155), nullable=True)
traceback = mapped_column(sa.Text, nullable=True)
name = mapped_column(String(155), nullable=True)
args = mapped_column(sa.LargeBinary, nullable=True)
kwargs = mapped_column(sa.LargeBinary, nullable=True)
worker = mapped_column(String(155), nullable=True)
retries: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
queue = mapped_column(String(155), nullable=True)
class CeleryTaskSet(Base):
@ -40,8 +42,8 @@ class CeleryTaskSet(Base):
__tablename__ = "celery_tasksetmeta"
id: Mapped[int] = mapped_column(
db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
)
taskset_id = mapped_column(db.String(155), unique=True)
taskset_id = mapped_column(String(155), unique=True)
result = mapped_column(db.PickleType, nullable=True)
date_done: Mapped[Optional[datetime]] = mapped_column(db.DateTime, default=lambda: naive_utc_now(), nullable=True)
date_done: Mapped[Optional[datetime]] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True)

View File

@ -5,7 +5,7 @@ from urllib.parse import urlparse
import sqlalchemy as sa
from deprecated import deprecated
from sqlalchemy import ForeignKey, func
from sqlalchemy import ForeignKey, String, func
from sqlalchemy.orm import Mapped, mapped_column
from core.file import helpers as file_helpers
@ -25,33 +25,33 @@ from .types import StringUUID
class ToolOAuthSystemClient(Base):
__tablename__ = "tool_oauth_system_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
sa.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
plugin_id = mapped_column(String(512), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
# tenant level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthTenantClient(Base):
__tablename__ = "tool_oauth_tenant_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
sa.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
@property
def oauth_params(self) -> dict:
@ -65,35 +65,35 @@ class BuiltinToolProvider(Base):
__tablename__ = "tool_builtin_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
sa.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
sa.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
)
# id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(
db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
String(256), nullable=False, server_default=sa.text("'API KEY 1'::character varying")
)
# id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# who created this tool provider
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# name of the tool provider
provider: Mapped[str] = mapped_column(db.String(256), nullable=False)
provider: Mapped[str] = mapped_column(String(256), nullable=False)
# credential of the tool provider
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
# credential type, e.g., "api-key", "oauth2"
credential_type: Mapped[str] = mapped_column(
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
String(32), nullable=False, server_default=sa.text("'api-key'::character varying")
)
expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))
@property
def credentials(self) -> dict:
@ -107,35 +107,35 @@ class ApiToolProvider(Base):
__tablename__ = "tool_api_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
sa.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# name of the api provider
name = mapped_column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
name = mapped_column(String(255), nullable=False, server_default=sa.text("'API KEY 1'::character varying"))
# icon
icon = mapped_column(db.String(255), nullable=False)
icon: Mapped[str] = mapped_column(String(255), nullable=False)
# original schema
schema = mapped_column(db.Text, nullable=False)
schema_type_str: Mapped[str] = mapped_column(db.String(40), nullable=False)
schema = mapped_column(sa.Text, nullable=False)
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
# who created this tool
user_id = mapped_column(StringUUID, nullable=False)
# tenant id
tenant_id = mapped_column(StringUUID, nullable=False)
# description of the provider
description = mapped_column(db.Text, nullable=False)
description = mapped_column(sa.Text, nullable=False)
# json format tools
tools_str = mapped_column(db.Text, nullable=False)
tools_str = mapped_column(sa.Text, nullable=False)
# json format credentials
credentials_str = mapped_column(db.Text, nullable=False)
credentials_str = mapped_column(sa.Text, nullable=False)
# privacy policy
privacy_policy = mapped_column(db.String(255), nullable=True)
privacy_policy = mapped_column(String(255), nullable=True)
# custom_disclaimer
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def schema_type(self) -> ApiProviderSchemaType:
@ -167,17 +167,17 @@ class ToolLabelBinding(Base):
__tablename__ = "tool_label_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"),
db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
sa.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"),
sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# tool id
tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False)
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
# tool type
tool_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
# label name
label_name: Mapped[str] = mapped_column(db.String(40), nullable=False)
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
class WorkflowToolProvider(Base):
@ -187,38 +187,38 @@ class WorkflowToolProvider(Base):
__tablename__ = "tool_workflow_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"),
db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"),
db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
sa.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"),
sa.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"),
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# name of the workflow provider
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
# label of the workflow provider
label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="")
label: Mapped[str] = mapped_column(String(255), nullable=False, server_default="")
# icon
icon: Mapped[str] = mapped_column(db.String(255), nullable=False)
icon: Mapped[str] = mapped_column(String(255), nullable=False)
# app id of the workflow provider
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# version of the workflow provider
version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="")
version: Mapped[str] = mapped_column(String(255), nullable=False, server_default="")
# who created this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# description of the provider
description: Mapped[str] = mapped_column(db.Text, nullable=False)
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
# parameter configuration
parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]")
parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]")
# privacy policy
privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default="")
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="")
created_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
@property
@ -245,38 +245,38 @@ class MCPToolProvider(Base):
__tablename__ = "tool_mcp_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
db.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"),
db.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"),
db.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
sa.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
sa.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"),
sa.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"),
sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# name of the mcp provider
name: Mapped[str] = mapped_column(db.String(40), nullable=False)
name: Mapped[str] = mapped_column(String(40), nullable=False)
# server identifier of the mcp provider
server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False)
server_identifier: Mapped[str] = mapped_column(String(64), nullable=False)
# encrypted url of the mcp provider
server_url: Mapped[str] = mapped_column(db.Text, nullable=False)
server_url: Mapped[str] = mapped_column(sa.Text, nullable=False)
# hash of server_url for uniqueness check
server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False)
server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False)
# icon of the mcp provider
icon: Mapped[str] = mapped_column(db.String(255), nullable=True)
icon: Mapped[str] = mapped_column(String(255), nullable=True)
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# who created this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# encrypted credentials
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
# authed
authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False)
authed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
# tools
tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]")
tools: Mapped[str] = mapped_column(sa.Text, nullable=False, default="[]")
created_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
def load_user(self) -> Account | None:
@ -347,35 +347,35 @@ class ToolModelInvoke(Base):
"""
__tablename__ = "tool_model_invokes"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# who invoke this tool
user_id = mapped_column(StringUUID, nullable=False)
# tenant id
tenant_id = mapped_column(StringUUID, nullable=False)
# provider
provider = mapped_column(db.String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# type
tool_type = mapped_column(db.String(40), nullable=False)
tool_type = mapped_column(String(40), nullable=False)
# tool name
tool_name = mapped_column(db.String(128), nullable=False)
tool_name = mapped_column(String(128), nullable=False)
# invoke parameters
model_parameters = mapped_column(db.Text, nullable=False)
model_parameters = mapped_column(sa.Text, nullable=False)
# prompt messages
prompt_messages = mapped_column(db.Text, nullable=False)
prompt_messages = mapped_column(sa.Text, nullable=False)
# invoke response
model_response = mapped_column(db.Text, nullable=False)
model_response = mapped_column(sa.Text, nullable=False)
prompt_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
answer_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0"))
total_price = mapped_column(db.Numeric(10, 7))
currency = mapped_column(db.String(255), nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
prompt_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_price = mapped_column(sa.Numeric(10, 7))
currency: Mapped[str] = mapped_column(String(255), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@deprecated
@ -386,13 +386,13 @@ class ToolConversationVariables(Base):
__tablename__ = "tool_conversation_variables"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"),
sa.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"),
# add index for user_id and conversation_id
db.Index("user_id_idx", "user_id"),
db.Index("conversation_id_idx", "conversation_id"),
sa.Index("user_id_idx", "user_id"),
sa.Index("conversation_id_idx", "conversation_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# conversation user id
user_id = mapped_column(StringUUID, nullable=False)
# tenant id
@ -400,10 +400,10 @@ class ToolConversationVariables(Base):
# conversation id
conversation_id = mapped_column(StringUUID, nullable=False)
# variables pool
variables_str = mapped_column(db.Text, nullable=False)
variables_str = mapped_column(sa.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def variables(self) -> Any:
@ -417,11 +417,11 @@ class ToolFile(Base):
__tablename__ = "tool_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_file_pkey"),
db.Index("tool_file_conversation_id_idx", "conversation_id"),
sa.PrimaryKeyConstraint("id", name="tool_file_pkey"),
sa.Index("tool_file_conversation_id_idx", "conversation_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# conversation user id
user_id: Mapped[str] = mapped_column(StringUUID)
# tenant id
@ -429,11 +429,11 @@ class ToolFile(Base):
# conversation id
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# file key
file_key: Mapped[str] = mapped_column(db.String(255), nullable=False)
file_key: Mapped[str] = mapped_column(String(255), nullable=False)
# mime type
mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False)
mimetype: Mapped[str] = mapped_column(String(255), nullable=False)
# original url
original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True)
original_url: Mapped[str] = mapped_column(String(2048), nullable=True)
# name
name: Mapped[str] = mapped_column(default="")
# size
@ -448,30 +448,30 @@ class DeprecatedPublishedAppTool(Base):
__tablename__ = "tool_published_apps"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
sa.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# id of the app
app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# who published this tool
description = mapped_column(db.Text, nullable=False)
description = mapped_column(sa.Text, nullable=False)
# llm_description of the tool, for LLM
llm_description = mapped_column(db.Text, nullable=False)
llm_description = mapped_column(sa.Text, nullable=False)
# query description, query will be seem as a parameter of the tool,
# to describe this parameter to llm, we need this field
query_description = mapped_column(db.Text, nullable=False)
query_description = mapped_column(sa.Text, nullable=False)
# query name, the name of the query parameter
query_name = mapped_column(db.String(40), nullable=False)
query_name = mapped_column(String(40), nullable=False)
# name of the tool provider
tool_name = mapped_column(db.String(40), nullable=False)
tool_name = mapped_column(String(40), nullable=False)
# author
author = mapped_column(db.String(40), nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
author = mapped_column(String(40), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
@property
def description_i18n(self) -> I18nObject:

View File

@ -1,4 +1,7 @@
from sqlalchemy import func
from datetime import datetime
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column
from models.base import Base
@ -11,18 +14,18 @@ from .types import StringUUID
class SavedMessage(Base):
__tablename__ = "saved_messages"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="saved_message_pkey"),
db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
sa.PrimaryKeyConstraint("id", name="saved_message_pkey"),
sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
message_id = mapped_column(StringUUID, nullable=False)
created_by_role = mapped_column(
db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@property
def message(self):
@ -32,15 +35,15 @@ class SavedMessage(Base):
class PinnedConversation(Base):
__tablename__ = "pinned_conversations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
sa.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID)
created_by_role = mapped_column(
db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -6,8 +6,9 @@ from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4
import sqlalchemy as sa
from flask_login import current_user
from sqlalchemy import orm
from sqlalchemy import DateTime, orm
from core.file.constants import maybe_file_object
from core.file.models import File
@ -24,8 +25,7 @@ from ._workflow_exc import NodeNotFoundError, WorkflowDataError
if TYPE_CHECKING:
from models.model import AppMode
import sqlalchemy as sa
from sqlalchemy import Index, PrimaryKeyConstraint, UniqueConstraint, func
from sqlalchemy import Index, PrimaryKeyConstraint, String, UniqueConstraint, func
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
@ -118,33 +118,33 @@ class Workflow(Base):
__tablename__ = "workflows"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_pkey"),
db.Index("workflow_version_idx", "tenant_id", "app_id", "version"),
sa.PrimaryKeyConstraint("id", name="workflow_pkey"),
sa.Index("workflow_version_idx", "tenant_id", "app_id", "version"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(db.String(255), nullable=False)
version: Mapped[str] = mapped_column(db.String(255), nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
version: Mapped[str] = mapped_column(String(255), nullable=False)
marked_name: Mapped[str] = mapped_column(default="", server_default="")
marked_comment: Mapped[str] = mapped_column(default="", server_default="")
graph: Mapped[str] = mapped_column(sa.Text)
_features: Mapped[str] = mapped_column("features", sa.TEXT)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by: Mapped[Optional[str]] = mapped_column(StringUUID)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime,
DateTime,
nullable=False,
default=naive_utc_now(),
server_onupdate=func.current_timestamp(),
)
_environment_variables: Mapped[str] = mapped_column(
"environment_variables", db.Text, nullable=False, server_default="{}"
"environment_variables", sa.Text, nullable=False, server_default="{}"
)
_conversation_variables: Mapped[str] = mapped_column(
"conversation_variables", db.Text, nullable=False, server_default="{}"
"conversation_variables", sa.Text, nullable=False, server_default="{}"
)
_rag_pipeline_variables: Mapped[str] = mapped_column(
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
@ -521,31 +521,31 @@ class WorkflowRun(Base):
__tablename__ = "workflow_runs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID)
type: Mapped[str] = mapped_column(db.String(255))
triggered_from: Mapped[str] = mapped_column(db.String(255))
version: Mapped[str] = mapped_column(db.String(255))
graph: Mapped[Optional[str]] = mapped_column(db.Text)
inputs: Mapped[Optional[str]] = mapped_column(db.Text)
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
type: Mapped[str] = mapped_column(String(255))
triggered_from: Mapped[str] = mapped_column(String(255))
version: Mapped[str] = mapped_column(String(255))
graph: Mapped[Optional[str]] = mapped_column(sa.Text)
inputs: Mapped[Optional[str]] = mapped_column(sa.Text)
status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0"))
error: Mapped[Optional[str]] = mapped_column(sa.Text)
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
@property
def created_by_account(self):
@ -735,29 +735,29 @@ class WorkflowNodeExecutionModel(Base):
),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID)
triggered_from: Mapped[str] = mapped_column(db.String(255))
triggered_from: Mapped[str] = mapped_column(String(255))
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
index: Mapped[int] = mapped_column(db.Integer)
predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255))
node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255))
node_id: Mapped[str] = mapped_column(db.String(255))
node_type: Mapped[str] = mapped_column(db.String(255))
title: Mapped[str] = mapped_column(db.String(255))
inputs: Mapped[Optional[str]] = mapped_column(db.Text)
process_data: Mapped[Optional[str]] = mapped_column(db.Text)
outputs: Mapped[Optional[str]] = mapped_column(db.Text)
status: Mapped[str] = mapped_column(db.String(255))
error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0"))
execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text)
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
created_by_role: Mapped[str] = mapped_column(db.String(255))
index: Mapped[int] = mapped_column(sa.Integer)
predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255))
node_execution_id: Mapped[Optional[str]] = mapped_column(String(255))
node_id: Mapped[str] = mapped_column(String(255))
node_type: Mapped[str] = mapped_column(String(255))
title: Mapped[str] = mapped_column(String(255))
inputs: Mapped[Optional[str]] = mapped_column(sa.Text)
process_data: Mapped[Optional[str]] = mapped_column(sa.Text)
outputs: Mapped[Optional[str]] = mapped_column(sa.Text)
status: Mapped[str] = mapped_column(String(255))
error: Mapped[Optional[str]] = mapped_column(sa.Text)
elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0"))
execution_metadata: Mapped[Optional[str]] = mapped_column(sa.Text)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
created_by_role: Mapped[str] = mapped_column(String(255))
created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
@property
def created_by_account(self):
@ -865,19 +865,19 @@ class WorkflowAppLog(Base):
__tablename__ = "workflow_app_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
sa.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
sa.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@property
def workflow_run(self):
@ -902,12 +902,12 @@ class ConversationVariable(Base):
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
data: Mapped[str] = mapped_column(db.Text, nullable=False)
data: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True
DateTime, nullable=False, server_default=func.current_timestamp(), index=True
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None:
@ -964,17 +964,17 @@ class WorkflowDraftVariable(Base):
__allow_unmapped__ = True
# id is the unique identifier of a draft variable.
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
created_at: Mapped[datetime] = mapped_column(
db.DateTime,
DateTime,
nullable=False,
default=_naive_utc_datetime,
server_default=func.current_timestamp(),
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime,
DateTime,
nullable=False,
default=_naive_utc_datetime,
server_default=func.current_timestamp(),
@ -989,7 +989,7 @@ class WorkflowDraftVariable(Base):
#
# If it's not edited after creation, its value is `None`.
last_edited_at: Mapped[datetime | None] = mapped_column(
db.DateTime,
DateTime,
nullable=True,
default=None,
)

View File

@ -114,6 +114,7 @@ dev = [
"pytest-cov~=4.1.0",
"pytest-env~=1.1.3",
"pytest-mock~=3.14.0",
"testcontainers~=4.10.0",
"types-aiofiles~=24.1.0",
"types-beautifulsoup4~=4.12.0",
"types-cachetools~=5.5.0",

View File

@ -12,10 +12,10 @@ from libs.email_i18n import EmailType, get_email_i18n_service
redis_config = parse_url(dify_config.CELERY_BROKER_URL)
celery_redis = Redis(
host=redis_config["hostname"],
port=redis_config["port"],
password=redis_config["password"],
db=int(redis_config["virtual_host"]) if redis_config["virtual_host"] else 1,
host=redis_config.get("hostname") or "localhost",
port=redis_config.get("port") or 6379,
password=redis_config.get("password") or None,
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
)

View File

@ -12,6 +12,7 @@ import yaml # type: ignore
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from packaging import version
from packaging.version import parse as parse_version
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -269,7 +270,7 @@ class AppDslService:
check_dependencies_pending_data = None
if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
elif imported_version <= "0.1.5":
elif parse_version(imported_version) <= parse_version("0.1.5"):
if "workflow" in data:
graph = data.get("workflow", {}).get("graph", {})
dependencies_list = self._extract_dependencies_from_workflow_graph(graph)

View File

@ -1,5 +1,6 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Union
from typing import Any, Optional, Union
from openai._exceptions import RateLimitError
@ -15,6 +16,7 @@ from libs.helper import RateLimiter
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
from services.billing_service import BillingService
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_service import WorkflowService
@ -86,7 +88,8 @@ class AppGenerateService:
request_id=request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow = cls._get_workflow(app_model, invoke_from)
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().generate(
@ -101,7 +104,8 @@ class AppGenerateService:
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW.value:
workflow = cls._get_workflow(app_model, invoke_from)
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
@ -210,14 +214,27 @@ class AppGenerateService:
)
@classmethod
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow:
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow:
"""
Get workflow
:param app_model: app model
:param invoke_from: invoke from
:param workflow_id: optional workflow id to specify a specific version
:return:
"""
workflow_service = WorkflowService()
# If workflow_id is specified, get the specific workflow version
if workflow_id:
try:
workflow_uuid = uuid.UUID(workflow_id)
except ValueError:
raise WorkflowIdFormatError(f"Invalid workflow_id format: '{workflow_id}'. ")
workflow = workflow_service.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id)
if not workflow:
raise WorkflowNotFoundError(f"Workflow not found with id: {workflow_id}")
return workflow
if invoke_from == InvokeFrom.DEBUGGER:
# fetch draft workflow by app_model
workflow = workflow_service.get_draft_workflow(app_model=app_model)

View File

@ -159,9 +159,9 @@ class BillingService:
):
limiter_key = f"{account_id}:{tenant_id}"
if cls.compliance_download_rate_limiter.is_rate_limited(limiter_key):
from controllers.console.error import CompilanceRateLimitError
from controllers.console.error import ComplianceRateLimitError
raise CompilanceRateLimitError()
raise ComplianceRateLimitError()
json = {
"doc_name": doc_name,

View File

@ -1,12 +1,15 @@
from collections.abc import Callable, Sequence
from typing import Optional, Union
from typing import Any, Optional, Union
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from core.variables.types import SegmentType
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from extensions.ext_database import db
from factories import variable_factory
from libs.datetime_utils import naive_utc_now
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import ConversationVariable
@ -15,6 +18,7 @@ from models.model import App, Conversation, EndUser, Message
from services.errors.conversation import (
ConversationNotExistsError,
ConversationVariableNotExistsError,
ConversationVariableTypeMismatchError,
LastConversationNotExistsError,
)
from services.errors.message import MessageNotExistsError
@ -220,3 +224,82 @@ class ConversationService:
]
return InfiniteScrollPagination(variables, limit, has_more)
@classmethod
def update_conversation_variable(
cls,
app_model: App,
conversation_id: str,
variable_id: str,
user: Optional[Union[Account, EndUser]],
new_value: Any,
) -> dict:
"""
Update a conversation variable's value.
Args:
app_model: The app model
conversation_id: The conversation ID
variable_id: The variable ID to update
user: The user (Account or EndUser)
new_value: The new value for the variable
Returns:
Dictionary containing the updated variable information
Raises:
ConversationNotExistsError: If the conversation doesn't exist
ConversationVariableNotExistsError: If the variable doesn't exist
ConversationVariableTypeMismatchError: If the new value type doesn't match the variable's expected type
"""
# Verify conversation exists and user has access
conversation = cls.get_conversation(app_model, conversation_id, user)
# Get the existing conversation variable
stmt = (
select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id)
.where(ConversationVariable.conversation_id == conversation.id)
.where(ConversationVariable.id == variable_id)
)
with Session(db.engine) as session:
existing_variable = session.scalar(stmt)
if not existing_variable:
raise ConversationVariableNotExistsError()
# Convert existing variable to Variable object
current_variable = existing_variable.to_variable()
# Validate that the new value type matches the expected variable type
expected_type = SegmentType(current_variable.value_type)
if not expected_type.is_valid(new_value):
inferred_type = SegmentType.infer_segment_type(new_value)
raise ConversationVariableTypeMismatchError(
f"Type mismatch: variable '{current_variable.name}' expects {expected_type.value}, "
f"but got {inferred_type.value if inferred_type else 'unknown'} type"
)
# Create updated variable with new value only, preserving everything else
updated_variable_dict = {
"id": current_variable.id,
"name": current_variable.name,
"description": current_variable.description,
"value_type": current_variable.value_type,
"value": new_value,
"selector": current_variable.selector,
}
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
# Use the conversation variable updater to persist the changes
updater = conversation_variable_updater_factory()
updater.update(conversation_id, updated_variable)
updater.flush()
# Return the updated variable data
return {
"created_at": existing_variable.created_at,
"updated_at": naive_utc_now(), # Update timestamp
**updated_variable.model_dump(),
}

View File

@ -311,7 +311,7 @@ class DatasetService:
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(f"The dataset in unavailable, due to: {ex.description}")
raise ValueError(f"The dataset is unavailable, due to: {ex.description}")
@staticmethod
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str):
@ -426,7 +426,7 @@ class DatasetService:
raise ValueError("External knowledge api id is required.")
# Update metadata fields
dataset.updated_by = user.id if user else None
dataset.updated_at = datetime.datetime.utcnow()
dataset.updated_at = naive_utc_now()
db.session.add(dataset)
# Update external knowledge binding
@ -2498,6 +2498,7 @@ class SegmentService:
db.session.add(segment_document)
# update document word count
assert document.word_count is not None
document.word_count += segment_document.word_count
db.session.add(document)
db.session.commit()
@ -2582,6 +2583,7 @@ class SegmentService:
else:
keywords_list.append(None)
# update document word count
assert document.word_count is not None
document.word_count += increment_word_count
db.session.add(document)
try:
@ -2643,6 +2645,7 @@ class SegmentService:
db.session.commit()
# update document word count
if word_count_change != 0:
assert document.word_count is not None
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
# update segment index task
@ -2718,6 +2721,7 @@ class SegmentService:
word_count_change = segment.word_count - word_count_change
# update document word count
if word_count_change != 0:
assert document.word_count is not None
document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document)
db.session.add(segment)
@ -2781,6 +2785,7 @@ class SegmentService:
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id)
db.session.delete(segment)
# update document word count
assert document.word_count is not None
document.word_count -= segment.word_count
db.session.add(document)
db.session.commit()
@ -2825,7 +2830,7 @@ class SegmentService:
)
if not segments:
return
real_deal_segmment_ids = []
real_deal_segment_ids = []
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
@ -2835,10 +2840,10 @@ class SegmentService:
segment.disabled_at = None
segment.disabled_by = None
db.session.add(segment)
real_deal_segmment_ids.append(segment.id)
real_deal_segment_ids.append(segment.id)
db.session.commit()
enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
elif action == "disable":
segments = (
db.session.query(DocumentSegment)
@ -2852,7 +2857,7 @@ class SegmentService:
)
if not segments:
return
real_deal_segmment_ids = []
real_deal_segment_ids = []
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
@ -2862,10 +2867,10 @@ class SegmentService:
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id
db.session.add(segment)
real_deal_segmment_ids.append(segment.id)
real_deal_segment_ids.append(segment.id)
db.session.commit()
disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
else:
raise InvalidActionError()
@ -3123,7 +3128,7 @@ class SegmentService:
# check segment
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
if not segment:

View File

@ -8,3 +8,11 @@ class WorkflowHashNotEqualError(Exception):
class IsDraftWorkflowError(Exception):
pass
class WorkflowNotFoundError(Exception):
pass
class WorkflowIdFormatError(Exception):
pass

View File

@ -15,3 +15,7 @@ class ConversationCompletedError(Exception):
class ConversationVariableNotExistsError(BaseServiceError):
pass
class ConversationVariableTypeMismatchError(BaseServiceError):
pass

View File

@ -79,7 +79,10 @@ class MetadataService:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
for document in documents:
doc_metadata = copy.deepcopy(document.doc_metadata)
if not document.doc_metadata:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
value = doc_metadata.pop(old_name, None)
doc_metadata[name] = value
document.doc_metadata = doc_metadata
@ -109,7 +112,10 @@ class MetadataService:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
for document in documents:
doc_metadata = copy.deepcopy(document.doc_metadata)
if not document.doc_metadata:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata.pop(metadata.name, None)
document.doc_metadata = doc_metadata
db.session.add(document)
@ -137,7 +143,6 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset.id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
dataset.built_in_field_enabled = True
db.session.add(dataset)
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
if documents:
@ -153,6 +158,7 @@ class MetadataService:
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
document.doc_metadata = doc_metadata
db.session.add(document)
dataset.built_in_field_enabled = True
db.session.commit()
except Exception:
logging.exception("Enable built-in field failed")
@ -166,13 +172,15 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset.id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
dataset.built_in_field_enabled = False
db.session.add(dataset)
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
document_ids = []
if documents:
for document in documents:
doc_metadata = copy.deepcopy(document.doc_metadata)
if not document.doc_metadata:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata.pop(BuiltInField.document_name.value, None)
doc_metadata.pop(BuiltInField.uploader.value, None)
doc_metadata.pop(BuiltInField.upload_date.value, None)
@ -181,6 +189,7 @@ class MetadataService:
document.doc_metadata = doc_metadata
db.session.add(document)
document_ids.append(document.id)
dataset.built_in_field_enabled = False
db.session.commit()
except Exception:
logging.exception("Disable built-in field failed")

View File

@ -2,6 +2,7 @@ import json
import logging
import click
import sqlalchemy as sa
from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
from models.engine import db
@ -38,7 +39,7 @@ class PluginDataMigration:
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
limit 1000"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql))
rs = conn.execute(sa.text(sql))
current_iter_count = 0
for i in rs:
@ -94,7 +95,7 @@ limit 1000"""
:provider_name
{update_retrieval_model_sql}
where id = :record_id"""
conn.execute(db.text(sql), params)
conn.execute(sa.text(sql), params)
click.echo(
click.style(
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
@ -148,7 +149,7 @@ limit 1000"""
params = {"last_id": last_id or ""}
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql), params)
rs = conn.execute(sa.text(sql), params)
current_iter_count = 0
batch_updates = []
@ -193,7 +194,7 @@ limit 1000"""
SET {provider_column_name} = :updated_value
WHERE id = :record_id
"""
conn.execute(db.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
click.echo(
click.style(
f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]",

View File

@ -9,6 +9,7 @@ from typing import Any, Optional
from uuid import uuid4
import click
import sqlalchemy as sa
import tqdm
from flask import Flask, current_app
from sqlalchemy.orm import Session
@ -197,7 +198,7 @@ class PluginMigration:
"""
with Session(db.engine) as session:
rs = session.execute(
db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
sa.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
)
result = []
for row in rs:

View File

@ -486,10 +486,10 @@ class BuiltinToolManageService:
oauth_params = encrypter.decrypt(user_client.oauth_params)
return oauth_params
# only verified provider can use custom oauth client
is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
tenant_id, provider.plugin_unique_identifier
)
# only verified provider can use official oauth client
is_verified = not isinstance(
provider_controller, PluginToolProviderController
) or PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
if not is_verified:
return oauth_params

View File

@ -422,7 +422,7 @@ class WorkflowDraftVariableService:
description=conv_var.description,
)
draft_conv_vars.append(draft_var)
_batch_upsert_draft_varaible(
_batch_upsert_draft_variable(
self._session,
draft_conv_vars,
policy=_UpsertPolicy.IGNORE,
@ -434,7 +434,7 @@ class _UpsertPolicy(StrEnum):
OVERWRITE = "overwrite"
def _batch_upsert_draft_varaible(
def _batch_upsert_draft_variable(
session: Session,
draft_vars: Sequence[WorkflowDraftVariable],
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
@ -721,7 +721,7 @@ class DraftVariableSaver:
draft_vars = self._build_variables_from_start_mapping(outputs)
else:
draft_vars = self._build_variables_from_mapping(outputs)
_batch_upsert_draft_varaible(self._session, draft_vars)
_batch_upsert_draft_variable(self._session, draft_vars)
@staticmethod
def _should_variable_be_editable(node_id: str, name: str) -> bool:

View File

@ -129,7 +129,10 @@ class WorkflowService:
if not workflow:
return None
if workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError(f"Workflow is draft version, id={workflow_id}")
raise IsDraftWorkflowError(
f"Cannot use draft workflow version. Workflow ID: {workflow_id}. "
f"Please use a published workflow version or leave workflow_id empty."
)
return workflow
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
@ -442,9 +445,9 @@ class WorkflowService:
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
"""
Run draft workflow node
Run free workflow node
"""
# run draft workflow node
# run free workflow node
start_at = time.perf_counter()
node_execution = self._handle_node_run_result(

View File

@ -134,6 +134,7 @@ def batch_create_segment_to_index_task(
db.session.add(segment_document)
document_segments.append(segment_document)
# update document word count
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
db.session.add(dataset_document)
# add index to db

View File

@ -3,6 +3,7 @@ import time
from collections.abc import Callable
import click
import sqlalchemy as sa
from celery import shared_task # type: ignore
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
@ -331,7 +332,7 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
with db.engine.begin() as conn:
rs = conn.execute(db.text(query_sql), params)
rs = conn.execute(sa.text(query_sql), params)
if rs.rowcount == 0:
break

View File

@ -2,6 +2,7 @@ import os
import uuid
import tablestore
from _pytest.python_api import approx
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
TableStoreConfig,
@ -16,7 +17,7 @@ from tests.integration_tests.vdb.test_vector_store import (
class TableStoreVectorTest(AbstractVectorTest):
def __init__(self):
def __init__(self, normalize_full_text_score: bool = False):
super().__init__()
self.vector = TableStoreVector(
collection_name=self.collection_name,
@ -25,6 +26,7 @@ class TableStoreVectorTest(AbstractVectorTest):
instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"),
access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"),
access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"),
normalize_full_text_bm25_score=normalize_full_text_score,
),
)
@ -64,7 +66,21 @@ class TableStoreVectorTest(AbstractVectorTest):
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == self.example_doc_id
assert not hasattr(docs[0], "score")
if self.vector._config.normalize_full_text_bm25_score:
assert docs[0].metadata["score"] == approx(0.1214, abs=1e-3)
else:
assert docs[0].metadata.get("score") is None
# return none if normalize_full_text_score=true and score_threshold > 0
docs = self.vector.search_by_full_text(
get_example_text(), document_ids_filter=[self.example_doc_id], score_threshold=0.5
)
if self.vector._config.normalize_full_text_bm25_score:
assert len(docs) == 0
else:
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == self.example_doc_id
assert docs[0].metadata.get("score") is None
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
assert len(docs) == 0
@ -80,3 +96,5 @@ class TableStoreVectorTest(AbstractVectorTest):
def test_tablestore_vector(setup_mock_redis):
TableStoreVectorTest().run_all_tests()
TableStoreVectorTest(normalize_full_text_score=True).run_all_tests()
TableStoreVectorTest(normalize_full_text_score=False).run_all_tests()

View File

@ -0,0 +1,328 @@
"""
TestContainers-based integration test configuration for Dify API.
This module provides containerized test infrastructure using TestContainers library
to spin up real database and service instances for integration testing. This approach
ensures tests run against actual service implementations rather than mocks, providing
more reliable and realistic test scenarios.
"""
import logging
import os
from collections.abc import Generator
from typing import Optional
import pytest
from flask import Flask
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs
from testcontainers.postgres import PostgresContainer
from testcontainers.redis import RedisContainer
from app_factory import create_app
from models import db
# Configure logging for test containers
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class DifyTestContainers:
"""
Manages all test containers required for Dify integration tests.
This class provides a centralized way to manage multiple containers
needed for comprehensive integration testing, including databases,
caches, and search engines.
"""
def __init__(self):
"""Initialize container management with default configurations."""
self.postgres: Optional[PostgresContainer] = None
self.redis: Optional[RedisContainer] = None
self.dify_sandbox: Optional[DockerContainer] = None
self._containers_started = False
logger.info("DifyTestContainers initialized - ready to manage test containers")
def start_containers_with_env(self) -> None:
"""
Start all required containers for integration testing.
This method initializes and starts PostgreSQL, Redis
containers with appropriate configurations for Dify testing. Containers
are started in dependency order to ensure proper initialization.
"""
if self._containers_started:
logger.info("Containers already started - skipping container startup")
return
logger.info("Starting test containers for Dify integration tests...")
# Start PostgreSQL container for main application database
# PostgreSQL is used for storing user data, workflows, and application state
logger.info("Initializing PostgreSQL container...")
self.postgres = PostgresContainer(
image="postgres:16-alpine",
)
self.postgres.start()
db_host = self.postgres.get_container_host_ip()
db_port = self.postgres.get_exposed_port(5432)
os.environ["DB_HOST"] = db_host
os.environ["DB_PORT"] = str(db_port)
os.environ["DB_USERNAME"] = self.postgres.username
os.environ["DB_PASSWORD"] = self.postgres.password
os.environ["DB_DATABASE"] = self.postgres.dbname
logger.info(
"PostgreSQL container started successfully - Host: %s, Port: %s User: %s, Database: %s",
db_host,
db_port,
self.postgres.username,
self.postgres.dbname,
)
# Wait for PostgreSQL to be ready
logger.info("Waiting for PostgreSQL to be ready to accept connections...")
wait_for_logs(self.postgres, "is ready to accept connections", timeout=30)
logger.info("PostgreSQL container is ready and accepting connections")
# Install uuid-ossp extension for UUID generation
logger.info("Installing uuid-ossp extension...")
try:
import psycopg2
conn = psycopg2.connect(
host=db_host,
port=db_port,
user=self.postgres.username,
password=self.postgres.password,
database=self.postgres.dbname,
)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
cursor.close()
conn.close()
logger.info("uuid-ossp extension installed successfully")
except Exception as e:
logger.warning("Failed to install uuid-ossp extension: %s", e)
# Set up storage environment variables
os.environ["STORAGE_TYPE"] = "opendal"
os.environ["OPENDAL_SCHEME"] = "fs"
os.environ["OPENDAL_FS_ROOT"] = "storage"
# Start Redis container for caching and session management
# Redis is used for storing session data, cache entries, and temporary data
logger.info("Initializing Redis container...")
self.redis = RedisContainer(image="redis:latest", port=6379)
self.redis.start()
redis_host = self.redis.get_container_host_ip()
redis_port = self.redis.get_exposed_port(6379)
os.environ["REDIS_HOST"] = redis_host
os.environ["REDIS_PORT"] = str(redis_port)
logger.info("Redis container started successfully - Host: %s, Port: %s", redis_host, redis_port)
# Wait for Redis to be ready
logger.info("Waiting for Redis to be ready to accept connections...")
wait_for_logs(self.redis, "Ready to accept connections", timeout=30)
logger.info("Redis container is ready and accepting connections")
# Start Dify Sandbox container for code execution environment
# Dify Sandbox provides a secure environment for executing user code
logger.info("Initializing Dify Sandbox container...")
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest")
self.dify_sandbox.with_exposed_ports(8194)
self.dify_sandbox.env = {
"API_KEY": "test_api_key",
}
self.dify_sandbox.start()
sandbox_host = self.dify_sandbox.get_container_host_ip()
sandbox_port = self.dify_sandbox.get_exposed_port(8194)
os.environ["CODE_EXECUTION_ENDPOINT"] = f"http://{sandbox_host}:{sandbox_port}"
os.environ["CODE_EXECUTION_API_KEY"] = "test_api_key"
logger.info("Dify Sandbox container started successfully - Host: %s, Port: %s", sandbox_host, sandbox_port)
# Wait for Dify Sandbox to be ready
logger.info("Waiting for Dify Sandbox to be ready to accept connections...")
wait_for_logs(self.dify_sandbox, "config init success", timeout=60)
logger.info("Dify Sandbox container is ready and accepting connections")
self._containers_started = True
logger.info("All test containers started successfully")
def stop_containers(self) -> None:
"""
Stop and clean up all test containers.
This method ensures proper cleanup of all containers to prevent
resource leaks and conflicts between test runs.
"""
if not self._containers_started:
logger.info("No containers to stop - containers were not started")
return
logger.info("Stopping and cleaning up test containers...")
containers = [self.redis, self.postgres, self.dify_sandbox]
for container in containers:
if container:
try:
container_name = container.image
logger.info("Stopping container: %s", container_name)
container.stop()
logger.info("Successfully stopped container: %s", container_name)
except Exception as e:
# Log error but don't fail the test cleanup
logger.warning("Failed to stop container %s: %s", container, e)
self._containers_started = False
logger.info("All test containers stopped and cleaned up successfully")
# Global container manager instance
_container_manager = DifyTestContainers()
def _create_app_with_containers() -> Flask:
"""
Create Flask application configured to use test containers.
This function creates a Flask application instance that is configured
to connect to the test containers instead of the default development
or production databases.
Returns:
Flask: Configured Flask application for containerized testing
"""
logger.info("Creating Flask application with test container configuration...")
# Re-create the config after environment variables have been set
from configs import dify_config
# Force re-creation of config with new environment variables
dify_config.__dict__.clear()
dify_config.__init__()
# Create and configure the Flask application
logger.info("Initializing Flask application...")
app = create_app()
logger.info("Flask application created successfully")
# Initialize database schema
logger.info("Creating database schema...")
with app.app_context():
db.create_all()
logger.info("Database schema created successfully")
logger.info("Flask application configured and ready for testing")
return app
@pytest.fixture(scope="session")
def set_up_containers_and_env() -> Generator[DifyTestContainers, None, None]:
"""
Session-scoped fixture to manage test containers.
This fixture ensures containers are started once per test session
and properly cleaned up when all tests are complete. This approach
improves test performance by reusing containers across multiple tests.
Yields:
DifyTestContainers: Container manager instance
"""
logger.info("=== Starting test session container management ===")
_container_manager.start_containers_with_env()
logger.info("Test containers ready for session")
yield _container_manager
logger.info("=== Cleaning up test session containers ===")
_container_manager.stop_containers()
logger.info("Test session container cleanup completed")
@pytest.fixture(scope="session")
def flask_app_with_containers(set_up_containers_and_env) -> Flask:
"""
Session-scoped Flask application fixture using test containers.
This fixture provides a Flask application instance that is configured
to use the test containers for all database and service connections.
Args:
containers: Container manager fixture
Returns:
Flask: Configured Flask application
"""
logger.info("=== Creating session-scoped Flask application ===")
app = _create_app_with_containers()
logger.info("Session-scoped Flask application created successfully")
return app
@pytest.fixture
def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]:
"""
Request context fixture for containerized Flask application.
This fixture provides a Flask request context for tests that need
to interact with the Flask application within a request scope.
Args:
flask_app_with_containers: Flask application fixture
Yields:
None: Request context is active during yield
"""
logger.debug("Creating Flask request context...")
with flask_app_with_containers.test_request_context():
logger.debug("Flask request context active")
yield
logger.debug("Flask request context closed")
@pytest.fixture
def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]:
"""
Test client fixture for containerized Flask application.
This fixture provides a Flask test client that can be used to make
HTTP requests to the containerized application for integration testing.
Args:
flask_app_with_containers: Flask application fixture
Yields:
FlaskClient: Test client instance
"""
logger.debug("Creating Flask test client...")
with flask_app_with_containers.test_client() as client:
logger.debug("Flask test client ready")
yield client
logger.debug("Flask test client closed")
@pytest.fixture
def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]:
"""
Database session fixture for containerized testing.
This fixture provides a SQLAlchemy database session that is connected
to the test PostgreSQL container, allowing tests to interact with
the database directly.
Args:
flask_app_with_containers: Flask application fixture
Yields:
Session: Database session instance
"""
logger.debug("Creating database session...")
with flask_app_with_containers.app_context():
session = db.session()
logger.debug("Database session created and ready")
try:
yield session
finally:
session.close()
logger.debug("Database session closed")

View File

@ -0,0 +1,371 @@
import unittest
from datetime import UTC, datetime
from typing import Optional
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.file import File, FileTransferMethod, FileType
from extensions.ext_database import db
from factories.file_factory import StorageKeyLoader
from models import ToolFile, UploadFile
from models.enums import CreatorUserRole
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
class TestStorageKeyLoader(unittest.TestCase):
"""
Integration tests for StorageKeyLoader class.
Tests the batched loading of storage keys from the database for files
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
"""
def setUp(self):
"""Set up test data before each test method."""
self.session = db.session()
self.tenant_id = str(uuid4())
self.user_id = str(uuid4())
self.conversation_id = str(uuid4())
# Create test data that will be cleaned up after each test
self.test_upload_files = []
self.test_tool_files = []
# Create StorageKeyLoader instance
self.loader = StorageKeyLoader(self.session, self.tenant_id)
def tearDown(self):
"""Clean up test data after each test method."""
self.session.rollback()
def _create_upload_file(
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
) -> UploadFile:
"""Helper method to create an UploadFile record for testing."""
if file_id is None:
file_id = str(uuid4())
if storage_key is None:
storage_key = f"test_storage_key_{uuid4()}"
if tenant_id is None:
tenant_id = self.tenant_id
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
key=storage_key,
name="test_file.txt",
size=1024,
extension=".txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
created_at=datetime.now(UTC),
used=False,
)
upload_file.id = file_id
self.session.add(upload_file)
self.session.flush()
self.test_upload_files.append(upload_file)
return upload_file
def _create_tool_file(
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
) -> ToolFile:
"""Helper method to create a ToolFile record for testing."""
if file_id is None:
file_id = str(uuid4())
if file_key is None:
file_key = f"test_file_key_{uuid4()}"
if tenant_id is None:
tenant_id = self.tenant_id
tool_file = ToolFile()
tool_file.id = file_id
tool_file.user_id = self.user_id
tool_file.tenant_id = tenant_id
tool_file.conversation_id = self.conversation_id
tool_file.file_key = file_key
tool_file.mimetype = "text/plain"
tool_file.original_url = "http://example.com/file.txt"
tool_file.name = "test_tool_file.txt"
tool_file.size = 2048
self.session.add(tool_file)
self.session.flush()
self.test_tool_files.append(tool_file)
return tool_file
def _create_file(
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
) -> File:
"""Helper method to create a File object for testing."""
if tenant_id is None:
tenant_id = self.tenant_id
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
file_related_id = None
remote_url = None
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
file_related_id = related_id
elif transfer_method == FileTransferMethod.REMOTE_URL:
remote_url = "https://example.com/test_file.txt"
file_related_id = related_id
return File(
id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id,
type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=file_related_id,
remote_url=remote_url,
filename="test_file.txt",
extension=".txt",
mime_type="text/plain",
size=1024,
storage_key="initial_key",
)
def test_load_storage_keys_local_file(self):
"""Test loading storage keys for LOCAL_FILE transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Load storage keys
self.loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == upload_file.key
def test_load_storage_keys_remote_url(self):
"""Test loading storage keys for REMOTE_URL transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
# Load storage keys
self.loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == upload_file.key
def test_load_storage_keys_tool_file(self):
"""Test loading storage keys for TOOL_FILE transfer method."""
# Create test data
tool_file = self._create_tool_file()
file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
# Load storage keys
self.loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == tool_file.file_key
def test_load_storage_keys_mixed_methods(self):
"""Test batch loading with mixed transfer methods."""
# Create test data for different transfer methods
upload_file1 = self._create_upload_file()
upload_file2 = self._create_upload_file()
tool_file = self._create_tool_file()
file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
files = [file1, file2, file3]
# Load storage keys
self.loader.load_storage_keys(files)
# Verify all storage keys were loaded correctly
assert file1._storage_key == upload_file1.key
assert file2._storage_key == upload_file2.key
assert file3._storage_key == tool_file.file_key
def test_load_storage_keys_empty_list(self):
"""Test with empty file list."""
# Should not raise any exceptions
self.loader.load_storage_keys([])
def test_load_storage_keys_tenant_mismatch(self):
"""Test tenant_id validation."""
# Create file with different tenant_id
upload_file = self._create_upload_file()
file = self._create_file(
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
)
# Should raise ValueError for tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file])
assert "invalid file, expected tenant_id" in str(context.value)
def test_load_storage_keys_missing_file_id(self):
"""Test with None file.related_id."""
# Create a file with valid parameters first, then manually set related_id to None
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
file.related_id = None
# Should raise ValueError for None file related_id
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file])
assert str(context.value) == "file id should not be None."
def test_load_storage_keys_nonexistent_upload_file_records(self):
"""Test with missing UploadFile database records."""
# Create file with non-existent upload file id
non_existent_id = str(uuid4())
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Should raise ValueError for missing record
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
def test_load_storage_keys_nonexistent_tool_file_records(self):
"""Test with missing ToolFile database records."""
# Create file with non-existent tool file id
non_existent_id = str(uuid4())
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)
# Should raise ValueError for missing record
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
def test_load_storage_keys_invalid_uuid(self):
"""Test with invalid UUID format."""
# Create a file with valid parameters first, then manually set invalid related_id
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
file.related_id = "invalid-uuid-format"
# Should raise ValueError for invalid UUID
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
def test_load_storage_keys_batch_efficiency(self):
"""Test batched operations use efficient queries."""
# Create multiple files of different types
upload_files = [self._create_upload_file() for _ in range(3)]
tool_files = [self._create_tool_file() for _ in range(2)]
files = []
files.extend(
[self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
)
files.extend(
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
)
# Mock the session to count queries
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
self.loader.load_storage_keys(files)
# Should make exactly 2 queries (one for upload_files, one for tool_files)
assert mock_scalars.call_count == 2
# Verify all storage keys were loaded correctly
for i, file in enumerate(files[:3]):
assert file._storage_key == upload_files[i].key
for i, file in enumerate(files[3:]):
assert file._storage_key == tool_files[i].file_key
def test_load_storage_keys_tenant_isolation(self):
"""Test that tenant isolation works correctly."""
# Create files for different tenants
other_tenant_id = str(uuid4())
# Create upload file for current tenant
upload_file_current = self._create_upload_file()
file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)
# Create upload file for other tenant (but don't add to cleanup list)
upload_file_other = UploadFile(
tenant_id=other_tenant_id,
storage_type="local",
key="other_tenant_key",
name="other_file.txt",
size=1024,
extension=".txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
created_at=datetime.now(UTC),
used=False,
)
upload_file_other.id = str(uuid4())
self.session.add(upload_file_other)
self.session.flush()
# Create file for other tenant but try to load with current tenant's loader
file_other = self._create_file(
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
)
# Should raise ValueError due to tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file_other])
assert "invalid file, expected tenant_id" in str(context.value)
# Current tenant's file should still work
self.loader.load_storage_keys([file_current])
assert file_current._storage_key == upload_file_current.key
def test_load_storage_keys_mixed_tenant_batch(self):
"""Test batch with mixed tenant files (should fail on first mismatch)."""
# Create files for current tenant
upload_file_current = self._create_upload_file()
file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)
# Create file for different tenant
other_tenant_id = str(uuid4())
file_other = self._create_file(
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
)
# Should raise ValueError on tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file_current, file_other])
assert "invalid file, expected tenant_id" in str(context.value)
def test_load_storage_keys_duplicate_file_ids(self):
"""Test handling of duplicate file IDs in the batch."""
# Create upload file
upload_file = self._create_upload_file()
# Create two File objects with same related_id
file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Should handle duplicates gracefully
self.loader.load_storage_keys([file1, file2])
# Both files should have the same storage key
assert file1._storage_key == upload_file.key
assert file2._storage_key == upload_file.key
def test_load_storage_keys_session_isolation(self):
"""Test that the loader uses the provided session correctly."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Create loader with different session (same underlying connection)
with Session(bind=db.engine) as other_session:
other_loader = StorageKeyLoader(other_session, self.tenant_id)
with pytest.raises(ValueError):
other_loader.load_storage_keys([file])

View File

@ -0,0 +1,11 @@
import pytest
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
CODE_LANGUAGE = "unsupported_language"
def test_unsupported_with_code_template():
with pytest.raises(CodeExecutionError) as e:
CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={})
assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"

View File

@ -0,0 +1,47 @@
from textwrap import dedent
from .test_utils import CodeExecutorTestMixin
class TestJavaScriptCodeExecutor(CodeExecutorTestMixin):
"""Test class for JavaScript code executor functionality."""
def test_javascript_plain(self, flask_app_with_containers):
"""Test basic JavaScript code execution with console.log output"""
CodeExecutor, CodeLanguage = self.code_executor_imports
code = 'console.log("Hello World")'
result_message = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code)
assert result_message == "Hello World\n"
def test_javascript_json(self, flask_app_with_containers):
"""Test JavaScript code execution with JSON output"""
CodeExecutor, CodeLanguage = self.code_executor_imports
code = dedent("""
obj = {'Hello': 'World'}
console.log(JSON.stringify(obj))
""")
result = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code)
assert result == '{"Hello":"World"}\n'
def test_javascript_with_code_template(self, flask_app_with_containers):
"""Test JavaScript workflow code template execution with inputs"""
CodeExecutor, CodeLanguage = self.code_executor_imports
JavascriptCodeProvider, _ = self.javascript_imports
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JAVASCRIPT,
code=JavascriptCodeProvider.get_default_code(),
inputs={"arg1": "Hello", "arg2": "World"},
)
assert result == {"result": "HelloWorld"}
def test_javascript_get_runner_script(self, flask_app_with_containers):
"""Test JavaScript template transformer runner script generation"""
_, NodeJsTemplateTransformer = self.javascript_imports
runner_script = NodeJsTemplateTransformer.get_runner_script()
assert runner_script.count(NodeJsTemplateTransformer._code_placeholder) == 1
assert runner_script.count(NodeJsTemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(NodeJsTemplateTransformer._result_tag) == 2

View File

@ -0,0 +1,42 @@
import base64
from .test_utils import CodeExecutorTestMixin
class TestJinja2CodeExecutor(CodeExecutorTestMixin):
"""Test class for Jinja2 code executor functionality."""
def test_jinja2(self, flask_app_with_containers):
"""Test basic Jinja2 template execution with variable substitution"""
CodeExecutor, CodeLanguage = self.code_executor_imports
_, Jinja2TemplateTransformer = self.jinja2_imports
template = "Hello {{template}}"
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
code = (
Jinja2TemplateTransformer.get_runner_script()
.replace(Jinja2TemplateTransformer._code_placeholder, template)
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
)
result = CodeExecutor.execute_code(
language=CodeLanguage.JINJA2, preload=Jinja2TemplateTransformer.get_preload_script(), code=code
)
assert result == "<<RESULT>>Hello World<<RESULT>>\n"
def test_jinja2_with_code_template(self, flask_app_with_containers):
"""Test Jinja2 workflow code template execution with inputs"""
CodeExecutor, CodeLanguage = self.code_executor_imports
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code="Hello {{template}}", inputs={"template": "World"}
)
assert result == {"result": "Hello World"}
def test_jinja2_get_runner_script(self, flask_app_with_containers):
"""Test Jinja2 template transformer runner script generation"""
_, Jinja2TemplateTransformer = self.jinja2_imports
runner_script = Jinja2TemplateTransformer.get_runner_script()
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2

View File

@ -0,0 +1,47 @@
from textwrap import dedent
from .test_utils import CodeExecutorTestMixin
class TestPython3CodeExecutor(CodeExecutorTestMixin):
"""Test class for Python3 code executor functionality."""
def test_python3_plain(self, flask_app_with_containers):
"""Test basic Python3 code execution with print output"""
CodeExecutor, CodeLanguage = self.code_executor_imports
code = 'print("Hello World")'
result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code)
assert result == "Hello World\n"
def test_python3_json(self, flask_app_with_containers):
"""Test Python3 code execution with JSON output"""
CodeExecutor, CodeLanguage = self.code_executor_imports
code = dedent("""
import json
print(json.dumps({'Hello': 'World'}))
""")
result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code)
assert result == '{"Hello": "World"}\n'
def test_python3_with_code_template(self, flask_app_with_containers):
"""Test Python3 workflow code template execution with inputs"""
CodeExecutor, CodeLanguage = self.code_executor_imports
Python3CodeProvider, _ = self.python3_imports
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.PYTHON3,
code=Python3CodeProvider.get_default_code(),
inputs={"arg1": "Hello", "arg2": "World"},
)
assert result == {"result": "HelloWorld"}
def test_python3_get_runner_script(self, flask_app_with_containers):
"""Test Python3 template transformer runner script generation"""
_, Python3TemplateTransformer = self.python3_imports
runner_script = Python3TemplateTransformer.get_runner_script()
assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1
assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Python3TemplateTransformer._result_tag) == 2

View File

@ -0,0 +1,115 @@
"""
Test utilities for code executor integration tests.
This module provides lazy import functions to avoid module loading issues
that occur when modules are imported before the flask_app_with_containers fixture
has set up the proper environment variables and configuration.
"""
import importlib
from typing import TYPE_CHECKING
if TYPE_CHECKING:
pass
def force_reload_code_executor():
"""
Force reload the code_executor module to reinitialize code_execution_endpoint_url.
This function should be called after setting up environment variables
to ensure the code_execution_endpoint_url is initialized with the correct value.
"""
try:
import core.helper.code_executor.code_executor
importlib.reload(core.helper.code_executor.code_executor)
except Exception as e:
# Log the error but don't fail the test
print(f"Warning: Failed to reload code_executor module: {e}")
def get_code_executor_imports():
"""
Lazy import function for core CodeExecutor classes.
Returns:
tuple: (CodeExecutor, CodeLanguage) classes
"""
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
return CodeExecutor, CodeLanguage
def get_javascript_imports():
"""
Lazy import function for JavaScript-specific modules.
Returns:
tuple: (JavascriptCodeProvider, NodeJsTemplateTransformer) classes
"""
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
return JavascriptCodeProvider, NodeJsTemplateTransformer
def get_python3_imports():
"""
Lazy import function for Python3-specific modules.
Returns:
tuple: (Python3CodeProvider, Python3TemplateTransformer) classes
"""
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
return Python3CodeProvider, Python3TemplateTransformer
def get_jinja2_imports():
"""
Lazy import function for Jinja2-specific modules.
Returns:
tuple: (None, Jinja2TemplateTransformer) classes
"""
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
return None, Jinja2TemplateTransformer
class CodeExecutorTestMixin:
"""
Mixin class providing lazy import methods for code executor tests.
This mixin helps avoid module loading issues by deferring imports
until after the flask_app_with_containers fixture has set up the environment.
"""
def setup_method(self):
"""
Setup method called before each test method.
Force reload the code_executor module to ensure fresh initialization.
"""
force_reload_code_executor()
@property
def code_executor_imports(self):
"""Property to get CodeExecutor and CodeLanguage classes."""
return get_code_executor_imports()
@property
def javascript_imports(self):
"""Property to get JavaScript-specific classes."""
return get_javascript_imports()
@property
def python3_imports(self):
"""Property to get Python3-specific classes."""
return get_python3_imports()
@property
def jinja2_imports(self):
"""Property to get Jinja2-specific classes."""
return get_jinja2_imports()

View File

@ -0,0 +1,278 @@
import io
from unittest.mock import patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.common.errors import FilenameNotExistsError
from controllers.console.error import (
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
class TestFileUploadSecurity:
"""Test file upload security logic without complex framework setup"""
# Test 1: Basic file validation
def test_should_validate_file_presence(self):
"""Test that missing file is detected"""
from flask import Flask, request
app = Flask(__name__)
with app.test_request_context(method="POST", data={}):
# Simulate the check in FileApi.post()
if "file" not in request.files:
with pytest.raises(NoFileUploadedError):
raise NoFileUploadedError()
def test_should_validate_multiple_files(self):
"""Test that multiple files are rejected"""
from flask import Flask, request
app = Flask(__name__)
file_data = {
"file": (io.BytesIO(b"content1"), "file1.txt", "text/plain"),
"file2": (io.BytesIO(b"content2"), "file2.txt", "text/plain"),
}
with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"):
# Simulate the check in FileApi.post()
if len(request.files) > 1:
with pytest.raises(TooManyFilesError):
raise TooManyFilesError()
def test_should_validate_empty_filename(self):
"""Test that empty filename is rejected"""
from flask import Flask, request
app = Flask(__name__)
file_data = {"file": (io.BytesIO(b"content"), "", "text/plain")}
with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"):
file = request.files["file"]
if not file.filename:
with pytest.raises(FilenameNotExistsError):
raise FilenameNotExistsError
# Test 2: Security - Filename sanitization
def test_should_detect_path_traversal_in_filename(self):
"""Test protection against directory traversal attacks"""
dangerous_filenames = [
"../../../etc/passwd",
"..\\..\\windows\\system32\\config\\sam",
"../../../../etc/shadow",
"./../../../sensitive.txt",
]
for filename in dangerous_filenames:
# Any filename containing .. should be considered dangerous
assert ".." in filename, f"Filename {filename} should be detected as path traversal"
def test_should_detect_null_byte_injection(self):
"""Test protection against null byte injection"""
dangerous_filenames = [
"file.jpg\x00.php",
"document.pdf\x00.exe",
"image.png\x00.sh",
]
for filename in dangerous_filenames:
# Null bytes should be detected
assert "\x00" in filename, f"Filename {filename} should be detected as null byte injection"
def test_should_sanitize_special_characters(self):
"""Test that special characters in filenames are handled safely"""
# Characters that could be problematic in various contexts
dangerous_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\x00"]
for char in dangerous_chars:
filename = f"file{char}name.txt"
# These characters should be detected or sanitized
assert any(c in filename for c in dangerous_chars)
# Test 3: Permission validation
def test_should_validate_dataset_permissions(self):
"""Test dataset upload permission logic"""
class MockUser:
is_dataset_editor = False
user = MockUser()
source = "datasets"
# Simulate the permission check in FileApi.post()
if source == "datasets" and not user.is_dataset_editor:
with pytest.raises(Forbidden):
raise Forbidden()
def test_should_allow_general_upload_without_permission(self):
"""Test general upload doesn't require dataset permission"""
class MockUser:
is_dataset_editor = False
user = MockUser()
source = None # General upload
# This should not raise an exception
if source == "datasets" and not user.is_dataset_editor:
raise Forbidden()
# Test passes if no exception is raised
# Test 4: Service error handling
@patch("services.file_service.FileService.upload_file")
def test_should_handle_file_too_large_error(self, mock_upload):
"""Test that service FileTooLargeError is properly converted"""
mock_upload.side_effect = ServiceFileTooLargeError("File too large")
try:
mock_upload(filename="test.txt", content=b"data", mimetype="text/plain", user=None, source=None)
except ServiceFileTooLargeError as e:
# Simulate the error conversion in FileApi.post()
with pytest.raises(FileTooLargeError):
raise FileTooLargeError(e.description)
@patch("services.file_service.FileService.upload_file")
def test_should_handle_unsupported_file_type_error(self, mock_upload):
"""Test that service UnsupportedFileTypeError is properly converted"""
mock_upload.side_effect = ServiceUnsupportedFileTypeError()
try:
mock_upload(
filename="test.exe", content=b"data", mimetype="application/octet-stream", user=None, source=None
)
except ServiceUnsupportedFileTypeError:
# Simulate the error conversion in FileApi.post()
with pytest.raises(UnsupportedFileTypeError):
raise UnsupportedFileTypeError()
# Test 5: File type security
def test_should_identify_dangerous_file_extensions(self):
"""Test detection of potentially dangerous file extensions"""
dangerous_extensions = [
".php",
".PHP",
".pHp", # PHP files (case variations)
".exe",
".EXE", # Executables
".sh",
".SH", # Shell scripts
".bat",
".BAT", # Batch files
".cmd",
".CMD", # Command files
".ps1",
".PS1", # PowerShell
".jar",
".JAR", # Java archives
".vbs",
".VBS", # VBScript
]
safe_extensions = [".txt", ".pdf", ".jpg", ".png", ".doc", ".docx"]
# Just verify our test data is correct
for ext in dangerous_extensions:
assert ext.lower() in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"]
for ext in safe_extensions:
assert ext.lower() not in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"]
def test_should_detect_double_extensions(self):
"""Test detection of double extension attacks"""
suspicious_filenames = [
"image.jpg.php",
"document.pdf.exe",
"photo.png.sh",
"file.txt.bat",
]
for filename in suspicious_filenames:
# Check that these have multiple extensions
parts = filename.split(".")
assert len(parts) > 2, f"Filename {filename} should have multiple extensions"
# Test 6: Configuration validation
def test_upload_configuration_structure(self):
"""Test that upload configuration has correct structure"""
# Simulate the configuration returned by FileApi.get()
config = {
"file_size_limit": 15,
"batch_count_limit": 5,
"image_file_size_limit": 10,
"video_file_size_limit": 500,
"audio_file_size_limit": 50,
"workflow_file_upload_limit": 10,
}
# Verify all required fields are present
required_fields = [
"file_size_limit",
"batch_count_limit",
"image_file_size_limit",
"video_file_size_limit",
"audio_file_size_limit",
"workflow_file_upload_limit",
]
for field in required_fields:
assert field in config, f"Missing required field: {field}"
assert isinstance(config[field], int), f"Field {field} should be an integer"
assert config[field] > 0, f"Field {field} should be positive"
# Test 7: Source parameter handling
def test_source_parameter_normalization(self):
"""Test that source parameter is properly normalized"""
test_cases = [
("datasets", "datasets"),
("other", None),
("", None),
(None, None),
]
for input_source, expected in test_cases:
# Simulate the source normalization in FileApi.post()
source = "datasets" if input_source == "datasets" else None
if source not in ("datasets", None):
source = None
assert source == expected
# Test 8: Boundary conditions
def test_should_handle_edge_case_file_sizes(self):
"""Test handling of boundary file sizes"""
test_cases = [
(0, "Empty file"), # 0 bytes
(1, "Single byte"), # 1 byte
(15 * 1024 * 1024 - 1, "Just under limit"), # Just under 15MB
(15 * 1024 * 1024, "At limit"), # Exactly 15MB
(15 * 1024 * 1024 + 1, "Just over limit"), # Just over 15MB
]
for size, description in test_cases:
# Just verify our test data
assert isinstance(size, int), f"{description}: Size should be integer"
assert size >= 0, f"{description}: Size should be non-negative"
def test_should_handle_special_mime_types(self):
"""Test handling of various MIME types"""
mime_type_tests = [
("application/octet-stream", "Generic binary"),
("text/plain", "Plain text"),
("image/jpeg", "JPEG image"),
("application/pdf", "PDF document"),
("", "Empty MIME type"),
(None, "None MIME type"),
]
for mime_type, description in mime_type_tests:
# Verify test data structure
if mime_type is not None:
assert isinstance(mime_type, str), f"{description}: MIME type should be string or None"

View File

@ -1,4 +1,4 @@
from core.variables.types import SegmentType
from core.variables.types import ArrayValidation, SegmentType
class TestSegmentTypeIsArrayType:
@ -17,7 +17,6 @@ class TestSegmentTypeIsArrayType:
value is tested for the is_array_type method.
"""
# Arrange
all_segment_types = set(SegmentType)
expected_array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
@ -58,3 +57,27 @@ class TestSegmentTypeIsArrayType:
for seg_type in enum_values:
is_array = seg_type.is_array_type()
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
class TestSegmentTypeIsValidArrayValidation:
"""
Test SegmentType.is_valid with array types using different validation strategies.
"""
def test_array_validation_all_success(self):
value = ["hello", "world", "foo"]
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
def test_array_validation_all_fail(self):
value = ["hello", 123, "world"]
# Should return False, since 123 is not a string
assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
def test_array_validation_first(self):
value = ["hello", 123, None]
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST)
def test_array_validation_none(self):
value = [1, 2, 3]
# validation is None, skip
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE)

View File

@ -1318,6 +1318,7 @@ dev = [
{ name = "pytest-mock" },
{ name = "ruff" },
{ name = "scipy-stubs" },
{ name = "testcontainers" },
{ name = "types-aiofiles" },
{ name = "types-beautifulsoup4" },
{ name = "types-cachetools" },
@ -1500,6 +1501,7 @@ dev = [
{ name = "pytest-mock", specifier = "~=3.14.0" },
{ name = "ruff", specifier = "~=0.12.3" },
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
{ name = "testcontainers", specifier = "~=4.10.0" },
{ name = "types-aiofiles", specifier = "~=24.1.0" },
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
{ name = "types-cachetools", specifier = "~=5.5.0" },
@ -1600,6 +1602,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
]
[[package]]
name = "docker"
version = "7.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "requests" },
{ name = "urllib3" },
]
sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" },
]
[[package]]
name = "docstring-parser"
version = "0.16"
@ -5468,6 +5484,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" },
]
[[package]]
name = "testcontainers"
version = "4.10.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "docker" },
{ name = "python-dotenv" },
{ name = "typing-extensions" },
{ name = "urllib3" },
{ name = "wrapt" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a1/49/9c618aff1c50121d183cdfbc3a4a5cf2727a2cde1893efe6ca55c7009196/testcontainers-4.10.0.tar.gz", hash = "sha256:03f85c3e505d8b4edeb192c72a961cebbcba0dd94344ae778b4a159cb6dcf8d3", size = 63327, upload-time = "2025-04-02T16:13:27.582Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1c/0a/824b0c1ecf224802125279c3effff2e25ed785ed046e67da6e53d928de4c/testcontainers-4.10.0-py3-none-any.whl", hash = "sha256:31ed1a81238c7e131a2a29df6db8f23717d892b592fa5a1977fd0dcd0c23fc23", size = 107414, upload-time = "2025-04-02T16:13:25.785Z" },
]
[[package]]
name = "tidb-vector"
version = "0.0.9"

View File

@ -15,3 +15,6 @@ dev/pytest/pytest_workflow.sh
# Unit tests
dev/pytest/pytest_unit_tests.sh
# TestContainers tests
dev/pytest/pytest_testcontainers.sh

View File

@ -0,0 +1,7 @@
#!/bin/bash
set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
pytest api/tests/test_containers_integration_tests

View File

@ -653,6 +653,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com
TABLESTORE_INSTANCE_NAME=instance-name
TABLESTORE_ACCESS_KEY_ID=xxx
TABLESTORE_ACCESS_KEY_SECRET=xxx
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false
# ------------------------------
# Knowledge Configuration

View File

@ -312,6 +312,7 @@ x-shared-env: &shared-api-worker-env
TABLESTORE_INSTANCE_NAME: ${TABLESTORE_INSTANCE_NAME:-instance-name}
TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx}
TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx}
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false}
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
ETL_TYPE: ${ETL_TYPE:-dify}

View File

@ -265,7 +265,6 @@ export default translation
fs.writeFileSync(path.join(testZhDir, 'pages.ts'), file2Content)
const allEnKeys = await getKeysFromLanguage('en-US')
const allZhKeys = await getKeysFromLanguage('zh-Hans')
// Test file filtering logic
const targetFile = 'components'
@ -563,4 +562,201 @@ export default translation
expect(enKeys.length - zhKeysExtra.length).toBe(-2) // -2 means 2 extra keys
})
})
describe('Auto-remove multiline key-value pairs', () => {
// Helper function to simulate removeExtraKeysFromFile logic
function removeExtraKeysFromFile(content: string, keysToRemove: string[]): string {
const lines = content.split('\n')
const linesToRemove: number[] = []
for (const keyToRemove of keysToRemove) {
let targetLineIndex = -1
const linesToRemoveForKey: number[] = []
// Find the key line (simplified for single-level keys in test)
for (let i = 0; i < lines.length; i++) {
const line = lines[i]
const keyPattern = new RegExp(`^\\s*${keyToRemove}\\s*:`)
if (keyPattern.test(line)) {
targetLineIndex = i
break
}
}
if (targetLineIndex !== -1) {
linesToRemoveForKey.push(targetLineIndex)
// Check if this is a multiline key-value pair
const keyLine = lines[targetLineIndex]
const trimmedKeyLine = keyLine.trim()
// If key line ends with ":" (not complete value), it's likely multiline
if (trimmedKeyLine.endsWith(':') && !trimmedKeyLine.includes('{') && !trimmedKeyLine.match(/:\s*['"`]/)) {
// Find the value lines that belong to this key
let currentLine = targetLineIndex + 1
let foundValue = false
while (currentLine < lines.length) {
const line = lines[currentLine]
const trimmed = line.trim()
// Skip empty lines
if (trimmed === '') {
currentLine++
continue
}
// Check if this line starts a new key (indicates end of current value)
if (trimmed.match(/^\w+\s*:/))
break
// Check if this line is part of the value
if (trimmed.startsWith('\'') || trimmed.startsWith('"') || trimmed.startsWith('`') || foundValue) {
linesToRemoveForKey.push(currentLine)
foundValue = true
// Check if this line ends the value (ends with quote and comma/no comma)
if ((trimmed.endsWith('\',') || trimmed.endsWith('",') || trimmed.endsWith('`,')
|| trimmed.endsWith('\'') || trimmed.endsWith('"') || trimmed.endsWith('`'))
&& !trimmed.startsWith('//'))
break
}
else {
break
}
currentLine++
}
}
linesToRemove.push(...linesToRemoveForKey)
}
}
// Remove duplicates and sort in reverse order
const uniqueLinesToRemove = [...new Set(linesToRemove)].sort((a, b) => b - a)
for (const lineIndex of uniqueLinesToRemove)
lines.splice(lineIndex, 1)
return lines.join('\n')
}
it('should remove single-line key-value pairs correctly', () => {
const content = `const translation = {
keepThis: 'This should stay',
removeThis: 'This should be removed',
alsoKeep: 'This should also stay',
}
export default translation`
const result = removeExtraKeysFromFile(content, ['removeThis'])
expect(result).toContain('keepThis: \'This should stay\'')
expect(result).toContain('alsoKeep: \'This should also stay\'')
expect(result).not.toContain('removeThis: \'This should be removed\'')
})
it('should remove multiline key-value pairs completely', () => {
const content = `const translation = {
keepThis: 'This should stay',
removeMultiline:
'This is a multiline value that should be removed completely',
alsoKeep: 'This should also stay',
}
export default translation`
const result = removeExtraKeysFromFile(content, ['removeMultiline'])
expect(result).toContain('keepThis: \'This should stay\'')
expect(result).toContain('alsoKeep: \'This should also stay\'')
expect(result).not.toContain('removeMultiline:')
expect(result).not.toContain('This is a multiline value that should be removed completely')
})
it('should handle mixed single-line and multiline removals', () => {
const content = `const translation = {
keepThis: 'Keep this',
removeSingle: 'Remove this single line',
removeMultiline:
'Remove this multiline value',
anotherMultiline:
'Another multiline that spans multiple lines',
keepAnother: 'Keep this too',
}
export default translation`
const result = removeExtraKeysFromFile(content, ['removeSingle', 'removeMultiline', 'anotherMultiline'])
expect(result).toContain('keepThis: \'Keep this\'')
expect(result).toContain('keepAnother: \'Keep this too\'')
expect(result).not.toContain('removeSingle:')
expect(result).not.toContain('removeMultiline:')
expect(result).not.toContain('anotherMultiline:')
expect(result).not.toContain('Remove this single line')
expect(result).not.toContain('Remove this multiline value')
expect(result).not.toContain('Another multiline that spans multiple lines')
})
it('should properly detect multiline vs single-line patterns', () => {
const multilineContent = `const translation = {
singleLine: 'This is single line',
multilineKey:
'This is multiline',
keyWithColon: 'Value with: colon inside',
objectKey: {
nested: 'value'
},
}
export default translation`
// Test that single line with colon in value is not treated as multiline
const result1 = removeExtraKeysFromFile(multilineContent, ['keyWithColon'])
expect(result1).not.toContain('keyWithColon:')
expect(result1).not.toContain('Value with: colon inside')
// Test that true multiline is handled correctly
const result2 = removeExtraKeysFromFile(multilineContent, ['multilineKey'])
expect(result2).not.toContain('multilineKey:')
expect(result2).not.toContain('This is multiline')
// Test that object key removal works (note: this is a simplified test)
// In real scenario, object removal would be more complex
const result3 = removeExtraKeysFromFile(multilineContent, ['objectKey'])
expect(result3).not.toContain('objectKey: {')
// Note: Our simplified test function doesn't handle nested object removal perfectly
// This is acceptable as it's testing the main multiline string removal functionality
})
it('should handle real-world Polish translation structure', () => {
const polishContent = `const translation = {
createApp: 'UTWÓRZ APLIKACJĘ',
newApp: {
captionAppType: 'Jaki typ aplikacji chcesz stworzyć?',
chatbotDescription:
'Zbuduj aplikację opartą na czacie. Ta aplikacja używa formatu pytań i odpowiedzi.',
agentDescription:
'Zbuduj inteligentnego agenta, który może autonomicznie wybierać narzędzia.',
basic: 'Podstawowy',
},
}
export default translation`
const result = removeExtraKeysFromFile(polishContent, ['captionAppType', 'chatbotDescription', 'agentDescription'])
expect(result).toContain('createApp: \'UTWÓRZ APLIKACJĘ\'')
expect(result).toContain('basic: \'Podstawowy\'')
expect(result).not.toContain('captionAppType:')
expect(result).not.toContain('chatbotDescription:')
expect(result).not.toContain('agentDescription:')
expect(result).not.toContain('Jaki typ aplikacji')
expect(result).not.toContain('Zbuduj aplikację opartą na czacie')
expect(result).not.toContain('Zbuduj inteligentnego agenta')
})
})
})

View File

@ -0,0 +1,305 @@
/**
* Document Detail Navigation Fix Verification Test
*
* This test specifically validates that the backToPrev function in the document detail
* component correctly preserves pagination and filter states.
*/
import { fireEvent, render, screen } from '@testing-library/react'
import { useRouter } from 'next/navigation'
import { useDocumentDetail, useDocumentMetadata } from '@/service/knowledge/use-document'
// Mock Next.js router
const mockPush = jest.fn()
jest.mock('next/navigation', () => ({
useRouter: jest.fn(() => ({
push: mockPush,
})),
}))
// Mock the document service hooks
jest.mock('@/service/knowledge/use-document', () => ({
useDocumentDetail: jest.fn(),
useDocumentMetadata: jest.fn(),
useInvalidDocumentList: jest.fn(() => jest.fn()),
}))
// Mock other dependencies
jest.mock('@/context/dataset-detail', () => ({
useDatasetDetailContext: jest.fn(() => [null]),
}))
jest.mock('@/service/use-base', () => ({
useInvalid: jest.fn(() => jest.fn()),
}))
jest.mock('@/service/knowledge/use-segment', () => ({
useSegmentListKey: jest.fn(),
useChildSegmentListKey: jest.fn(),
}))
// Create a minimal version of the DocumentDetail component that includes our fix
const DocumentDetailWithFix = ({ datasetId, documentId }: { datasetId: string; documentId: string }) => {
const router = useRouter()
// This is the FIXED implementation from detail/index.tsx
const backToPrev = () => {
// Preserve pagination and filter states when navigating back
const searchParams = new URLSearchParams(window.location.search)
const queryString = searchParams.toString()
const separator = queryString ? '?' : ''
const backPath = `/datasets/${datasetId}/documents${separator}${queryString}`
router.push(backPath)
}
return (
<div data-testid="document-detail-fixed">
<button data-testid="back-button-fixed" onClick={backToPrev}>
Back to Documents
</button>
<div data-testid="document-info">
Dataset: {datasetId}, Document: {documentId}
</div>
</div>
)
}
describe('Document Detail Navigation Fix Verification', () => {
beforeEach(() => {
jest.clearAllMocks()
// Mock successful API responses
;(useDocumentDetail as jest.Mock).mockReturnValue({
data: {
id: 'doc-123',
name: 'Test Document',
display_status: 'available',
enabled: true,
archived: false,
},
error: null,
})
;(useDocumentMetadata as jest.Mock).mockReturnValue({
data: null,
error: null,
})
})
describe('Query Parameter Preservation', () => {
test('preserves pagination state (page 3, limit 25)', () => {
// Simulate user coming from page 3 with 25 items per page
Object.defineProperty(window, 'location', {
value: {
search: '?page=3&limit=25',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
// User clicks back button
fireEvent.click(screen.getByTestId('back-button-fixed'))
// Should preserve the pagination state
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=3&limit=25')
console.log('✅ Pagination state preserved: page=3&limit=25')
})
test('preserves search keyword and filters', () => {
// Simulate user with search and filters applied
Object.defineProperty(window, 'location', {
value: {
search: '?page=2&limit=10&keyword=API%20documentation&status=active',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
fireEvent.click(screen.getByTestId('back-button-fixed'))
// Should preserve all query parameters
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=2&limit=10&keyword=API+documentation&status=active')
console.log('✅ Search and filters preserved')
})
test('handles complex query parameters with special characters', () => {
// Test with complex query string including encoded characters
Object.defineProperty(window, 'location', {
value: {
search: '?page=1&limit=50&keyword=test%20%26%20debug&sort=name&order=desc&filter=%7B%22type%22%3A%22pdf%22%7D',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
fireEvent.click(screen.getByTestId('back-button-fixed'))
// URLSearchParams will normalize the encoding, but preserve all parameters
const expectedCall = mockPush.mock.calls[0][0]
expect(expectedCall).toMatch(/^\/datasets\/dataset-123\/documents\?/)
expect(expectedCall).toMatch(/page=1/)
expect(expectedCall).toMatch(/limit=50/)
expect(expectedCall).toMatch(/keyword=test/)
expect(expectedCall).toMatch(/sort=name/)
expect(expectedCall).toMatch(/order=desc/)
console.log('✅ Complex query parameters handled:', expectedCall)
})
test('handles empty query parameters gracefully', () => {
// No query parameters in URL
Object.defineProperty(window, 'location', {
value: {
search: '',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
fireEvent.click(screen.getByTestId('back-button-fixed'))
// Should navigate to clean documents URL
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents')
console.log('✅ Empty parameters handled gracefully')
})
})
describe('Different Dataset IDs', () => {
test('works with different dataset identifiers', () => {
Object.defineProperty(window, 'location', {
value: {
search: '?page=5&limit=10',
},
writable: true,
})
// Test with different dataset ID format
render(<DocumentDetailWithFix datasetId="ds-prod-2024-001" documentId="doc-456" />)
fireEvent.click(screen.getByTestId('back-button-fixed'))
expect(mockPush).toHaveBeenCalledWith('/datasets/ds-prod-2024-001/documents?page=5&limit=10')
console.log('✅ Works with different dataset ID formats')
})
})
describe('Real User Scenarios', () => {
test('scenario: user searches, goes to page 3, views document, clicks back', () => {
// User searched for "API" and navigated to page 3
Object.defineProperty(window, 'location', {
value: {
search: '?keyword=API&page=3&limit=10',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="main-dataset" documentId="api-doc-123" />)
// User decides to go back to continue browsing
fireEvent.click(screen.getByTestId('back-button-fixed'))
// Should return to page 3 of API search results
expect(mockPush).toHaveBeenCalledWith('/datasets/main-dataset/documents?keyword=API&page=3&limit=10')
console.log('✅ Real user scenario: search + pagination preserved')
})
test('scenario: user applies multiple filters, goes to document, returns', () => {
// User has applied multiple filters and is on page 2
Object.defineProperty(window, 'location', {
value: {
search: '?page=2&limit=25&status=active&type=pdf&sort=created_at&order=desc',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="filtered-dataset" documentId="filtered-doc" />)
fireEvent.click(screen.getByTestId('back-button-fixed'))
// All filters should be preserved
expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-dataset/documents?page=2&limit=25&status=active&type=pdf&sort=created_at&order=desc')
console.log('✅ Complex filtering scenario preserved')
})
})
describe('Error Handling and Edge Cases', () => {
test('handles malformed query parameters gracefully', () => {
// Test with potentially problematic query string
Object.defineProperty(window, 'location', {
value: {
search: '?page=invalid&limit=&keyword=test&=emptykey&malformed',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
// Should not throw errors
expect(() => {
fireEvent.click(screen.getByTestId('back-button-fixed'))
}).not.toThrow()
// Should still attempt navigation (URLSearchParams will clean up the parameters)
expect(mockPush).toHaveBeenCalled()
const navigationPath = mockPush.mock.calls[0][0]
expect(navigationPath).toMatch(/^\/datasets\/dataset-123\/documents/)
console.log('✅ Malformed parameters handled gracefully:', navigationPath)
})
test('handles very long query strings', () => {
// Test with a very long query string
const longKeyword = 'a'.repeat(1000)
Object.defineProperty(window, 'location', {
value: {
search: `?page=1&keyword=${longKeyword}`,
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
expect(() => {
fireEvent.click(screen.getByTestId('back-button-fixed'))
}).not.toThrow()
expect(mockPush).toHaveBeenCalled()
console.log('✅ Long query strings handled')
})
})
describe('Performance Verification', () => {
test('navigation function executes quickly', () => {
Object.defineProperty(window, 'location', {
value: {
search: '?page=1&limit=10&keyword=test',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
const startTime = performance.now()
fireEvent.click(screen.getByTestId('back-button-fixed'))
const endTime = performance.now()
const executionTime = endTime - startTime
// Should execute in less than 10ms
expect(executionTime).toBeLessThan(10)
console.log(`⚡ Navigation execution time: ${executionTime.toFixed(2)}ms`)
})
})
})

View File

@ -0,0 +1,83 @@
/**
* Document List Sorting Tests
*/
describe('Document List Sorting', () => {
const mockDocuments = [
{ id: '1', name: 'Beta.pdf', word_count: 500, hit_count: 10, created_at: 1699123456 },
{ id: '2', name: 'Alpha.txt', word_count: 200, hit_count: 25, created_at: 1699123400 },
{ id: '3', name: 'Gamma.docx', word_count: 800, hit_count: 5, created_at: 1699123500 },
]
const sortDocuments = (docs: any[], field: string, order: 'asc' | 'desc') => {
return [...docs].sort((a, b) => {
let aValue: any
let bValue: any
switch (field) {
case 'name':
aValue = a.name?.toLowerCase() || ''
bValue = b.name?.toLowerCase() || ''
break
case 'word_count':
aValue = a.word_count || 0
bValue = b.word_count || 0
break
case 'hit_count':
aValue = a.hit_count || 0
bValue = b.hit_count || 0
break
case 'created_at':
aValue = a.created_at
bValue = b.created_at
break
default:
return 0
}
if (field === 'name') {
const result = aValue.localeCompare(bValue)
return order === 'asc' ? result : -result
}
else {
const result = aValue - bValue
return order === 'asc' ? result : -result
}
})
}
test('sorts by name descending (default for UI consistency)', () => {
const sorted = sortDocuments(mockDocuments, 'name', 'desc')
expect(sorted.map(doc => doc.name)).toEqual(['Gamma.docx', 'Beta.pdf', 'Alpha.txt'])
})
test('sorts by name ascending (after toggle)', () => {
const sorted = sortDocuments(mockDocuments, 'name', 'asc')
expect(sorted.map(doc => doc.name)).toEqual(['Alpha.txt', 'Beta.pdf', 'Gamma.docx'])
})
test('sorts by word_count descending', () => {
const sorted = sortDocuments(mockDocuments, 'word_count', 'desc')
expect(sorted.map(doc => doc.word_count)).toEqual([800, 500, 200])
})
test('sorts by hit_count descending', () => {
const sorted = sortDocuments(mockDocuments, 'hit_count', 'desc')
expect(sorted.map(doc => doc.hit_count)).toEqual([25, 10, 5])
})
test('sorts by created_at descending (newest first)', () => {
const sorted = sortDocuments(mockDocuments, 'created_at', 'desc')
expect(sorted.map(doc => doc.created_at)).toEqual([1699123500, 1699123456, 1699123400])
})
test('handles empty values correctly', () => {
const docsWithEmpty = [
{ id: '1', name: 'Test', word_count: 100, hit_count: 5, created_at: 1699123456 },
{ id: '2', name: 'Empty', word_count: 0, hit_count: 0, created_at: 1699123400 },
]
const sorted = sortDocuments(docsWithEmpty, 'word_count', 'desc')
expect(sorted.map(doc => doc.word_count)).toEqual([100, 0])
})
})

View File

@ -0,0 +1,290 @@
/**
* Navigation Utilities Test
*
* Tests for the navigation utility functions to ensure they handle
* query parameter preservation correctly across different scenarios.
*/
import {
createBackNavigation,
createNavigationPath,
createNavigationPathWithParams,
datasetNavigation,
extractQueryParams,
mergeQueryParams,
} from '@/utils/navigation'
// Mock router for testing
const mockPush = jest.fn()
const mockRouter = { push: mockPush }
describe('Navigation Utilities', () => {
beforeEach(() => {
jest.clearAllMocks()
})
describe('createNavigationPath', () => {
test('preserves query parameters by default', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3&limit=10&keyword=test' },
writable: true,
})
const path = createNavigationPath('/datasets/123/documents')
expect(path).toBe('/datasets/123/documents?page=3&limit=10&keyword=test')
})
test('returns clean path when preserveParams is false', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3&limit=10' },
writable: true,
})
const path = createNavigationPath('/datasets/123/documents', false)
expect(path).toBe('/datasets/123/documents')
})
test('handles empty query parameters', () => {
Object.defineProperty(window, 'location', {
value: { search: '' },
writable: true,
})
const path = createNavigationPath('/datasets/123/documents')
expect(path).toBe('/datasets/123/documents')
})
test('handles errors gracefully', () => {
// Mock window.location to throw an error
Object.defineProperty(window, 'location', {
get: () => {
throw new Error('Location access denied')
},
configurable: true,
})
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
const path = createNavigationPath('/datasets/123/documents')
expect(path).toBe('/datasets/123/documents')
expect(consoleSpy).toHaveBeenCalledWith('Failed to preserve query parameters:', expect.any(Error))
consoleSpy.mockRestore()
})
})
describe('createBackNavigation', () => {
test('creates function that navigates with preserved params', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=2&limit=25' },
writable: true,
})
const backFn = createBackNavigation(mockRouter, '/datasets/123/documents')
backFn()
expect(mockPush).toHaveBeenCalledWith('/datasets/123/documents?page=2&limit=25')
})
test('creates function that navigates without params when specified', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=2&limit=25' },
writable: true,
})
const backFn = createBackNavigation(mockRouter, '/datasets/123/documents', false)
backFn()
expect(mockPush).toHaveBeenCalledWith('/datasets/123/documents')
})
})
describe('extractQueryParams', () => {
test('extracts specified parameters', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3&limit=10&keyword=test&other=value' },
writable: true,
})
const params = extractQueryParams(['page', 'limit', 'keyword'])
expect(params).toEqual({
page: '3',
limit: '10',
keyword: 'test',
})
})
test('handles missing parameters', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3' },
writable: true,
})
const params = extractQueryParams(['page', 'limit', 'missing'])
expect(params).toEqual({
page: '3',
})
})
test('handles errors gracefully', () => {
Object.defineProperty(window, 'location', {
get: () => {
throw new Error('Location access denied')
},
configurable: true,
})
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
const params = extractQueryParams(['page', 'limit'])
expect(params).toEqual({})
expect(consoleSpy).toHaveBeenCalledWith('Failed to extract query parameters:', expect.any(Error))
consoleSpy.mockRestore()
})
})
describe('createNavigationPathWithParams', () => {
test('creates path with specified parameters', () => {
const path = createNavigationPathWithParams('/datasets/123/documents', {
page: 1,
limit: 25,
keyword: 'search term',
})
expect(path).toBe('/datasets/123/documents?page=1&limit=25&keyword=search+term')
})
test('filters out empty values', () => {
const path = createNavigationPathWithParams('/datasets/123/documents', {
page: 1,
limit: '',
keyword: 'test',
empty: null,
undefined,
})
expect(path).toBe('/datasets/123/documents?page=1&keyword=test')
})
test('handles errors gracefully', () => {
// Mock URLSearchParams to throw an error
const originalURLSearchParams = globalThis.URLSearchParams
globalThis.URLSearchParams = jest.fn(() => {
throw new Error('URLSearchParams error')
}) as any
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
const path = createNavigationPathWithParams('/datasets/123/documents', { page: 1 })
expect(path).toBe('/datasets/123/documents')
expect(consoleSpy).toHaveBeenCalledWith('Failed to create navigation path with params:', expect.any(Error))
consoleSpy.mockRestore()
globalThis.URLSearchParams = originalURLSearchParams
})
})
describe('mergeQueryParams', () => {
test('merges new params with existing ones', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3&limit=10' },
writable: true,
})
const merged = mergeQueryParams({ keyword: 'test', page: '1' })
const result = merged.toString()
expect(result).toContain('page=1') // overridden
expect(result).toContain('limit=10') // preserved
expect(result).toContain('keyword=test') // added
})
test('removes parameters when value is null', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3&limit=10&keyword=test' },
writable: true,
})
const merged = mergeQueryParams({ keyword: null, filter: 'active' })
const result = merged.toString()
expect(result).toContain('page=3')
expect(result).toContain('limit=10')
expect(result).not.toContain('keyword')
expect(result).toContain('filter=active')
})
test('creates fresh params when preserveExisting is false', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3&limit=10' },
writable: true,
})
const merged = mergeQueryParams({ keyword: 'test' }, false)
const result = merged.toString()
expect(result).toBe('keyword=test')
})
})
describe('datasetNavigation', () => {
test('backToDocuments creates correct navigation function', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=2&limit=25' },
writable: true,
})
const backFn = datasetNavigation.backToDocuments(mockRouter, 'dataset-123')
backFn()
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=2&limit=25')
})
test('toDocumentDetail creates correct navigation function', () => {
const detailFn = datasetNavigation.toDocumentDetail(mockRouter, 'dataset-123', 'doc-456')
detailFn()
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents/doc-456')
})
test('toDocumentSettings creates correct navigation function', () => {
const settingsFn = datasetNavigation.toDocumentSettings(mockRouter, 'dataset-123', 'doc-456')
settingsFn()
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents/doc-456/settings')
})
})
describe('Real-world Integration Scenarios', () => {
test('complete user workflow: list -> detail -> back', () => {
// User starts on page 3 with search
Object.defineProperty(window, 'location', {
value: { search: '?page=3&keyword=API&limit=25' },
writable: true,
})
// Create back navigation function (as would be done in detail component)
const backToDocuments = datasetNavigation.backToDocuments(mockRouter, 'main-dataset')
// User clicks back
backToDocuments()
// Should return to exact same list state
expect(mockPush).toHaveBeenCalledWith('/datasets/main-dataset/documents?page=3&keyword=API&limit=25')
})
test('user applies filters then views document', () => {
// Complex filter state
Object.defineProperty(window, 'location', {
value: { search: '?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc' },
writable: true,
})
const backFn = createBackNavigation(mockRouter, '/datasets/filtered-set/documents')
backFn()
expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-set/documents?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc')
})
})
})

View File

@ -0,0 +1,396 @@
/**
* Unified Tags Editing - Pure Logic Tests
*
* This test file validates the core business logic and state management
* behaviors introduced in the recent 7 commits without requiring complex mocks.
*/
describe('Unified Tags Editing - Pure Logic Tests', () => {
describe('Tag State Management Logic', () => {
it('should detect when tag values have changed', () => {
const currentValue = ['tag1', 'tag2']
const newSelectedTagIDs = ['tag1', 'tag3']
// This is the valueNotChanged logic from TagSelector component
const valueNotChanged
= currentValue.length === newSelectedTagIDs.length
&& currentValue.every(v => newSelectedTagIDs.includes(v))
&& newSelectedTagIDs.every(v => currentValue.includes(v))
expect(valueNotChanged).toBe(false)
})
it('should correctly identify unchanged tag values', () => {
const currentValue = ['tag1', 'tag2']
const newSelectedTagIDs = ['tag2', 'tag1'] // Same tags, different order
const valueNotChanged
= currentValue.length === newSelectedTagIDs.length
&& currentValue.every(v => newSelectedTagIDs.includes(v))
&& newSelectedTagIDs.every(v => currentValue.includes(v))
expect(valueNotChanged).toBe(true)
})
it('should calculate correct tag operations for binding/unbinding', () => {
const currentValue = ['tag1', 'tag2']
const selectedTagIDs = ['tag2', 'tag3']
// This is the handleValueChange logic from TagSelector
const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v))
const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v))
expect(addTagIDs).toEqual(['tag3'])
expect(removeTagIDs).toEqual(['tag1'])
})
it('should handle empty tag arrays correctly', () => {
const currentValue: string[] = []
const selectedTagIDs = ['tag1']
const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v))
const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v))
expect(addTagIDs).toEqual(['tag1'])
expect(removeTagIDs).toEqual([])
expect(currentValue.length).toBe(0) // Verify empty array usage
})
it('should handle removing all tags', () => {
const currentValue = ['tag1', 'tag2']
const selectedTagIDs: string[] = []
const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v))
const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v))
expect(addTagIDs).toEqual([])
expect(removeTagIDs).toEqual(['tag1', 'tag2'])
expect(selectedTagIDs.length).toBe(0) // Verify empty array usage
})
})
describe('Fallback Logic (from layout-main.tsx)', () => {
it('should trigger fallback when tags are missing or empty', () => {
const appDetailWithoutTags = { tags: [] }
const appDetailWithTags = { tags: [{ id: 'tag1' }] }
const appDetailWithUndefinedTags = { tags: undefined as any }
// This simulates the condition in layout-main.tsx
const shouldFallback1 = !appDetailWithoutTags.tags || appDetailWithoutTags.tags.length === 0
const shouldFallback2 = !appDetailWithTags.tags || appDetailWithTags.tags.length === 0
const shouldFallback3 = !appDetailWithUndefinedTags.tags || appDetailWithUndefinedTags.tags.length === 0
expect(shouldFallback1).toBe(true) // Empty array should trigger fallback
expect(shouldFallback2).toBe(false) // Has tags, no fallback needed
expect(shouldFallback3).toBe(true) // Undefined tags should trigger fallback
})
it('should preserve tags when fallback succeeds', () => {
const originalAppDetail = { tags: [] as any[] }
const fallbackResult = { tags: [{ id: 'tag1', name: 'fallback-tag' }] }
// This simulates the successful fallback in layout-main.tsx
if (fallbackResult?.tags)
originalAppDetail.tags = fallbackResult.tags
expect(originalAppDetail.tags).toEqual(fallbackResult.tags)
expect(originalAppDetail.tags.length).toBe(1)
})
it('should continue with empty tags when fallback fails', () => {
const originalAppDetail: { tags: any[] } = { tags: [] }
const fallbackResult: { tags?: any[] } | null = null
// This simulates fallback failure in layout-main.tsx
if (fallbackResult?.tags)
originalAppDetail.tags = fallbackResult.tags
expect(originalAppDetail.tags).toEqual([])
})
})
describe('TagSelector Auto-initialization Logic', () => {
it('should trigger getTagList when tagList is empty', () => {
const tagList: any[] = []
let getTagListCalled = false
const getTagList = () => {
getTagListCalled = true
}
// This simulates the useEffect in TagSelector
if (tagList.length === 0)
getTagList()
expect(getTagListCalled).toBe(true)
})
it('should not trigger getTagList when tagList has items', () => {
const tagList = [{ id: 'tag1', name: 'existing-tag' }]
let getTagListCalled = false
const getTagList = () => {
getTagListCalled = true
}
// This simulates the useEffect in TagSelector
if (tagList.length === 0)
getTagList()
expect(getTagListCalled).toBe(false)
})
})
describe('State Initialization Patterns', () => {
it('should maintain AppCard tag state pattern', () => {
const app = { tags: [{ id: 'tag1', name: 'test' }] }
// Original AppCard pattern: useState(app.tags)
const initialTags = app.tags
expect(Array.isArray(initialTags)).toBe(true)
expect(initialTags.length).toBe(1)
expect(initialTags).toBe(app.tags) // Reference equality for AppCard
})
it('should maintain AppInfo tag state pattern', () => {
const appDetail = { tags: [{ id: 'tag1', name: 'test' }] }
// New AppInfo pattern: useState(appDetail?.tags || [])
const initialTags = appDetail?.tags || []
expect(Array.isArray(initialTags)).toBe(true)
expect(initialTags.length).toBe(1)
})
it('should handle undefined appDetail gracefully in AppInfo', () => {
const appDetail = undefined
// AppInfo pattern with undefined appDetail
const initialTags = (appDetail as any)?.tags || []
expect(Array.isArray(initialTags)).toBe(true)
expect(initialTags.length).toBe(0)
})
})
describe('CSS Class and Layout Logic', () => {
it('should apply correct minimum width condition', () => {
const minWidth = 'true'
// This tests the minWidth logic in TagSelector
const shouldApplyMinWidth = minWidth && '!min-w-80'
expect(shouldApplyMinWidth).toBe('!min-w-80')
})
it('should not apply minimum width when not specified', () => {
const minWidth = undefined
const shouldApplyMinWidth = minWidth && '!min-w-80'
expect(shouldApplyMinWidth).toBeFalsy()
})
it('should handle overflow layout classes correctly', () => {
// This tests the layout pattern from AppCard and new AppInfo
const overflowLayoutClasses = {
container: 'flex w-0 grow items-center',
inner: 'w-full',
truncate: 'truncate',
}
expect(overflowLayoutClasses.container).toContain('w-0 grow')
expect(overflowLayoutClasses.inner).toContain('w-full')
expect(overflowLayoutClasses.truncate).toBe('truncate')
})
})
describe('fetchAppWithTags Service Logic', () => {
it('should correctly find app by ID from app list', () => {
const appList = [
{ id: 'app1', name: 'App 1', tags: [] },
{ id: 'test-app-id', name: 'Test App', tags: [{ id: 'tag1', name: 'test' }] },
{ id: 'app3', name: 'App 3', tags: [] },
]
const targetAppId = 'test-app-id'
// This simulates the logic in fetchAppWithTags
const foundApp = appList.find(app => app.id === targetAppId)
expect(foundApp).toBeDefined()
expect(foundApp?.id).toBe('test-app-id')
expect(foundApp?.tags.length).toBe(1)
})
it('should return null when app not found', () => {
const appList = [
{ id: 'app1', name: 'App 1' },
{ id: 'app2', name: 'App 2' },
]
const targetAppId = 'nonexistent-app'
const foundApp = appList.find(app => app.id === targetAppId) || null
expect(foundApp).toBeNull()
})
it('should handle empty app list', () => {
const appList: any[] = []
const targetAppId = 'any-app'
const foundApp = appList.find(app => app.id === targetAppId) || null
expect(foundApp).toBeNull()
expect(appList.length).toBe(0) // Verify empty array usage
})
})
describe('Data Structure Validation', () => {
it('should maintain consistent tag data structure', () => {
const tag = {
id: 'tag1',
name: 'test-tag',
type: 'app',
binding_count: 1,
}
expect(tag).toHaveProperty('id')
expect(tag).toHaveProperty('name')
expect(tag).toHaveProperty('type')
expect(tag).toHaveProperty('binding_count')
expect(tag.type).toBe('app')
expect(typeof tag.binding_count).toBe('number')
})
it('should handle tag arrays correctly', () => {
const tags = [
{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 1 },
{ id: 'tag2', name: 'Tag 2', type: 'app', binding_count: 0 },
]
expect(Array.isArray(tags)).toBe(true)
expect(tags.length).toBe(2)
expect(tags.every(tag => tag.type === 'app')).toBe(true)
})
it('should validate app data structure with tags', () => {
const app = {
id: 'test-app',
name: 'Test App',
tags: [
{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 1 },
],
}
expect(app).toHaveProperty('id')
expect(app).toHaveProperty('name')
expect(app).toHaveProperty('tags')
expect(Array.isArray(app.tags)).toBe(true)
expect(app.tags.length).toBe(1)
})
})
describe('Performance and Edge Cases', () => {
it('should handle large tag arrays efficiently', () => {
const largeTags = Array.from({ length: 100 }, (_, i) => `tag${i}`)
const selectedTags = ['tag1', 'tag50', 'tag99']
// Performance test: filtering should be efficient
const startTime = Date.now()
const addTags = selectedTags.filter(tag => !largeTags.includes(tag))
const removeTags = largeTags.filter(tag => !selectedTags.includes(tag))
const endTime = Date.now()
expect(endTime - startTime).toBeLessThan(10) // Should be very fast
expect(addTags.length).toBe(0) // All selected tags exist
expect(removeTags.length).toBe(97) // 100 - 3 = 97 tags to remove
})
it('should handle malformed tag data gracefully', () => {
const mixedData = [
{ id: 'valid1', name: 'Valid Tag', type: 'app', binding_count: 1 },
{ id: 'invalid1' }, // Missing required properties
null,
undefined,
{ id: 'valid2', name: 'Another Valid', type: 'app', binding_count: 0 },
]
// Filter out invalid entries
const validTags = mixedData.filter((tag): tag is { id: string; name: string; type: string; binding_count: number } =>
tag != null
&& typeof tag === 'object'
&& 'id' in tag
&& 'name' in tag
&& 'type' in tag
&& 'binding_count' in tag
&& typeof tag.binding_count === 'number',
)
expect(validTags.length).toBe(2)
expect(validTags.every(tag => tag.id && tag.name)).toBe(true)
})
it('should handle concurrent tag operations correctly', () => {
const operations = [
{ type: 'add', tagIds: ['tag1', 'tag2'] },
{ type: 'remove', tagIds: ['tag3'] },
{ type: 'add', tagIds: ['tag4'] },
]
// Simulate processing operations
const results = operations.map(op => ({
...op,
processed: true,
timestamp: Date.now(),
}))
expect(results.length).toBe(3)
expect(results.every(result => result.processed)).toBe(true)
})
})
describe('Backward Compatibility Verification', () => {
it('should not break existing AppCard behavior', () => {
// Verify AppCard continues to work with original patterns
const originalAppCardLogic = {
initializeTags: (app: any) => app.tags,
updateTags: (_currentTags: any[], newTags: any[]) => newTags,
shouldRefresh: true,
}
const app = { tags: [{ id: 'tag1', name: 'original' }] }
const initializedTags = originalAppCardLogic.initializeTags(app)
expect(initializedTags).toBe(app.tags)
expect(originalAppCardLogic.shouldRefresh).toBe(true)
})
it('should ensure AppInfo follows AppCard patterns', () => {
// Verify AppInfo uses compatible state management
const appCardPattern = (app: any) => app.tags
const appInfoPattern = (appDetail: any) => appDetail?.tags || []
const appWithTags = { tags: [{ id: 'tag1' }] }
const appWithoutTags = { tags: [] }
const undefinedApp = undefined
expect(appCardPattern(appWithTags)).toEqual(appInfoPattern(appWithTags))
expect(appInfoPattern(appWithoutTags)).toEqual([])
expect(appInfoPattern(undefinedApp)).toEqual([])
})
it('should maintain consistent API parameters', () => {
// Verify service layer maintains expected parameters
const fetchAppListParams = {
url: '/apps',
params: { page: 1, limit: 100 },
}
const tagApiParams = {
bindTag: (tagIDs: string[], targetID: string, type: string) => ({ tagIDs, targetID, type }),
unBindTag: (tagID: string, targetID: string, type: string) => ({ tagID, targetID, type }),
}
expect(fetchAppListParams.url).toBe('/apps')
expect(fetchAppListParams.params.limit).toBe(100)
const bindResult = tagApiParams.bindTag(['tag1'], 'app1', 'app')
expect(bindResult.tagIDs).toEqual(['tag1'])
expect(bindResult.type).toBe('app')
})
})
})

View File

@ -0,0 +1,212 @@
/**
* XSS Fix Verification Test
*
* This test verifies that the XSS vulnerability in check-code pages has been
* properly fixed by replacing dangerouslySetInnerHTML with safe React rendering.
*/
import React from 'react'
import { cleanup, render } from '@testing-library/react'
import '@testing-library/jest-dom'
// Mock i18next with the new safe translation structure
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => {
if (key === 'login.checkCode.tipsPrefix')
return 'We send a verification code to '
return key
},
}),
}))
// Mock Next.js useSearchParams
jest.mock('next/navigation', () => ({
useSearchParams: () => ({
get: (key: string) => {
if (key === 'email')
return 'test@example.com<script>alert("XSS")</script>'
return null
},
}),
}))
// Fixed CheckCode component implementation (current secure version)
const SecureCheckCodeComponent = ({ email }: { email: string }) => {
const { t } = require('react-i18next').useTranslation()
return (
<div>
<h1>Check Code</h1>
<p>
<span>
{t('login.checkCode.tipsPrefix')}
<strong>{email}</strong>
</span>
</p>
</div>
)
}
// Vulnerable implementation for comparison (what we fixed)
const VulnerableCheckCodeComponent = ({ email }: { email: string }) => {
const mockTranslation = (key: string, params?: any) => {
if (key === 'login.checkCode.tips' && params?.email)
return `We send a verification code to <strong>${params.email}</strong>`
return key
}
return (
<div>
<h1>Check Code</h1>
<p>
<span dangerouslySetInnerHTML={{ __html: mockTranslation('login.checkCode.tips', { email }) }}></span>
</p>
</div>
)
}
describe('XSS Fix Verification - Check Code Pages Security', () => {
afterEach(() => {
cleanup()
})
const maliciousEmail = 'test@example.com<script>alert("XSS")</script>'
it('should securely render email with HTML characters as text (FIXED VERSION)', () => {
console.log('\n🔒 Security Fix Verification Report')
console.log('===================================')
const { container } = render(<SecureCheckCodeComponent email={maliciousEmail} />)
const spanElement = container.querySelector('span')
const strongElement = container.querySelector('strong')
const scriptElements = container.querySelectorAll('script')
console.log('\n✅ Fixed Implementation Results:')
console.log('- Email rendered in strong tag:', strongElement?.textContent)
console.log('- HTML tags visible as text:', strongElement?.textContent?.includes('<script>'))
console.log('- Script elements created:', scriptElements.length)
console.log('- Full text content:', spanElement?.textContent)
// Verify secure behavior
expect(strongElement?.textContent).toBe(maliciousEmail) // Email rendered as text
expect(strongElement?.textContent).toContain('<script>') // HTML visible as text
expect(scriptElements).toHaveLength(0) // No script elements created
expect(spanElement?.textContent).toBe(`We send a verification code to ${maliciousEmail}`)
console.log('\n🎯 Security Status: SECURE - HTML automatically escaped by React')
})
it('should demonstrate the vulnerability that was fixed (VULNERABLE VERSION)', () => {
const { container } = render(<VulnerableCheckCodeComponent email={maliciousEmail} />)
const spanElement = container.querySelector('span')
const strongElement = container.querySelector('strong')
const scriptElements = container.querySelectorAll('script')
console.log('\n⚠ Previous Vulnerable Implementation:')
console.log('- HTML content:', spanElement?.innerHTML)
console.log('- Strong element text:', strongElement?.textContent)
console.log('- Script elements created:', scriptElements.length)
console.log('- Script content:', scriptElements[0]?.textContent)
// Verify vulnerability exists in old implementation
expect(scriptElements).toHaveLength(1) // Script element was created
expect(scriptElements[0]?.textContent).toBe('alert("XSS")') // Contains malicious code
expect(spanElement?.innerHTML).toContain('<script>') // Raw HTML in DOM
console.log('\n❌ Security Status: VULNERABLE - dangerouslySetInnerHTML creates script elements')
})
it('should verify all affected components use the secure pattern', () => {
console.log('\n📋 Component Security Audit')
console.log('============================')
// Test multiple malicious inputs
const testCases = [
'user@test.com<img src=x onerror=alert(1)>',
'test@evil.com<div onclick="alert(2)">click</div>',
'admin@site.com<script>document.cookie="stolen"</script>',
'normal@email.com',
]
testCases.forEach((testEmail, index) => {
const { container } = render(<SecureCheckCodeComponent email={testEmail} />)
const strongElement = container.querySelector('strong')
const scriptElements = container.querySelectorAll('script')
const imgElements = container.querySelectorAll('img')
const divElements = container.querySelectorAll('div:not([data-testid])')
console.log(`\n📧 Test Case ${index + 1}: ${testEmail.substring(0, 20)}...`)
console.log(` - Script elements: ${scriptElements.length}`)
console.log(` - Img elements: ${imgElements.length}`)
console.log(` - Malicious divs: ${divElements.length - 1}`) // -1 for container div
console.log(` - Text content: ${strongElement?.textContent === testEmail ? 'SAFE' : 'ISSUE'}`)
// All should be safe
expect(scriptElements).toHaveLength(0)
expect(imgElements).toHaveLength(0)
expect(strongElement?.textContent).toBe(testEmail)
})
console.log('\n✅ All test cases passed - secure rendering confirmed')
})
it('should validate the translation structure is secure', () => {
console.log('\n🔍 Translation Security Analysis')
console.log('=================================')
const { t } = require('react-i18next').useTranslation()
const prefix = t('login.checkCode.tipsPrefix')
console.log('- Translation key used: login.checkCode.tipsPrefix')
console.log('- Translation value:', prefix)
console.log('- Contains HTML tags:', prefix.includes('<'))
console.log('- Pure text content:', !prefix.includes('<') && !prefix.includes('>'))
// Verify translation is plain text
expect(prefix).toBe('We send a verification code to ')
expect(prefix).not.toContain('<')
expect(prefix).not.toContain('>')
expect(typeof prefix).toBe('string')
console.log('\n✅ Translation structure is secure - no HTML content')
})
it('should confirm React automatic escaping works correctly', () => {
console.log('\n⚡ React Security Mechanism Test')
console.log('=================================')
// Test React's automatic escaping with various inputs
const dangerousInputs = [
'<script>alert("xss")</script>',
'<img src="x" onerror="alert(1)">',
'"><script>alert(2)</script>',
'\'>alert(3)</script>',
'<div onclick="alert(4)">click</div>',
]
dangerousInputs.forEach((input, index) => {
const TestComponent = () => <strong>{input}</strong>
const { container } = render(<TestComponent />)
const strongElement = container.querySelector('strong')
const scriptElements = container.querySelectorAll('script')
console.log(`\n🧪 Input ${index + 1}: ${input.substring(0, 30)}...`)
console.log(` - Rendered as text: ${strongElement?.textContent === input}`)
console.log(` - No script execution: ${scriptElements.length === 0}`)
expect(strongElement?.textContent).toBe(input)
expect(scriptElements).toHaveLength(0)
})
console.log('\n🛡 React automatic escaping is working perfectly')
})
})
export {}

Some files were not shown because too many files have changed in this diff Show More