mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
201e4cd64d
|
|
@ -99,3 +99,6 @@ jobs:
|
|||
|
||||
- name: Run Tool
|
||||
run: uv run --project api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run TestContainers
|
||||
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ permissions:
|
|||
|
||||
jobs:
|
||||
autofix:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
|
|
|||
|
|
@ -235,6 +235,10 @@ Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https:/
|
|||
|
||||
One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Deploy to AKS with Azure Devops Pipeline
|
||||
|
||||
One-Click deploy Dify to AKS with [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -217,6 +217,10 @@ docker compose up -d
|
|||
|
||||
انشر Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### استخدام Azure Devops Pipeline للنشر على AKS
|
||||
|
||||
انشر Dify على AKS بنقرة واحدة باستخدام [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## المساهمة
|
||||
|
||||
|
|
|
|||
|
|
@ -235,6 +235,10 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
|
|||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### AKS-এ ডিপ্লয় করার জন্য Azure Devops Pipeline ব্যবহার
|
||||
|
||||
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ব্যবহার করে Dify কে AKS-এ এক ক্লিকে ডিপ্লয় করুন
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -233,6 +233,9 @@ docker compose up -d
|
|||
|
||||
使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云
|
||||
|
||||
#### 使用 Azure Devops Pipeline 部署到AKS
|
||||
|
||||
使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 将 Dify 一键部署到 AKS
|
||||
|
||||
## Star History
|
||||
|
||||
|
|
|
|||
|
|
@ -230,6 +230,10 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Verwendung von Azure Devops Pipeline für AKS-Bereitstellung
|
||||
|
||||
Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) verwenden
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -230,6 +230,10 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Uso de Azure Devops Pipeline para implementar en AKS
|
||||
|
||||
Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## Contribuir
|
||||
|
||||
|
|
|
|||
|
|
@ -228,6 +228,10 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Utilisation d'Azure Devops Pipeline pour déployer sur AKS
|
||||
|
||||
Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## Contribuer
|
||||
|
||||
|
|
|
|||
|
|
@ -227,6 +227,10 @@ docker compose up -d
|
|||
#### Alibaba Cloud Data Management
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます
|
||||
|
||||
#### AKSへのデプロイにAzure Devops Pipelineを使用
|
||||
|
||||
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)を使用してDifyをAKSにワンクリックでデプロイ
|
||||
|
||||
|
||||
## 貢献
|
||||
|
||||
|
|
|
|||
|
|
@ -228,6 +228,10 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo
|
|||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### AKS 'e' Deploy je Azure Devops Pipeline lo'laH
|
||||
|
||||
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) lo'laH Dify AKS 'e' wa'DIch click 'e' Deploy
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
|
|||
|
|
@ -222,6 +222,10 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
|
|||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다
|
||||
|
||||
#### AKS에 배포하기 위해 Azure Devops Pipeline 사용
|
||||
|
||||
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)을 사용하여 Dify를 AKS에 원클릭으로 배포
|
||||
|
||||
|
||||
## 기여
|
||||
|
||||
|
|
|
|||
|
|
@ -227,6 +227,10 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Usando Azure Devops Pipeline para Implantar no AKS
|
||||
|
||||
Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## Contribuindo
|
||||
|
||||
|
|
|
|||
|
|
@ -228,6 +228,10 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Uporaba Azure Devops Pipeline za uvajanje v AKS
|
||||
|
||||
Z enim klikom namestite Dify v AKS z uporabo [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## Prispevam
|
||||
|
||||
|
|
|
|||
|
|
@ -221,6 +221,10 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
|
|||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın
|
||||
|
||||
#### AKS'ye Dağıtım için Azure Devops Pipeline Kullanımı
|
||||
|
||||
[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) kullanarak Dify'ı tek tıkla AKS'ye dağıtın
|
||||
|
||||
|
||||
## Katkıda Bulunma
|
||||
|
||||
|
|
|
|||
|
|
@ -233,6 +233,10 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
|
|||
|
||||
透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲
|
||||
|
||||
#### 使用 Azure Devops Pipeline 部署到AKS
|
||||
|
||||
使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 將 Dify 一鍵部署到 AKS
|
||||
|
||||
|
||||
## 貢獻
|
||||
|
||||
|
|
|
|||
|
|
@ -224,6 +224,10 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/)
|
|||
|
||||
Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
#### Sử dụng Azure Devops Pipeline để Triển khai lên AKS
|
||||
|
||||
Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure Devops Pipeline Helm Chart bởi @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)
|
||||
|
||||
|
||||
## Đóng góp
|
||||
|
||||
|
|
|
|||
|
|
@ -232,6 +232,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com
|
|||
TABLESTORE_INSTANCE_NAME=instance-name
|
||||
TABLESTORE_ACCESS_KEY_ID=xxx
|
||||
TABLESTORE_ACCESS_KEY_SECRET=xxx
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false
|
||||
|
||||
# Tidb Vector configuration
|
||||
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import secrets
|
|||
from typing import Any, Optional
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from flask import current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
|
@ -462,7 +463,7 @@ def convert_to_agent_apps():
|
|||
"""
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query))
|
||||
rs = conn.execute(sa.text(sql_query))
|
||||
|
||||
apps = []
|
||||
for i in rs:
|
||||
|
|
@ -707,7 +708,7 @@ def fix_app_site_missing():
|
|||
sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id
|
||||
where sites.id is null limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql))
|
||||
rs = conn.execute(sa.text(sql))
|
||||
|
||||
processed_count = 0
|
||||
for i in rs:
|
||||
|
|
@ -921,7 +922,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
)
|
||||
orphaned_message_files = []
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])})
|
||||
|
||||
|
|
@ -942,7 +943,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
|
||||
query = "DELETE FROM message_files WHERE id IN :ids"
|
||||
with db.engine.begin() as conn:
|
||||
conn.execute(db.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
|
||||
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
|
||||
click.echo(
|
||||
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
|
||||
)
|
||||
|
|
@ -959,7 +960,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white"))
|
||||
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]})
|
||||
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
|
||||
|
|
@ -979,7 +980,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
|
||||
elif ids_table["type"] == "text":
|
||||
|
|
@ -994,7 +995,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
|
|
@ -1013,7 +1014,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
|
|
@ -1042,7 +1043,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white"))
|
||||
query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids"
|
||||
with db.engine.begin() as conn:
|
||||
conn.execute(db.text(query), {"ids": tuple(orphaned_files)})
|
||||
conn.execute(sa.text(query), {"ids": tuple(orphaned_files)})
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red"))
|
||||
return
|
||||
|
|
@ -1112,7 +1113,7 @@ def remove_orphaned_files_on_storage(force: bool):
|
|||
click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white"))
|
||||
query = f"SELECT {files_table['key_column']} FROM {files_table['table']}"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_files_in_tables.append(str(i[0]))
|
||||
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ class DatabaseConfig(BaseSettings):
|
|||
|
||||
class CeleryConfig(DatabaseConfig):
|
||||
CELERY_BACKEND: str = Field(
|
||||
description="Backend for Celery task results. Options: 'database', 'redis'.",
|
||||
description="Backend for Celery task results. Options: 'database', 'redis', 'rabbitmq'.",
|
||||
default="redis",
|
||||
)
|
||||
|
||||
|
|
@ -245,7 +245,12 @@ class CeleryConfig(DatabaseConfig):
|
|||
|
||||
@computed_field
|
||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||
return f"db+{self.SQLALCHEMY_DATABASE_URI}" if self.CELERY_BACKEND == "database" else self.CELERY_BROKER_URL
|
||||
if self.CELERY_BACKEND in ("database", "rabbitmq"):
|
||||
return f"db+{self.SQLALCHEMY_DATABASE_URI}"
|
||||
elif self.CELERY_BACKEND == "redis":
|
||||
return self.CELERY_BROKER_URL
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def BROKER_USE_SSL(self) -> bool:
|
||||
|
|
|
|||
|
|
@ -28,3 +28,8 @@ class TableStoreConfig(BaseSettings):
|
|||
description="AccessKey secret for the instance name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: bool = Field(
|
||||
description="Whether to normalize full-text search scores to [0, 1]",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ DEFAULT_FILE_NUMBER_LIMITS = 3
|
|||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
|
||||
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
|
||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ from .datasets import (
|
|||
external,
|
||||
hit_testing,
|
||||
metadata,
|
||||
upload_file,
|
||||
website,
|
||||
)
|
||||
from .datasets.rag_pipeline import (
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "message_count": i.message_count})
|
||||
|
||||
|
|
@ -176,7 +176,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"}
|
||||
|
|
@ -310,7 +310,7 @@ ORDER BY
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
|
|
@ -373,7 +373,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
|
|
@ -435,7 +435,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)})
|
||||
|
||||
|
|
@ -495,7 +495,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from datetime import datetime
|
|||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
|
@ -71,7 +72,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "runs": i.runs})
|
||||
|
||||
|
|
@ -133,7 +134,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
|
||||
|
|
@ -195,7 +196,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
|
|
@ -277,7 +278,7 @@ GROUP BY
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
|
|
|
|||
|
|
@ -645,7 +645,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
|||
return marshal(document_dict, document_status_fields)
|
||||
|
||||
|
||||
class DocumentDetailApi(DocumentResource):
|
||||
class DocumentApi(DocumentResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
@setup_required
|
||||
|
|
@ -733,6 +733,28 @@ class DocumentDetailApi(DocumentResource):
|
|||
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class DocumentProcessingApi(DocumentResource):
|
||||
@setup_required
|
||||
|
|
@ -771,30 +793,6 @@ class DocumentProcessingApi(DocumentResource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DocumentDeleteApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class DocumentMetadataApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -1075,11 +1073,10 @@ api.add_resource(
|
|||
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
|
||||
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(
|
||||
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
|
||||
)
|
||||
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
|
||||
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
|
||||
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
class UploadFileApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""Get upload file."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check upload file
|
||||
if document.data_source_type != "upload_file":
|
||||
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
else:
|
||||
raise ValueError("Upload file id not found in document data source info.")
|
||||
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": url,
|
||||
"download_url": f"{url}&as_attachment=true",
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.timestamp(),
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")
|
||||
|
|
@ -127,7 +127,7 @@ class EducationActivateLimitError(BaseHTTPException):
|
|||
code = 429
|
||||
|
||||
|
||||
class CompilanceRateLimitError(BaseHTTPException):
|
||||
error_code = "compilance_rate_limit"
|
||||
class ComplianceRateLimitError(BaseHTTPException):
|
||||
error_code = "compliance_rate_limit"
|
||||
description = "Rate limit exceeded for downloading compliance report."
|
||||
code = 429
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
|
|
@ -30,6 +30,7 @@ from libs import helper
|
|||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
|
|
@ -113,7 +114,7 @@ class ChatApi(Resource):
|
|||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
|
||||
|
||||
parser.add_argument("workflow_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
|
|
@ -128,6 +129,12 @@ class ChatApi(Resource):
|
|||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except IsDraftWorkflowError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except WorkflowIdFormatError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import json
|
||||
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
|
|
@ -15,6 +17,7 @@ from fields.conversation_fields import (
|
|||
simple_conversation_fields,
|
||||
)
|
||||
from fields.conversation_variable_fields import (
|
||||
conversation_variable_fields,
|
||||
conversation_variable_infinite_scroll_pagination_fields,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
|
|
@ -120,7 +123,41 @@ class ConversationVariablesApi(Resource):
|
|||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
class ConversationVariableDetailApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(conversation_variable_fields)
|
||||
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
|
||||
"""Update a conversation variable's value"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
variable_id = str(variable_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("value", required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id, end_user, json.loads(args["value"])
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationVariableNotExistsError:
|
||||
raise NotFound("Conversation Variable Not Exists.")
|
||||
except services.errors.conversation.ConversationVariableTypeMismatchError as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
|
||||
api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="conversation_name")
|
||||
api.add_resource(ConversationApi, "/conversations")
|
||||
api.add_resource(ConversationDetailApi, "/conversations/<uuid:c_id>", endpoint="conversation_detail")
|
||||
api.add_resource(ConversationVariablesApi, "/conversations/<uuid:c_id>/variables", endpoint="conversation_variables")
|
||||
api.add_resource(
|
||||
ConversationVariableDetailApi,
|
||||
"/conversations/<uuid:c_id>/variables/<uuid:variable_id>",
|
||||
endpoint="conversation_variable_detail",
|
||||
methods=["PUT"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask import request
|
|||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
|
|
@ -34,6 +34,7 @@ from libs.helper import TimestampField
|
|||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
|
@ -120,6 +121,59 @@ class WorkflowRunApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowRunByIdApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, workflow_id: str):
|
||||
"""
|
||||
Run specific workflow by ID
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Add workflow_id to args for AppGenerateService
|
||||
args["workflow_id"] = workflow_id
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except IsDraftWorkflowError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except WorkflowIdFormatError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowTaskStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
|
|
@ -193,5 +247,6 @@ class WorkflowAppLogApi(Resource):
|
|||
|
||||
api.add_resource(WorkflowRunApi, "/workflows/run")
|
||||
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_run_id>")
|
||||
api.add_resource(WorkflowRunByIdApi, "/workflows/<string:workflow_id>/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
|
||||
api.add_resource(WorkflowAppLogApi, "/workflows/logs")
|
||||
|
|
|
|||
|
|
@ -358,39 +358,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
class DocumentDeleteApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id):
|
||||
"""Delete document."""
|
||||
document_id = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# get dataset info
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
try:
|
||||
# delete document
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DocumentListApi(DatasetApiResource):
|
||||
def get(self, tenant_id, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
|
|
@ -473,7 +440,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
|||
return data
|
||||
|
||||
|
||||
class DocumentDetailApi(DatasetApiResource):
|
||||
class DocumentApi(DatasetApiResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
|
|
@ -567,6 +534,37 @@ class DocumentDetailApi(DatasetApiResource):
|
|||
|
||||
return response
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id):
|
||||
"""Delete document."""
|
||||
document_id = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# get dataset info
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
try:
|
||||
# delete document
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DocumentAddByTextApi,
|
||||
|
|
@ -588,7 +586,6 @@ api.add_resource(
|
|||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
|
||||
)
|
||||
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
|
||||
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ class ProviderConfig(BasicProviderConfig):
|
|||
|
||||
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
|
||||
required: bool = False
|
||||
default: Optional[Union[int, str]] = None
|
||||
default: Optional[Union[int, str, float, bool]] = None
|
||||
options: Optional[list[Option]] = None
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def get_attr(*, file: File, attr: FileAttribute):
|
|||
case FileAttribute.TRANSFER_METHOD:
|
||||
return file.transfer_method.value
|
||||
case FileAttribute.URL:
|
||||
return file.remote_url
|
||||
return _to_url(file)
|
||||
case FileAttribute.EXTENSION:
|
||||
return file.extension
|
||||
case FileAttribute.RELATED_ID:
|
||||
|
|
|
|||
|
|
@ -322,7 +322,7 @@ class OpsTraceManager:
|
|||
:return:
|
||||
"""
|
||||
# auth check
|
||||
if enabled == True:
|
||||
if enabled:
|
||||
try:
|
||||
provider_config_map[tracing_provider]
|
||||
except KeyError:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from urllib.parse import urlparse
|
|||
import requests
|
||||
from elasticsearch import Elasticsearch
|
||||
from flask import current_app
|
||||
from packaging.version import parse as parse_version
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
|
|
@ -149,7 +150,7 @@ class ElasticSearchVector(BaseVector):
|
|||
return cast(str, info["version"]["number"])
|
||||
|
||||
def _check_version(self):
|
||||
if self._version < "8.0.0":
|
||||
if parse_version(self._version) < parse_version("8.0.0"):
|
||||
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
|
||||
|
||||
def get_type(self) -> str:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
|
||||
import tablestore # type: ignore
|
||||
|
|
@ -22,6 +23,7 @@ class TableStoreConfig(BaseModel):
|
|||
access_key_secret: Optional[str] = None
|
||||
instance_name: Optional[str] = None
|
||||
endpoint: Optional[str] = None
|
||||
normalize_full_text_bm25_score: Optional[bool] = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
|
@ -47,6 +49,7 @@ class TableStoreVector(BaseVector):
|
|||
config.access_key_secret,
|
||||
config.instance_name,
|
||||
)
|
||||
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
|
||||
self._table_name = f"{collection_name}"
|
||||
self._index_name = f"{collection_name}_idx"
|
||||
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
|
||||
|
|
@ -131,8 +134,8 @@ class TableStoreVector(BaseVector):
|
|||
filtered_list = None
|
||||
if document_ids_filter:
|
||||
filtered_list = ["document_id=" + item for item in document_ids_filter]
|
||||
|
||||
return self._search_by_full_text(query, filtered_list, top_k)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._search_by_full_text(query, filtered_list, top_k, score_threshold)
|
||||
|
||||
def delete(self) -> None:
|
||||
self._delete_table_if_exist()
|
||||
|
|
@ -318,7 +321,19 @@ class TableStoreVector(BaseVector):
|
|||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
|
||||
@staticmethod
|
||||
def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float:
|
||||
"""
|
||||
Args:
|
||||
score: BM25 search score.
|
||||
k: decay factor, the larger the k, the steeper the low score end
|
||||
"""
|
||||
normalized_score = 1 - math.exp(-k * score)
|
||||
return max(0.0, min(1.0, normalized_score))
|
||||
|
||||
def _search_by_full_text(
|
||||
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
|
||||
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
|
||||
|
||||
|
|
@ -339,15 +354,27 @@ class TableStoreVector(BaseVector):
|
|||
|
||||
documents = []
|
||||
for search_hit in search_response.search_hits:
|
||||
score = None
|
||||
if self._normalize_full_text_bm25_score:
|
||||
score = self._normalize_score_exp_decay(search_hit.score)
|
||||
|
||||
# skip when score is below threshold and use normalize score
|
||||
if score and score <= score_threshold:
|
||||
continue
|
||||
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR.value)
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR.value)
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
|
||||
if score:
|
||||
metadata["score"] = score
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
|
||||
|
|
@ -355,6 +382,8 @@ class TableStoreVector(BaseVector):
|
|||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
if self._normalize_full_text_bm25_score:
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
|
||||
|
|
@ -375,5 +404,6 @@ class TableStoreVectorFactory(AbstractVectorFactory):
|
|||
instance_name=dify_config.TABLESTORE_INSTANCE_NAME,
|
||||
access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID,
|
||||
access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET,
|
||||
normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class TimezoneConversionTool(BuiltinTool):
|
|||
target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore
|
||||
if not target_time:
|
||||
yield self.create_text_message(
|
||||
f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}"
|
||||
f"Invalid datetime and timezone: {current_time},{current_timezone},{target_timezone}"
|
||||
)
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from os import listdir, path
|
|||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import TypeAdapter
|
||||
from yarl import URL
|
||||
|
||||
|
|
@ -616,7 +617,7 @@ class ToolManager:
|
|||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class SegmentType(StrEnum):
|
|||
elif array_validation == ArrayValidation.FIRST:
|
||||
return element_type.is_valid(value[0])
|
||||
else:
|
||||
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
|
||||
return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value)
|
||||
|
||||
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
|
||||
"""
|
||||
|
|
@ -152,7 +152,7 @@ class SegmentType(StrEnum):
|
|||
|
||||
|
||||
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
||||
# ARRAY_ANY does not have correpond element type.
|
||||
# ARRAY_ANY does not have corresponding element type.
|
||||
SegmentType.ARRAY_STRING: SegmentType.STRING,
|
||||
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
|
||||
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
|
||||
|
|
|
|||
|
|
@ -597,7 +597,7 @@ def _extract_text_from_vtt(vtt_bytes: bytes) -> str:
|
|||
|
||||
for i in range(1, len(raw_results)):
|
||||
spk, txt = raw_results[i]
|
||||
if spk == None:
|
||||
if spk is None:
|
||||
merged_results.append((None, current_text))
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -277,6 +277,22 @@ class Executor:
|
|||
elif self.auth.config.type == "custom":
|
||||
headers[authorization.config.header] = authorization.config.api_key or ""
|
||||
|
||||
# Handle Content-Type for multipart/form-data requests
|
||||
# Fix for issue #22880: Missing boundary when using multipart/form-data
|
||||
body = self.node_data.body
|
||||
if body and body.type == "form-data":
|
||||
# For multipart/form-data with files, let httpx handle the boundary automatically
|
||||
# by not setting Content-Type header when files are present
|
||||
if not self.files or all(f[0] == "__multipart_placeholder__" for f in self.files):
|
||||
# Only set Content-Type when there are no actual files
|
||||
# This ensures httpx generates the correct boundary
|
||||
if "content-type" not in (k.lower() for k in headers):
|
||||
headers["Content-Type"] = "multipart/form-data"
|
||||
elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE:
|
||||
# Set Content-Type for other body types
|
||||
if "content-type" not in (k.lower() for k in headers):
|
||||
headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
|
||||
|
||||
return headers
|
||||
|
||||
def _validate_and_parse_response(self, response: httpx.Response) -> Response:
|
||||
|
|
@ -384,15 +400,24 @@ class Executor:
|
|||
# '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file.
|
||||
# This prevents logging meaningless placeholder entries.
|
||||
if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files):
|
||||
for key, (filename, content, mime_type) in self.files:
|
||||
for file_entry in self.files:
|
||||
# file_entry should be (key, (filename, content, mime_type)), but handle edge cases
|
||||
if len(file_entry) != 2 or not isinstance(file_entry[1], tuple) or len(file_entry[1]) < 2:
|
||||
continue # skip malformed entries
|
||||
key = file_entry[0]
|
||||
content = file_entry[1][1]
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
# decode content
|
||||
try:
|
||||
body_string += content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# fix: decode binary content
|
||||
pass
|
||||
# decode content safely
|
||||
if isinstance(content, bytes):
|
||||
try:
|
||||
body_string += content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
body_string += content.decode("utf-8", errors="replace")
|
||||
elif isinstance(content, str):
|
||||
body_string += content
|
||||
else:
|
||||
body_string += f"[Unsupported content type: {type(content).__name__}]"
|
||||
body_string += "\r\n"
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.node_data.body:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import io
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
|
|
@ -33,12 +33,10 @@ from core.model_runtime.entities.message_entities import (
|
|||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
|
|
@ -1006,21 +1004,6 @@ class LLMNode(BaseNode):
|
|||
)
|
||||
return saved_file
|
||||
|
||||
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
|
||||
"""
|
||||
Fetch model schema
|
||||
"""
|
||||
model_name = self._node_data.model.name
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
|
||||
)
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
model_credentials = model_instance.credentials
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_schema
|
||||
|
||||
@staticmethod
|
||||
def fetch_structured_output_schema(
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -318,6 +318,33 @@ class ToolNode(BaseNode):
|
|||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": message.message.text,
|
||||
}
|
||||
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import mimetypes
|
||||
import os
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
|
@ -241,16 +243,21 @@ def _build_from_remote_url(
|
|||
|
||||
def _get_remote_file_info(url: str):
|
||||
file_size = -1
|
||||
filename = url.split("/")[-1].split("?")[0] or "unknown_file"
|
||||
mime_type = mimetypes.guess_type(filename)[0] or ""
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
filename = os.path.basename(url_path)
|
||||
|
||||
# Initialize mime_type from filename as fallback
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
|
||||
resp = ssrf_proxy.head(url, follow_redirects=True)
|
||||
resp = cast(httpx.Response, resp)
|
||||
if resp.status_code == httpx.codes.OK:
|
||||
if content_disposition := resp.headers.get("Content-Disposition"):
|
||||
filename = str(content_disposition.split("filename=")[-1].strip('"'))
|
||||
# Re-guess mime_type from updated filename
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
file_size = int(resp.headers.get("Content-Length", file_size))
|
||||
mime_type = mime_type or str(resp.headers.get("Content-Type", ""))
|
||||
|
||||
return mime_type, filename, file_size
|
||||
|
||||
|
|
|
|||
|
|
@ -59,6 +59,8 @@ model_config_fields = {
|
|||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
||||
|
||||
app_detail_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
|
|
@ -77,6 +79,7 @@ app_detail_fields = {
|
|||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
}
|
||||
|
||||
prompt_config_fields = {
|
||||
|
|
@ -92,8 +95,6 @@ model_config_partial_fields = {
|
|||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
||||
|
||||
app_partial_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
|
|
@ -185,7 +186,6 @@ app_detail_fields_with_site = {
|
|||
"enable_api": fields.Boolean,
|
||||
"model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||
"site": fields.Nested(site_fields),
|
||||
"api_base_url": fields.String,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"max_active_requests": fields.Integer,
|
||||
|
|
@ -195,6 +195,8 @@ app_detail_fields_with_site = {
|
|||
"updated_at": TimestampField,
|
||||
"deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
"site": fields.Nested(site_fields),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
"""manual dataset field update
|
||||
|
||||
Revision ID: 532b3f888abf
|
||||
Revises: 8bcc02c9bd07
|
||||
Create Date: 2025-07-24 14:50:48.779833
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '532b3f888abf'
|
||||
down_revision = '8bcc02c9bd07'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'::character varying")
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'")
|
||||
|
|
@ -3,8 +3,9 @@ import json
|
|||
from datetime import datetime
|
||||
from typing import Optional, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin # type: ignore
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
|
||||
|
||||
from models.base import Base
|
||||
|
|
@ -83,26 +84,24 @@ class AccountStatus(enum.StrEnum):
|
|||
|
||||
class Account(UserMixin, Base):
|
||||
__tablename__ = "accounts"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
name: Mapped[str] = mapped_column(db.String(255))
|
||||
email: Mapped[str] = mapped_column(db.String(255))
|
||||
password: Mapped[Optional[str]] = mapped_column(db.String(255))
|
||||
password_salt: Mapped[Optional[str]] = mapped_column(db.String(255))
|
||||
avatar: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
|
||||
interface_language: Mapped[Optional[str]] = mapped_column(db.String(255))
|
||||
interface_theme: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
|
||||
timezone: Mapped[Optional[str]] = mapped_column(db.String(255))
|
||||
last_login_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
|
||||
last_login_ip: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
|
||||
last_active_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, server_default=func.current_timestamp(), nullable=False
|
||||
)
|
||||
status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'active'::character varying"))
|
||||
initialized_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
email: Mapped[str] = mapped_column(String(255))
|
||||
password: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
password_salt: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
avatar: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
interface_language: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
interface_theme: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
timezone: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying"))
|
||||
initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
|
||||
@reconstructor
|
||||
def init_on_load(self):
|
||||
|
|
@ -197,16 +196,16 @@ class TenantStatus(enum.StrEnum):
|
|||
|
||||
class Tenant(Base):
|
||||
__tablename__ = "tenants"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
name: Mapped[str] = mapped_column(db.String(255))
|
||||
encrypt_public_key = db.Column(db.Text)
|
||||
plan: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'basic'::character varying"))
|
||||
status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying"))
|
||||
custom_config: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
encrypt_public_key = db.Column(sa.Text)
|
||||
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying"))
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
|
||||
custom_config: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
|
||||
def get_accounts(self) -> list[Account]:
|
||||
return (
|
||||
|
|
@ -227,56 +226,56 @@ class Tenant(Base):
|
|||
class TenantAccountJoin(Base):
|
||||
__tablename__ = "tenant_account_joins"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
|
||||
db.Index("tenant_account_join_account_id_idx", "account_id"),
|
||||
db.Index("tenant_account_join_tenant_id_idx", "tenant_id"),
|
||||
db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
|
||||
sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
|
||||
sa.Index("tenant_account_join_account_id_idx", "account_id"),
|
||||
sa.Index("tenant_account_join_tenant_id_idx", "tenant_id"),
|
||||
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||
current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
|
||||
role: Mapped[str] = mapped_column(db.String(16), server_default="normal")
|
||||
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
||||
role: Mapped[str] = mapped_column(String(16), server_default="normal")
|
||||
invited_by: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class AccountIntegrate(Base):
|
||||
__tablename__ = "account_integrates"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
|
||||
db.UniqueConstraint("account_id", "provider", name="unique_account_provider"),
|
||||
db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
|
||||
sa.UniqueConstraint("account_id", "provider", name="unique_account_provider"),
|
||||
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||
provider: Mapped[str] = mapped_column(db.String(16))
|
||||
open_id: Mapped[str] = mapped_column(db.String(255))
|
||||
encrypted_token: Mapped[str] = mapped_column(db.String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
|
||||
provider: Mapped[str] = mapped_column(String(16))
|
||||
open_id: Mapped[str] = mapped_column(String(255))
|
||||
encrypted_token: Mapped[str] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class InvitationCode(Base):
|
||||
__tablename__ = "invitation_codes"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
|
||||
db.Index("invitation_codes_batch_idx", "batch"),
|
||||
db.Index("invitation_codes_code_idx", "code", "status"),
|
||||
sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
|
||||
sa.Index("invitation_codes_batch_idx", "batch"),
|
||||
sa.Index("invitation_codes_code_idx", "code", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(db.Integer)
|
||||
batch: Mapped[str] = mapped_column(db.String(255))
|
||||
code: Mapped[str] = mapped_column(db.String(32))
|
||||
status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'unused'::character varying"))
|
||||
used_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||
id: Mapped[int] = mapped_column(sa.Integer)
|
||||
batch: Mapped[str] = mapped_column(String(255))
|
||||
code: Mapped[str] = mapped_column(String(32))
|
||||
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying"))
|
||||
used_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
|
||||
used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
deprecated_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
|
||||
class TenantPluginPermission(Base):
|
||||
|
|
@ -292,16 +291,14 @@ class TenantPluginPermission(Base):
|
|||
|
||||
__tablename__ = "account_plugin_permissions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"),
|
||||
db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
|
||||
sa.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
install_permission: Mapped[InstallPermission] = mapped_column(
|
||||
db.String(16), nullable=False, server_default="everyone"
|
||||
)
|
||||
debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone")
|
||||
install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone")
|
||||
debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone")
|
||||
|
||||
|
||||
class TenantPluginAutoUpgradeStrategy(Base):
|
||||
|
|
@ -317,20 +314,16 @@ class TenantPluginAutoUpgradeStrategy(Base):
|
|||
|
||||
__tablename__ = "tenant_plugin_auto_upgrade_strategies"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"),
|
||||
db.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
|
||||
sa.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
strategy_setting: Mapped[StrategySetting] = mapped_column(db.String(16), nullable=False, server_default="fix_only")
|
||||
upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day
|
||||
upgrade_mode: Mapped[UpgradeMode] = mapped_column(db.String(16), nullable=False, server_default="exclude")
|
||||
exclude_plugins: Mapped[list[str]] = mapped_column(
|
||||
db.ARRAY(db.String(255)), nullable=False
|
||||
) # plugin_id (author/name)
|
||||
include_plugins: Mapped[list[str]] = mapped_column(
|
||||
db.ARRAY(db.String(255)), nullable=False
|
||||
) # plugin_id (author/name)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only")
|
||||
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day
|
||||
upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude")
|
||||
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
|
||||
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
|
||||
created_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import enum
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import mapped_column
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
|
|
@ -18,13 +19,13 @@ class APIBasedExtensionPoint(enum.Enum):
|
|||
class APIBasedExtension(Base):
|
||||
__tablename__ = "api_based_extensions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
|
||||
db.Index("api_based_extension_tenant_idx", "tenant_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
|
||||
sa.Index("api_based_extension_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
api_endpoint = mapped_column(db.String(255), nullable=False)
|
||||
api_key = mapped_column(db.Text, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
api_key = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ from datetime import datetime
|
|||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import func, select
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
|
@ -38,25 +39,25 @@ class DatasetPermissionEnum(enum.StrEnum):
|
|||
class Dataset(Base):
|
||||
__tablename__ = "datasets"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
|
||||
db.Index("dataset_tenant_idx", "tenant_id"),
|
||||
db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_pkey"),
|
||||
sa.Index("dataset_tenant_idx", "tenant_id"),
|
||||
sa.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
|
||||
PROVIDER_LIST = ["vendor", "external", None]
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
name: Mapped[str] = mapped_column(db.String(255))
|
||||
description = mapped_column(db.Text, nullable=True)
|
||||
provider: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'vendor'::character varying"))
|
||||
permission: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'only_me'::character varying"))
|
||||
data_source_type = mapped_column(db.String(255))
|
||||
indexing_technique: Mapped[Optional[str]] = mapped_column(db.String(255))
|
||||
index_struct = mapped_column(db.Text, nullable=True)
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
description = mapped_column(sa.Text, nullable=True)
|
||||
provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'::character varying"))
|
||||
permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'::character varying"))
|
||||
data_source_type = mapped_column(String(255))
|
||||
indexing_technique: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
index_struct = mapped_column(sa.Text, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
embedding_model = db.Column(db.String(255), nullable=True) # TODO: mapped_column
|
||||
|
|
@ -294,16 +295,16 @@ class Dataset(Base):
|
|||
class DatasetProcessRule(Base):
|
||||
__tablename__ = "dataset_process_rules"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
|
||||
db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
|
||||
sa.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
mode = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
|
||||
rules = mapped_column(db.Text, nullable=True)
|
||||
mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying"))
|
||||
rules = mapped_column(sa.Text, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
MODES = ["automatic", "custom", "hierarchical"]
|
||||
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
|
||||
|
|
@ -334,72 +335,70 @@ class DatasetProcessRule(Base):
|
|||
class Document(Base):
|
||||
__tablename__ = "documents"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="document_pkey"),
|
||||
db.Index("document_dataset_id_idx", "dataset_id"),
|
||||
db.Index("document_is_paused_idx", "is_paused"),
|
||||
db.Index("document_tenant_idx", "tenant_id"),
|
||||
db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
|
||||
sa.PrimaryKeyConstraint("id", name="document_pkey"),
|
||||
sa.Index("document_dataset_id_idx", "dataset_id"),
|
||||
sa.Index("document_is_paused_idx", "is_paused"),
|
||||
sa.Index("document_tenant_idx", "tenant_id"),
|
||||
sa.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
position = mapped_column(db.Integer, nullable=False)
|
||||
data_source_type = mapped_column(db.String(255), nullable=False)
|
||||
data_source_info = mapped_column(db.Text, nullable=True)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
data_source_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
data_source_info = mapped_column(sa.Text, nullable=True)
|
||||
dataset_process_rule_id = mapped_column(StringUUID, nullable=True)
|
||||
batch = mapped_column(db.String(255), nullable=False)
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
created_from = mapped_column(db.String(255), nullable=False)
|
||||
batch: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_api_request_id = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
# start processing
|
||||
processing_started_at = mapped_column(db.DateTime, nullable=True)
|
||||
processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# parsing
|
||||
file_id = mapped_column(db.Text, nullable=True)
|
||||
word_count = mapped_column(db.Integer, nullable=True)
|
||||
parsing_completed_at = mapped_column(db.DateTime, nullable=True)
|
||||
file_id = mapped_column(sa.Text, nullable=True)
|
||||
word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable
|
||||
parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# cleaning
|
||||
cleaning_completed_at = mapped_column(db.DateTime, nullable=True)
|
||||
cleaning_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# split
|
||||
splitting_completed_at = mapped_column(db.DateTime, nullable=True)
|
||||
splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# indexing
|
||||
tokens = mapped_column(db.Integer, nullable=True)
|
||||
indexing_latency = mapped_column(db.Float, nullable=True)
|
||||
completed_at = mapped_column(db.DateTime, nullable=True)
|
||||
tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
indexing_latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# pause
|
||||
is_paused = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
|
||||
is_paused: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
paused_by = mapped_column(StringUUID, nullable=True)
|
||||
paused_at = mapped_column(db.DateTime, nullable=True)
|
||||
paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# error
|
||||
error = mapped_column(db.Text, nullable=True)
|
||||
stopped_at = mapped_column(db.DateTime, nullable=True)
|
||||
error = mapped_column(sa.Text, nullable=True)
|
||||
stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# basic fields
|
||||
indexing_status = mapped_column(
|
||||
db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")
|
||||
)
|
||||
enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
disabled_at = mapped_column(db.DateTime, nullable=True)
|
||||
indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'::character varying"))
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
archived = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
archived_reason = mapped_column(db.String(255), nullable=True)
|
||||
archived: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
archived_reason = mapped_column(String(255), nullable=True)
|
||||
archived_by = mapped_column(StringUUID, nullable=True)
|
||||
archived_at = mapped_column(db.DateTime, nullable=True)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
doc_type = mapped_column(db.String(40), nullable=True)
|
||||
archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
doc_type = mapped_column(String(40), nullable=True)
|
||||
doc_metadata = mapped_column(JSONB, nullable=True)
|
||||
doc_form = mapped_column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
|
||||
doc_language = mapped_column(db.String(255), nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'::character varying"))
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
|
||||
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
|
||||
|
||||
|
|
@ -556,7 +555,7 @@ class Document(Base):
|
|||
"id": "built-in",
|
||||
"name": BuiltInField.upload_date,
|
||||
"type": "time",
|
||||
"value": self.created_at.timestamp(),
|
||||
"value": str(self.created_at.timestamp()),
|
||||
}
|
||||
)
|
||||
built_in_fields.append(
|
||||
|
|
@ -564,7 +563,7 @@ class Document(Base):
|
|||
"id": "built-in",
|
||||
"name": BuiltInField.last_update_date,
|
||||
"type": "time",
|
||||
"value": self.updated_at.timestamp(),
|
||||
"value": str(self.updated_at.timestamp()),
|
||||
}
|
||||
)
|
||||
built_in_fields.append(
|
||||
|
|
@ -677,45 +676,45 @@ class Document(Base):
|
|||
class DocumentSegment(Base):
|
||||
__tablename__ = "document_segments"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
|
||||
db.Index("document_segment_dataset_id_idx", "dataset_id"),
|
||||
db.Index("document_segment_document_id_idx", "document_id"),
|
||||
db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
|
||||
db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
|
||||
db.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"),
|
||||
db.Index("document_segment_tenant_idx", "tenant_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="document_segment_pkey"),
|
||||
sa.Index("document_segment_dataset_id_idx", "dataset_id"),
|
||||
sa.Index("document_segment_document_id_idx", "document_id"),
|
||||
sa.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
|
||||
sa.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
|
||||
sa.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"),
|
||||
sa.Index("document_segment_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int]
|
||||
content = mapped_column(db.Text, nullable=False)
|
||||
answer = mapped_column(db.Text, nullable=True)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
answer = mapped_column(sa.Text, nullable=True)
|
||||
word_count: Mapped[int]
|
||||
tokens: Mapped[int]
|
||||
|
||||
# indexing fields
|
||||
keywords = mapped_column(db.JSON, nullable=True)
|
||||
index_node_id = mapped_column(db.String(255), nullable=True)
|
||||
index_node_hash = mapped_column(db.String(255), nullable=True)
|
||||
keywords = mapped_column(sa.JSON, nullable=True)
|
||||
index_node_id = mapped_column(String(255), nullable=True)
|
||||
index_node_hash = mapped_column(String(255), nullable=True)
|
||||
|
||||
# basic fields
|
||||
hit_count = mapped_column(db.Integer, nullable=False, default=0)
|
||||
enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
disabled_at = mapped_column(db.DateTime, nullable=True)
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'waiting'::character varying"))
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'::character varying"))
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
indexing_at = mapped_column(db.DateTime, nullable=True)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
|
||||
error = mapped_column(db.Text, nullable=True)
|
||||
stopped_at = mapped_column(db.DateTime, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
error = mapped_column(sa.Text, nullable=True)
|
||||
stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
|
|
@ -828,32 +827,36 @@ class DocumentSegment(Base):
|
|||
class ChildChunk(Base):
|
||||
__tablename__ = "child_chunks"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
|
||||
db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
|
||||
db.Index("child_chunks_node_idx", "index_node_id", "dataset_id"),
|
||||
db.Index("child_chunks_segment_idx", "segment_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
|
||||
sa.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
|
||||
sa.Index("child_chunks_node_idx", "index_node_id", "dataset_id"),
|
||||
sa.Index("child_chunks_segment_idx", "segment_id"),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
segment_id = mapped_column(StringUUID, nullable=False)
|
||||
position = mapped_column(db.Integer, nullable=False)
|
||||
content = mapped_column(db.Text, nullable=False)
|
||||
word_count = mapped_column(db.Integer, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
# indexing fields
|
||||
index_node_id = mapped_column(db.String(255), nullable=True)
|
||||
index_node_hash = mapped_column(db.String(255), nullable=True)
|
||||
type = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
|
||||
index_node_id = mapped_column(String(255), nullable=True)
|
||||
index_node_hash = mapped_column(String(255), nullable=True)
|
||||
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying"))
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
indexing_at = mapped_column(db.DateTime, nullable=True)
|
||||
completed_at = mapped_column(db.DateTime, nullable=True)
|
||||
error = mapped_column(db.Text, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
error = mapped_column(sa.Text, nullable=True)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
|
|
@ -871,14 +874,14 @@ class ChildChunk(Base):
|
|||
class AppDatasetJoin(Base):
|
||||
__tablename__ = "app_dataset_joins"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
|
||||
db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
|
||||
sa.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
|
|
@ -888,32 +891,32 @@ class AppDatasetJoin(Base):
|
|||
class DatasetQuery(Base):
|
||||
__tablename__ = "dataset_queries"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
|
||||
db.Index("dataset_query_dataset_id_idx", "dataset_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
|
||||
sa.Index("dataset_query_dataset_id_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
content = mapped_column(db.Text, nullable=False)
|
||||
source = mapped_column(db.String(255), nullable=False)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
source_app_id = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role = mapped_column(db.String, nullable=False)
|
||||
created_by_role = mapped_column(String, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetKeywordTable(Base):
|
||||
__tablename__ = "dataset_keyword_tables"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
|
||||
db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
|
||||
sa.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
dataset_id = mapped_column(StringUUID, nullable=False, unique=True)
|
||||
keyword_table = mapped_column(db.Text, nullable=False)
|
||||
keyword_table = mapped_column(sa.Text, nullable=False)
|
||||
data_source_type = mapped_column(
|
||||
db.String(255), nullable=False, server_default=db.text("'database'::character varying")
|
||||
String(255), nullable=False, server_default=sa.text("'database'::character varying")
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -950,19 +953,19 @@ class DatasetKeywordTable(Base):
|
|||
class Embedding(Base):
|
||||
__tablename__ = "embeddings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="embedding_pkey"),
|
||||
db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
|
||||
db.Index("created_at_idx", "created_at"),
|
||||
sa.PrimaryKeyConstraint("id", name="embedding_pkey"),
|
||||
sa.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
|
||||
sa.Index("created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
model_name = mapped_column(
|
||||
db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")
|
||||
String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'::character varying")
|
||||
)
|
||||
hash = mapped_column(db.String(64), nullable=False)
|
||||
embedding = mapped_column(db.LargeBinary, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
provider_name = mapped_column(db.String(255), nullable=False, server_default=db.text("''::character varying"))
|
||||
hash = mapped_column(String(64), nullable=False)
|
||||
embedding = mapped_column(sa.LargeBinary, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
provider_name = mapped_column(String(255), nullable=False, server_default=sa.text("''::character varying"))
|
||||
|
||||
def set_embedding(self, embedding_data: list[float]):
|
||||
self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
|
@ -974,84 +977,84 @@ class Embedding(Base):
|
|||
class DatasetCollectionBinding(Base):
|
||||
__tablename__ = "dataset_collection_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
|
||||
db.Index("provider_model_name_idx", "provider_name", "model_name"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
|
||||
sa.Index("provider_model_name_idx", "provider_name", "model_name"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
provider_name = mapped_column(db.String(255), nullable=False)
|
||||
model_name = mapped_column(db.String(255), nullable=False)
|
||||
type = mapped_column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
|
||||
collection_name = mapped_column(db.String(64), nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type = mapped_column(String(40), server_default=sa.text("'dataset'::character varying"), nullable=False)
|
||||
collection_name = mapped_column(String(64), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TidbAuthBinding(Base):
|
||||
__tablename__ = "tidb_auth_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
|
||||
db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
|
||||
db.Index("tidb_auth_bindings_active_idx", "active"),
|
||||
db.Index("tidb_auth_bindings_created_at_idx", "created_at"),
|
||||
db.Index("tidb_auth_bindings_status_idx", "status"),
|
||||
sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
|
||||
sa.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
|
||||
sa.Index("tidb_auth_bindings_active_idx", "active"),
|
||||
sa.Index("tidb_auth_bindings_created_at_idx", "created_at"),
|
||||
sa.Index("tidb_auth_bindings_status_idx", "status"),
|
||||
)
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
cluster_id = mapped_column(db.String(255), nullable=False)
|
||||
cluster_name = mapped_column(db.String(255), nullable=False)
|
||||
active = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
status = mapped_column(db.String(255), nullable=False, server_default=db.text("CREATING"))
|
||||
account = mapped_column(db.String(255), nullable=False)
|
||||
password = mapped_column(db.String(255), nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'::character varying"))
|
||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class Whitelist(Base):
|
||||
__tablename__ = "whitelists"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
|
||||
db.Index("whitelists_tenant_idx", "tenant_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="whitelists_pkey"),
|
||||
sa.Index("whitelists_tenant_idx", "tenant_id"),
|
||||
)
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
category = mapped_column(db.String(255), nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetPermission(Base):
|
||||
__tablename__ = "dataset_permissions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
|
||||
db.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
|
||||
db.Index("idx_dataset_permissions_account_id", "account_id"),
|
||||
db.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
|
||||
sa.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
|
||||
sa.Index("idx_dataset_permissions_account_id", "account_id"),
|
||||
sa.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), primary_key=True)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
has_permission = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
has_permission: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ExternalKnowledgeApis(Base):
|
||||
__tablename__ = "external_knowledge_apis"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
|
||||
db.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
|
||||
db.Index("external_knowledge_apis_name_idx", "name"),
|
||||
sa.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
|
||||
sa.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
|
||||
sa.Index("external_knowledge_apis_name_idx", "name"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
description = mapped_column(db.String(255), nullable=False)
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
settings = mapped_column(db.Text, nullable=True)
|
||||
settings = mapped_column(sa.Text, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
|
@ -1091,71 +1094,79 @@ class ExternalKnowledgeApis(Base):
|
|||
class ExternalKnowledgeBindings(Base):
|
||||
__tablename__ = "external_knowledge_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
|
||||
db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
|
||||
db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
|
||||
db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
|
||||
db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
|
||||
sa.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
|
||||
sa.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
|
||||
sa.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
|
||||
sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
external_knowledge_api_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
external_knowledge_id = mapped_column(db.Text, nullable=False)
|
||||
external_knowledge_id = mapped_column(sa.Text, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetAutoDisableLog(Base):
|
||||
__tablename__ = "dataset_auto_disable_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
|
||||
db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
|
||||
db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
|
||||
db.Index("dataset_auto_disable_log_created_atx", "created_at"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
|
||||
sa.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
|
||||
sa.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
|
||||
sa.Index("dataset_auto_disable_log_created_atx", "created_at"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
notified = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
|
||||
|
||||
class RateLimitLog(Base):
|
||||
__tablename__ = "rate_limit_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
|
||||
db.Index("rate_limit_log_tenant_idx", "tenant_id"),
|
||||
db.Index("rate_limit_log_operation_idx", "operation"),
|
||||
sa.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
|
||||
sa.Index("rate_limit_log_tenant_idx", "tenant_id"),
|
||||
sa.Index("rate_limit_log_operation_idx", "operation"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
subscription_plan = mapped_column(db.String(255), nullable=False)
|
||||
operation = mapped_column(db.String(255), nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
operation: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
|
||||
|
||||
class DatasetMetadata(Base):
|
||||
__tablename__ = "dataset_metadatas"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
|
||||
db.Index("dataset_metadata_tenant_idx", "tenant_id"),
|
||||
db.Index("dataset_metadata_dataset_idx", "dataset_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
|
||||
sa.Index("dataset_metadata_tenant_idx", "tenant_id"),
|
||||
sa.Index("dataset_metadata_dataset_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
type = mapped_column(db.String(255), nullable=False)
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
|
|
@ -1163,19 +1174,19 @@ class DatasetMetadata(Base):
|
|||
class DatasetMetadataBinding(Base):
|
||||
__tablename__ = "dataset_metadata_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
|
||||
db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
|
||||
db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
|
||||
db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
|
||||
db.Index("dataset_metadata_binding_document_idx", "document_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
|
||||
sa.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
|
||||
sa.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
|
||||
sa.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
|
||||
sa.Index("dataset_metadata_binding_document_idx", "document_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
metadata_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
|||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -35,10 +35,10 @@ from .types import StringUUID
|
|||
|
||||
class DifySetup(Base):
|
||||
__tablename__ = "dify_setups"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
|
||||
version = mapped_column(db.String(255), nullable=False)
|
||||
setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
version: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
setup_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class AppMode(StrEnum):
|
||||
|
|
@ -71,33 +71,33 @@ class IconType(Enum):
|
|||
|
||||
class App(Base):
|
||||
__tablename__ = "apps"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_pkey"), sa.Index("app_tenant_id_idx", "tenant_id"))
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
name: Mapped[str] = mapped_column(db.String(255))
|
||||
description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying"))
|
||||
mode: Mapped[str] = mapped_column(db.String(255))
|
||||
icon_type: Mapped[Optional[str]] = mapped_column(db.String(255)) # image, emoji
|
||||
icon = db.Column(db.String(255))
|
||||
icon_background: Mapped[Optional[str]] = mapped_column(db.String(255))
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying"))
|
||||
mode: Mapped[str] = mapped_column(String(255))
|
||||
icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji
|
||||
icon = db.Column(String(255))
|
||||
icon_background: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
app_model_config_id = mapped_column(StringUUID, nullable=True)
|
||||
workflow_id = mapped_column(StringUUID, nullable=True)
|
||||
status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying"))
|
||||
enable_site: Mapped[bool] = mapped_column(db.Boolean)
|
||||
enable_api: Mapped[bool] = mapped_column(db.Boolean)
|
||||
api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
|
||||
api_rph: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
|
||||
is_demo: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
|
||||
is_public: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
|
||||
is_universal: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
|
||||
tracing = mapped_column(db.Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
|
||||
enable_site: Mapped[bool] = mapped_column(sa.Boolean)
|
||||
enable_api: Mapped[bool] = mapped_column(sa.Boolean)
|
||||
api_rpm: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"))
|
||||
api_rph: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"))
|
||||
is_demo: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
||||
is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
||||
is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
||||
tracing = mapped_column(sa.Text, nullable=True)
|
||||
max_active_requests: Mapped[Optional[int]]
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
@property
|
||||
def desc_or_prompt(self):
|
||||
|
|
@ -304,36 +304,36 @@ class App(Base):
|
|||
|
||||
class AppModelConfig(Base):
|
||||
__tablename__ = "app_model_configs"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id"))
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
provider = mapped_column(db.String(255), nullable=True)
|
||||
model_id = mapped_column(db.String(255), nullable=True)
|
||||
configs = mapped_column(db.JSON, nullable=True)
|
||||
provider = mapped_column(String(255), nullable=True)
|
||||
model_id = mapped_column(String(255), nullable=True)
|
||||
configs = mapped_column(sa.JSON, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
opening_statement = mapped_column(db.Text)
|
||||
suggested_questions = mapped_column(db.Text)
|
||||
suggested_questions_after_answer = mapped_column(db.Text)
|
||||
speech_to_text = mapped_column(db.Text)
|
||||
text_to_speech = mapped_column(db.Text)
|
||||
more_like_this = mapped_column(db.Text)
|
||||
model = mapped_column(db.Text)
|
||||
user_input_form = mapped_column(db.Text)
|
||||
dataset_query_variable = mapped_column(db.String(255))
|
||||
pre_prompt = mapped_column(db.Text)
|
||||
agent_mode = mapped_column(db.Text)
|
||||
sensitive_word_avoidance = mapped_column(db.Text)
|
||||
retriever_resource = mapped_column(db.Text)
|
||||
prompt_type = mapped_column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying"))
|
||||
chat_prompt_config = mapped_column(db.Text)
|
||||
completion_prompt_config = mapped_column(db.Text)
|
||||
dataset_configs = mapped_column(db.Text)
|
||||
external_data_tools = mapped_column(db.Text)
|
||||
file_upload = mapped_column(db.Text)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
opening_statement = mapped_column(sa.Text)
|
||||
suggested_questions = mapped_column(sa.Text)
|
||||
suggested_questions_after_answer = mapped_column(sa.Text)
|
||||
speech_to_text = mapped_column(sa.Text)
|
||||
text_to_speech = mapped_column(sa.Text)
|
||||
more_like_this = mapped_column(sa.Text)
|
||||
model = mapped_column(sa.Text)
|
||||
user_input_form = mapped_column(sa.Text)
|
||||
dataset_query_variable = mapped_column(String(255))
|
||||
pre_prompt = mapped_column(sa.Text)
|
||||
agent_mode = mapped_column(sa.Text)
|
||||
sensitive_word_avoidance = mapped_column(sa.Text)
|
||||
retriever_resource = mapped_column(sa.Text)
|
||||
prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'::character varying"))
|
||||
chat_prompt_config = mapped_column(sa.Text)
|
||||
completion_prompt_config = mapped_column(sa.Text)
|
||||
dataset_configs = mapped_column(sa.Text)
|
||||
external_data_tools = mapped_column(sa.Text)
|
||||
file_upload = mapped_column(sa.Text)
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
|
|
@ -555,24 +555,24 @@ class AppModelConfig(Base):
|
|||
class RecommendedApp(Base):
|
||||
__tablename__ = "recommended_apps"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
|
||||
db.Index("recommended_app_app_id_idx", "app_id"),
|
||||
db.Index("recommended_app_is_listed_idx", "is_listed", "language"),
|
||||
sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
|
||||
sa.Index("recommended_app_app_id_idx", "app_id"),
|
||||
sa.Index("recommended_app_is_listed_idx", "is_listed", "language"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
description = mapped_column(db.JSON, nullable=False)
|
||||
copyright = mapped_column(db.String(255), nullable=False)
|
||||
privacy_policy = mapped_column(db.String(255), nullable=False)
|
||||
description = mapped_column(sa.JSON, nullable=False)
|
||||
copyright: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
|
||||
category = mapped_column(db.String(255), nullable=False)
|
||||
position = mapped_column(db.Integer, nullable=False, default=0)
|
||||
is_listed = mapped_column(db.Boolean, nullable=False, default=True)
|
||||
install_count = mapped_column(db.Integer, nullable=False, default=0)
|
||||
language = mapped_column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying"))
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
|
||||
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
|
|
@ -583,20 +583,20 @@ class RecommendedApp(Base):
|
|||
class InstalledApp(Base):
|
||||
__tablename__ = "installed_apps"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="installed_app_pkey"),
|
||||
db.Index("installed_app_tenant_id_idx", "tenant_id"),
|
||||
db.Index("installed_app_app_id_idx", "app_id"),
|
||||
db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
|
||||
sa.PrimaryKeyConstraint("id", name="installed_app_pkey"),
|
||||
sa.Index("installed_app_tenant_id_idx", "tenant_id"),
|
||||
sa.Index("installed_app_app_id_idx", "app_id"),
|
||||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
app_owner_tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
position = mapped_column(db.Integer, nullable=False, default=0)
|
||||
is_pinned = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
last_used_at = mapped_column(db.DateTime, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
is_pinned: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
last_used_at = mapped_column(sa.DateTime, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
|
|
@ -612,47 +612,47 @@ class InstalledApp(Base):
|
|||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="conversation_pkey"),
|
||||
db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="conversation_pkey"),
|
||||
sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
app_model_config_id = mapped_column(StringUUID, nullable=True)
|
||||
model_provider = mapped_column(db.String(255), nullable=True)
|
||||
override_model_configs = mapped_column(db.Text)
|
||||
model_id = mapped_column(db.String(255), nullable=True)
|
||||
mode: Mapped[str] = mapped_column(db.String(255))
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
summary = mapped_column(db.Text)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
||||
introduction = mapped_column(db.Text)
|
||||
system_instruction = mapped_column(db.Text)
|
||||
system_instruction_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
status = mapped_column(db.String(255), nullable=False)
|
||||
model_provider = mapped_column(String(255), nullable=True)
|
||||
override_model_configs = mapped_column(sa.Text)
|
||||
model_id = mapped_column(String(255), nullable=True)
|
||||
mode: Mapped[str] = mapped_column(String(255))
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
summary = mapped_column(sa.Text)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
|
||||
introduction = mapped_column(sa.Text)
|
||||
system_instruction = mapped_column(sa.Text)
|
||||
system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
status: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
# The `invoke_from` records how the conversation is created.
|
||||
#
|
||||
# Its value corresponds to the members of `InvokeFrom`.
|
||||
# (api/core/app/entities/app_invoke_entities.py)
|
||||
invoke_from = mapped_column(db.String(255), nullable=True)
|
||||
invoke_from = mapped_column(String(255), nullable=True)
|
||||
|
||||
# ref: ConversationSource.
|
||||
from_source = mapped_column(db.String(255), nullable=False)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_end_user_id = mapped_column(StringUUID)
|
||||
from_account_id = mapped_column(StringUUID)
|
||||
read_at = mapped_column(db.DateTime)
|
||||
read_at = mapped_column(sa.DateTime)
|
||||
read_account_id = mapped_column(StringUUID)
|
||||
dialogue_count: Mapped[int] = mapped_column(default=0)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
|
||||
message_annotations = db.relationship(
|
||||
"MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
|
||||
)
|
||||
|
||||
is_deleted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
|
|
@ -894,36 +894,36 @@ class Message(Base):
|
|||
Index("message_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
model_provider = mapped_column(db.String(255), nullable=True)
|
||||
model_id = mapped_column(db.String(255), nullable=True)
|
||||
override_model_configs = mapped_column(db.Text)
|
||||
conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
||||
query: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
message = mapped_column(db.JSON, nullable=False)
|
||||
message_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
message_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
answer: Mapped[str] = db.Column(db.Text, nullable=False) # TODO make it mapped_column
|
||||
answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
model_provider = mapped_column(String(255), nullable=True)
|
||||
model_id = mapped_column(String(255), nullable=True)
|
||||
override_model_configs = mapped_column(sa.Text)
|
||||
conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
|
||||
query: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
message = mapped_column(sa.JSON, nullable=False)
|
||||
message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
message_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
answer: Mapped[str] = db.Column(sa.Text, nullable=False) # TODO make it mapped_column
|
||||
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
parent_message_id = mapped_column(StringUUID, nullable=True)
|
||||
provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0"))
|
||||
total_price = mapped_column(db.Numeric(10, 7))
|
||||
currency = mapped_column(db.String(255), nullable=False)
|
||||
status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
error = mapped_column(db.Text)
|
||||
message_metadata = mapped_column(db.Text)
|
||||
invoke_from: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
|
||||
from_source = mapped_column(db.String(255), nullable=False)
|
||||
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_price = mapped_column(sa.Numeric(10, 7))
|
||||
currency: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
|
||||
error = mapped_column(sa.Text)
|
||||
message_metadata = mapped_column(sa.Text)
|
||||
invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
agent_based = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
|
||||
@property
|
||||
|
|
@ -1230,23 +1230,23 @@ class Message(Base):
|
|||
class MessageFeedback(Base):
|
||||
__tablename__ = "message_feedbacks"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
|
||||
db.Index("message_feedback_app_idx", "app_id"),
|
||||
db.Index("message_feedback_message_idx", "message_id", "from_source"),
|
||||
db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
|
||||
sa.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
|
||||
sa.Index("message_feedback_app_idx", "app_id"),
|
||||
sa.Index("message_feedback_message_idx", "message_id", "from_source"),
|
||||
sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id = mapped_column(StringUUID, nullable=False)
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
rating = mapped_column(db.String(255), nullable=False)
|
||||
content = mapped_column(db.Text)
|
||||
from_source = mapped_column(db.String(255), nullable=False)
|
||||
rating: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
content = mapped_column(sa.Text)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_end_user_id = mapped_column(StringUUID)
|
||||
from_account_id = mapped_column(StringUUID)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def from_account(self):
|
||||
|
|
@ -1272,9 +1272,9 @@ class MessageFeedback(Base):
|
|||
class MessageFile(Base):
|
||||
__tablename__ = "message_files"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_file_pkey"),
|
||||
db.Index("message_file_message_idx", "message_id"),
|
||||
db.Index("message_file_created_by_idx", "created_by"),
|
||||
sa.PrimaryKeyConstraint("id", name="message_file_pkey"),
|
||||
sa.Index("message_file_message_idx", "message_id"),
|
||||
sa.Index("message_file_created_by_idx", "created_by"),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
|
@ -1298,37 +1298,37 @@ class MessageFile(Base):
|
|||
self.created_by_role = created_by_role.value
|
||||
self.created_by = created_by
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||
belongs_to: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
transfer_method: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class MessageAnnotation(Base):
|
||||
__tablename__ = "message_annotations"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
|
||||
db.Index("message_annotation_app_idx", "app_id"),
|
||||
db.Index("message_annotation_conversation_idx", "conversation_id"),
|
||||
db.Index("message_annotation_message_idx", "message_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
|
||||
sa.Index("message_annotation_app_idx", "app_id"),
|
||||
sa.Index("message_annotation_conversation_idx", "conversation_id"),
|
||||
sa.Index("message_annotation_message_idx", "message_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, db.ForeignKey("conversations.id"))
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
|
||||
message_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
question = db.Column(db.Text, nullable=True)
|
||||
content = mapped_column(db.Text, nullable=False)
|
||||
hit_count = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
question = db.Column(sa.Text, nullable=True)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
|
|
@ -1344,24 +1344,24 @@ class MessageAnnotation(Base):
|
|||
class AppAnnotationHitHistory(Base):
|
||||
__tablename__ = "app_annotation_hit_histories"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
|
||||
db.Index("app_annotation_hit_histories_app_idx", "app_id"),
|
||||
db.Index("app_annotation_hit_histories_account_idx", "account_id"),
|
||||
db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"),
|
||||
db.Index("app_annotation_hit_histories_message_idx", "message_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
|
||||
sa.Index("app_annotation_hit_histories_app_idx", "app_id"),
|
||||
sa.Index("app_annotation_hit_histories_account_idx", "account_id"),
|
||||
sa.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"),
|
||||
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
source = mapped_column(db.Text, nullable=False)
|
||||
question = mapped_column(db.Text, nullable=False)
|
||||
source = mapped_column(sa.Text, nullable=False)
|
||||
question = mapped_column(sa.Text, nullable=False)
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
score = mapped_column(Float, nullable=False, server_default=db.text("0"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
score = mapped_column(Float, nullable=False, server_default=sa.text("0"))
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
annotation_question = mapped_column(db.Text, nullable=False)
|
||||
annotation_content = mapped_column(db.Text, nullable=False)
|
||||
annotation_question = mapped_column(sa.Text, nullable=False)
|
||||
annotation_content = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
|
|
@ -1382,18 +1382,18 @@ class AppAnnotationHitHistory(Base):
|
|||
class AppAnnotationSetting(Base):
|
||||
__tablename__ = "app_annotation_settings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
|
||||
db.Index("app_annotation_settings_app_idx", "app_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
|
||||
sa.Index("app_annotation_settings_app_idx", "app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
score_threshold = mapped_column(Float, nullable=False, server_default=db.text("0"))
|
||||
score_threshold = mapped_column(Float, nullable=False, server_default=sa.text("0"))
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=False)
|
||||
created_user_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_user_id = mapped_column(StringUUID, nullable=False)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def collection_binding_detail(self):
|
||||
|
|
@ -1410,58 +1410,58 @@ class AppAnnotationSetting(Base):
|
|||
class OperationLog(Base):
|
||||
__tablename__ = "operation_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="operation_log_pkey"),
|
||||
db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
|
||||
sa.PrimaryKeyConstraint("id", name="operation_log_pkey"),
|
||||
sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
action = mapped_column(db.String(255), nullable=False)
|
||||
content = mapped_column(db.JSON)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_ip = mapped_column(db.String(255), nullable=False)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
action: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
content = mapped_column(sa.JSON)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_ip: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class EndUser(Base, UserMixin):
|
||||
__tablename__ = "end_users"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="end_user_pkey"),
|
||||
db.Index("end_user_session_id_idx", "session_id", "type"),
|
||||
db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
|
||||
sa.PrimaryKeyConstraint("id", name="end_user_pkey"),
|
||||
sa.Index("end_user_session_id_idx", "session_id", "type"),
|
||||
sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=True)
|
||||
type = mapped_column(db.String(255), nullable=False)
|
||||
external_user_id = mapped_column(db.String(255), nullable=True)
|
||||
name = mapped_column(db.String(255))
|
||||
is_anonymous = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
external_user_id = mapped_column(String(255), nullable=True)
|
||||
name = mapped_column(String(255))
|
||||
is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
session_id: Mapped[str] = mapped_column()
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class AppMCPServer(Base):
|
||||
__tablename__ = "app_mcp_servers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
|
||||
db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
|
||||
sa.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
|
||||
sa.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
|
||||
)
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
description = mapped_column(db.String(255), nullable=False)
|
||||
server_code = mapped_column(db.String(255), nullable=False)
|
||||
status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
parameters = mapped_column(db.Text, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
server_code: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
|
||||
parameters = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@staticmethod
|
||||
def generate_server_code(n):
|
||||
|
|
@ -1480,35 +1480,35 @@ class AppMCPServer(Base):
|
|||
class Site(Base):
|
||||
__tablename__ = "sites"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="site_pkey"),
|
||||
db.Index("site_app_id_idx", "app_id"),
|
||||
db.Index("site_code_idx", "code", "status"),
|
||||
sa.PrimaryKeyConstraint("id", name="site_pkey"),
|
||||
sa.Index("site_app_id_idx", "app_id"),
|
||||
sa.Index("site_code_idx", "code", "status"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
title = mapped_column(db.String(255), nullable=False)
|
||||
icon_type = mapped_column(db.String(255), nullable=True)
|
||||
icon = mapped_column(db.String(255))
|
||||
icon_background = mapped_column(db.String(255))
|
||||
description = mapped_column(db.Text)
|
||||
default_language = mapped_column(db.String(255), nullable=False)
|
||||
chat_color_theme = mapped_column(db.String(255))
|
||||
chat_color_theme_inverted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
copyright = mapped_column(db.String(255))
|
||||
privacy_policy = mapped_column(db.String(255))
|
||||
show_workflow_steps = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
use_icon_as_answer_icon = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
icon_type = mapped_column(String(255), nullable=True)
|
||||
icon = mapped_column(String(255))
|
||||
icon_background = mapped_column(String(255))
|
||||
description = mapped_column(sa.Text)
|
||||
default_language: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
chat_color_theme = mapped_column(String(255))
|
||||
chat_color_theme_inverted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
copyright = mapped_column(String(255))
|
||||
privacy_policy = mapped_column(String(255))
|
||||
show_workflow_steps: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
_custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="")
|
||||
customize_domain = mapped_column(db.String(255))
|
||||
customize_token_strategy = mapped_column(db.String(255), nullable=False)
|
||||
prompt_public = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
customize_domain = mapped_column(String(255))
|
||||
customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
code = mapped_column(db.String(255))
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
code = mapped_column(String(255))
|
||||
|
||||
@property
|
||||
def custom_disclaimer(self):
|
||||
|
|
@ -1537,19 +1537,19 @@ class Site(Base):
|
|||
class ApiToken(Base):
|
||||
__tablename__ = "api_tokens"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="api_token_pkey"),
|
||||
db.Index("api_token_app_id_type_idx", "app_id", "type"),
|
||||
db.Index("api_token_token_idx", "token", "type"),
|
||||
db.Index("api_token_tenant_idx", "tenant_id", "type"),
|
||||
sa.PrimaryKeyConstraint("id", name="api_token_pkey"),
|
||||
sa.Index("api_token_app_id_type_idx", "app_id", "type"),
|
||||
sa.Index("api_token_token_idx", "token", "type"),
|
||||
sa.Index("api_token_tenant_idx", "tenant_id", "type"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=True)
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
type = mapped_column(db.String(16), nullable=False)
|
||||
token = mapped_column(db.String(255), nullable=False)
|
||||
last_used_at = mapped_column(db.DateTime, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
type = mapped_column(String(16), nullable=False)
|
||||
token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
last_used_at = mapped_column(sa.DateTime, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@staticmethod
|
||||
def generate_api_key(prefix, n):
|
||||
|
|
@ -1563,27 +1563,27 @@ class ApiToken(Base):
|
|||
class UploadFile(Base):
|
||||
__tablename__ = "upload_files"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="upload_file_pkey"),
|
||||
db.Index("upload_file_tenant_idx", "tenant_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="upload_file_pkey"),
|
||||
sa.Index("upload_file_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
storage_type: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
key: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
size: Mapped[int] = mapped_column(db.Integer, nullable=False)
|
||||
extension: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
mime_type: Mapped[str] = mapped_column(db.String(255), nullable=True)
|
||||
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
extension: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
mime_type: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(
|
||||
db.String(255), nullable=False, server_default=db.text("'account'::character varying")
|
||||
String(255), nullable=False, server_default=sa.text("'account'::character varying")
|
||||
)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True)
|
||||
hash: Mapped[str | None] = mapped_column(db.String(255), nullable=True)
|
||||
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
|
||||
hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
source_url: Mapped[str] = mapped_column(sa.TEXT, default="")
|
||||
|
||||
def __init__(
|
||||
|
|
@ -1625,71 +1625,71 @@ class UploadFile(Base):
|
|||
class ApiRequest(Base):
|
||||
__tablename__ = "api_requests"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="api_request_pkey"),
|
||||
db.Index("api_request_token_idx", "tenant_id", "api_token_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="api_request_pkey"),
|
||||
sa.Index("api_request_token_idx", "tenant_id", "api_token_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
api_token_id = mapped_column(StringUUID, nullable=False)
|
||||
path = mapped_column(db.String(255), nullable=False)
|
||||
request = mapped_column(db.Text, nullable=True)
|
||||
response = mapped_column(db.Text, nullable=True)
|
||||
ip = mapped_column(db.String(255), nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
path: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
request = mapped_column(sa.Text, nullable=True)
|
||||
response = mapped_column(sa.Text, nullable=True)
|
||||
ip: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class MessageChain(Base):
|
||||
__tablename__ = "message_chains"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_chain_pkey"),
|
||||
db.Index("message_chain_message_id_idx", "message_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="message_chain_pkey"),
|
||||
sa.Index("message_chain_message_id_idx", "message_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
type = mapped_column(db.String(255), nullable=False)
|
||||
input = mapped_column(db.Text, nullable=True)
|
||||
output = mapped_column(db.Text, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
input = mapped_column(sa.Text, nullable=True)
|
||||
output = mapped_column(sa.Text, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class MessageAgentThought(Base):
|
||||
__tablename__ = "message_agent_thoughts"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
|
||||
db.Index("message_agent_thought_message_id_idx", "message_id"),
|
||||
db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
|
||||
sa.Index("message_agent_thought_message_id_idx", "message_id"),
|
||||
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
message_chain_id = mapped_column(StringUUID, nullable=True)
|
||||
position = mapped_column(db.Integer, nullable=False)
|
||||
thought = mapped_column(db.Text, nullable=True)
|
||||
tool = mapped_column(db.Text, nullable=True)
|
||||
tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
|
||||
tool_meta_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
|
||||
tool_input = mapped_column(db.Text, nullable=True)
|
||||
observation = mapped_column(db.Text, nullable=True)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
thought = mapped_column(sa.Text, nullable=True)
|
||||
tool = mapped_column(sa.Text, nullable=True)
|
||||
tool_labels_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text"))
|
||||
tool_meta_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text"))
|
||||
tool_input = mapped_column(sa.Text, nullable=True)
|
||||
observation = mapped_column(sa.Text, nullable=True)
|
||||
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
|
||||
tool_process_data = mapped_column(db.Text, nullable=True)
|
||||
message = mapped_column(db.Text, nullable=True)
|
||||
message_token = mapped_column(db.Integer, nullable=True)
|
||||
message_unit_price = mapped_column(db.Numeric, nullable=True)
|
||||
message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
message_files = mapped_column(db.Text, nullable=True)
|
||||
answer = db.Column(db.Text, nullable=True)
|
||||
answer_token = mapped_column(db.Integer, nullable=True)
|
||||
answer_unit_price = mapped_column(db.Numeric, nullable=True)
|
||||
answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
tokens = mapped_column(db.Integer, nullable=True)
|
||||
total_price = mapped_column(db.Numeric, nullable=True)
|
||||
currency = mapped_column(db.String, nullable=True)
|
||||
latency = mapped_column(db.Float, nullable=True)
|
||||
created_by_role = mapped_column(db.String, nullable=False)
|
||||
tool_process_data = mapped_column(sa.Text, nullable=True)
|
||||
message = mapped_column(sa.Text, nullable=True)
|
||||
message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
message_unit_price = mapped_column(sa.Numeric, nullable=True)
|
||||
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
message_files = mapped_column(sa.Text, nullable=True)
|
||||
answer = db.Column(sa.Text, nullable=True)
|
||||
answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
answer_unit_price = mapped_column(sa.Numeric, nullable=True)
|
||||
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
total_price = mapped_column(sa.Numeric, nullable=True)
|
||||
currency = mapped_column(String, nullable=True)
|
||||
latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
|
||||
created_by_role = mapped_column(String, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
@property
|
||||
def files(self) -> list:
|
||||
|
|
@ -1771,80 +1771,80 @@ class MessageAgentThought(Base):
|
|||
class DatasetRetrieverResource(Base):
|
||||
__tablename__ = "dataset_retriever_resources"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
|
||||
db.Index("dataset_retriever_resource_message_id_idx", "message_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
|
||||
sa.Index("dataset_retriever_resource_message_id_idx", "message_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
position = mapped_column(db.Integer, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_name = mapped_column(db.Text, nullable=False)
|
||||
dataset_name = mapped_column(sa.Text, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=True)
|
||||
document_name = mapped_column(db.Text, nullable=False)
|
||||
data_source_type = mapped_column(db.Text, nullable=True)
|
||||
document_name = mapped_column(sa.Text, nullable=False)
|
||||
data_source_type = mapped_column(sa.Text, nullable=True)
|
||||
segment_id = mapped_column(StringUUID, nullable=True)
|
||||
score = mapped_column(db.Float, nullable=True)
|
||||
content = mapped_column(db.Text, nullable=False)
|
||||
hit_count = mapped_column(db.Integer, nullable=True)
|
||||
word_count = mapped_column(db.Integer, nullable=True)
|
||||
segment_position = mapped_column(db.Integer, nullable=True)
|
||||
index_node_hash = mapped_column(db.Text, nullable=True)
|
||||
retriever_from = mapped_column(db.Text, nullable=False)
|
||||
score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
index_node_hash = mapped_column(sa.Text, nullable=True)
|
||||
retriever_from = mapped_column(sa.Text, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tag_pkey"),
|
||||
db.Index("tag_type_idx", "type"),
|
||||
db.Index("tag_name_idx", "name"),
|
||||
sa.PrimaryKeyConstraint("id", name="tag_pkey"),
|
||||
sa.Index("tag_type_idx", "type"),
|
||||
sa.Index("tag_name_idx", "name"),
|
||||
)
|
||||
|
||||
TAG_TYPE_LIST = ["knowledge", "app"]
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
type = mapped_column(db.String(16), nullable=False)
|
||||
name = mapped_column(db.String(255), nullable=False)
|
||||
type = mapped_column(String(16), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TagBinding(Base):
|
||||
__tablename__ = "tag_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
|
||||
db.Index("tag_bind_target_id_idx", "target_id"),
|
||||
db.Index("tag_bind_tag_id_idx", "tag_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
|
||||
sa.Index("tag_bind_target_id_idx", "target_id"),
|
||||
sa.Index("tag_bind_tag_id_idx", "tag_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
tag_id = mapped_column(StringUUID, nullable=True)
|
||||
target_id = mapped_column(StringUUID, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TraceAppConfig(Base):
|
||||
__tablename__ = "trace_app_config"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
|
||||
db.Index("trace_app_config_app_id_idx", "app_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
|
||||
sa.Index("trace_app_config_app_id_idx", "app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
tracing_provider = mapped_column(db.String(255), nullable=True)
|
||||
tracing_config = mapped_column(db.JSON, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
tracing_provider = mapped_column(String(255), nullable=True)
|
||||
tracing_config = mapped_column(sa.JSON, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
is_active = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
|
||||
@property
|
||||
def tracing_config_dict(self):
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ from datetime import datetime
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func, text
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
|
|
@ -47,31 +47,31 @@ class Provider(Base):
|
|||
|
||||
__tablename__ = "providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="provider_pkey"),
|
||||
db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"),
|
||||
db.UniqueConstraint(
|
||||
sa.PrimaryKeyConstraint("id", name="provider_pkey"),
|
||||
sa.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"),
|
||||
sa.UniqueConstraint(
|
||||
"tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota"
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(
|
||||
db.String(40), nullable=False, server_default=text("'custom'::character varying")
|
||||
String(40), nullable=False, server_default=text("'custom'::character varying")
|
||||
)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
|
||||
last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
quota_type: Mapped[Optional[str]] = mapped_column(
|
||||
db.String(40), nullable=True, server_default=text("''::character varying")
|
||||
String(40), nullable=True, server_default=text("''::character varying")
|
||||
)
|
||||
quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True)
|
||||
quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0)
|
||||
quota_limit: Mapped[Optional[int]] = mapped_column(sa.BigInteger, nullable=True)
|
||||
quota_used: Mapped[Optional[int]] = mapped_column(sa.BigInteger, default=0)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
|
|
@ -104,80 +104,80 @@ class ProviderModel(Base):
|
|||
|
||||
__tablename__ = "provider_models"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="provider_model_pkey"),
|
||||
db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"),
|
||||
db.UniqueConstraint(
|
||||
sa.PrimaryKeyConstraint("id", name="provider_model_pkey"),
|
||||
sa.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"),
|
||||
sa.UniqueConstraint(
|
||||
"tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name"
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TenantDefaultModel(Base):
|
||||
__tablename__ = "tenant_default_models"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
|
||||
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
|
||||
sa.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
|
||||
sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TenantPreferredModelProvider(Base):
|
||||
__tablename__ = "tenant_preferred_model_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
|
||||
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
sa.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
|
||||
sa.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ProviderOrder(Base):
|
||||
__tablename__ = "provider_orders"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="provider_order_pkey"),
|
||||
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
sa.PrimaryKeyConstraint("id", name="provider_order_pkey"),
|
||||
sa.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False)
|
||||
payment_id: Mapped[Optional[str]] = mapped_column(db.String(191))
|
||||
transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191))
|
||||
quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1"))
|
||||
currency: Mapped[Optional[str]] = mapped_column(db.String(40))
|
||||
total_amount: Mapped[Optional[int]] = mapped_column(db.Integer)
|
||||
payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False)
|
||||
payment_id: Mapped[Optional[str]] = mapped_column(String(191))
|
||||
transaction_id: Mapped[Optional[str]] = mapped_column(String(191))
|
||||
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
|
||||
currency: Mapped[Optional[str]] = mapped_column(String(40))
|
||||
total_amount: Mapped[Optional[int]] = mapped_column(sa.Integer)
|
||||
payment_status: Mapped[str] = mapped_column(
|
||||
db.String(40), nullable=False, server_default=text("'wait_pay'::character varying")
|
||||
String(40), nullable=False, server_default=text("'wait_pay'::character varying")
|
||||
)
|
||||
paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||
pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||
refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
|
||||
pay_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
|
||||
refunded_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ProviderModelSetting(Base):
|
||||
|
|
@ -187,19 +187,19 @@ class ProviderModelSetting(Base):
|
|||
|
||||
__tablename__ = "provider_model_settings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"),
|
||||
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
sa.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"),
|
||||
sa.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
|
||||
load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||
load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class LoadBalancingModelConfig(Base):
|
||||
|
|
@ -209,17 +209,17 @@ class LoadBalancingModelConfig(Base):
|
|||
|
||||
__tablename__ = "load_balancing_model_configs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"),
|
||||
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
sa.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"),
|
||||
sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -1,49 +1,51 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from models.base import Base
|
||||
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DataSourceOauthBinding(Base):
|
||||
__tablename__ = "data_source_oauth_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
|
||||
db.Index("source_binding_tenant_id_idx", "tenant_id"),
|
||||
db.Index("source_info_idx", "source_info", postgresql_using="gin"),
|
||||
sa.PrimaryKeyConstraint("id", name="source_binding_pkey"),
|
||||
sa.Index("source_binding_tenant_id_idx", "tenant_id"),
|
||||
sa.Index("source_info_idx", "source_info", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
access_token = mapped_column(db.String(255), nullable=False)
|
||||
provider = mapped_column(db.String(255), nullable=False)
|
||||
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
source_info = mapped_column(JSONB, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
|
||||
|
||||
class DataSourceApiKeyAuthBinding(Base):
|
||||
__tablename__ = "data_source_api_key_auth_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
|
||||
db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"),
|
||||
db.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
|
||||
sa.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
|
||||
sa.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"),
|
||||
sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
category = mapped_column(db.String(255), nullable=False)
|
||||
provider = mapped_column(db.String(255), nullable=False)
|
||||
credentials = mapped_column(db.Text, nullable=True) # JSON
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
credentials = mapped_column(sa.Text, nullable=True) # JSON
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from celery import states # type: ignore
|
||||
from sqlalchemy import DateTime, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -15,23 +17,23 @@ class CeleryTask(Base):
|
|||
|
||||
__tablename__ = "celery_taskmeta"
|
||||
|
||||
id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
|
||||
task_id = mapped_column(db.String(155), unique=True)
|
||||
status = mapped_column(db.String(50), default=states.PENDING)
|
||||
id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
|
||||
task_id = mapped_column(String(155), unique=True)
|
||||
status = mapped_column(String(50), default=states.PENDING)
|
||||
result = mapped_column(db.PickleType, nullable=True)
|
||||
date_done = mapped_column(
|
||||
db.DateTime,
|
||||
DateTime,
|
||||
default=lambda: naive_utc_now(),
|
||||
onupdate=lambda: naive_utc_now(),
|
||||
nullable=True,
|
||||
)
|
||||
traceback = mapped_column(db.Text, nullable=True)
|
||||
name = mapped_column(db.String(155), nullable=True)
|
||||
args = mapped_column(db.LargeBinary, nullable=True)
|
||||
kwargs = mapped_column(db.LargeBinary, nullable=True)
|
||||
worker = mapped_column(db.String(155), nullable=True)
|
||||
retries = mapped_column(db.Integer, nullable=True)
|
||||
queue = mapped_column(db.String(155), nullable=True)
|
||||
traceback = mapped_column(sa.Text, nullable=True)
|
||||
name = mapped_column(String(155), nullable=True)
|
||||
args = mapped_column(sa.LargeBinary, nullable=True)
|
||||
kwargs = mapped_column(sa.LargeBinary, nullable=True)
|
||||
worker = mapped_column(String(155), nullable=True)
|
||||
retries: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
queue = mapped_column(String(155), nullable=True)
|
||||
|
||||
|
||||
class CeleryTaskSet(Base):
|
||||
|
|
@ -40,8 +42,8 @@ class CeleryTaskSet(Base):
|
|||
__tablename__ = "celery_tasksetmeta"
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
|
||||
sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
|
||||
)
|
||||
taskset_id = mapped_column(db.String(155), unique=True)
|
||||
taskset_id = mapped_column(String(155), unique=True)
|
||||
result = mapped_column(db.PickleType, nullable=True)
|
||||
date_done: Mapped[Optional[datetime]] = mapped_column(db.DateTime, default=lambda: naive_utc_now(), nullable=True)
|
||||
date_done: Mapped[Optional[datetime]] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from urllib.parse import urlparse
|
|||
|
||||
import sqlalchemy as sa
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy import ForeignKey, func
|
||||
from sqlalchemy import ForeignKey, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
|
|
@ -25,33 +25,33 @@ from .types import StringUUID
|
|||
class ToolOAuthSystemClient(Base):
|
||||
__tablename__ = "tool_oauth_system_clients"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
|
||||
db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
|
||||
sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
plugin_id = mapped_column(String(512), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# oauth params of the tool provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
|
||||
# tenant level tool oauth client params (client_id, client_secret, etc.)
|
||||
class ToolOAuthTenantClient(Base):
|
||||
__tablename__ = "tool_oauth_tenant_clients"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
# oauth params of the tool provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
@property
|
||||
def oauth_params(self) -> dict:
|
||||
|
|
@ -65,35 +65,35 @@ class BuiltinToolProvider(Base):
|
|||
|
||||
__tablename__ = "tool_builtin_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
|
||||
)
|
||||
|
||||
# id of the tool provider
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
name: Mapped[str] = mapped_column(
|
||||
db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
|
||||
String(256), nullable=False, server_default=sa.text("'API KEY 1'::character varying")
|
||||
)
|
||||
# id of the tenant
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
# who created this tool provider
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# name of the tool provider
|
||||
provider: Mapped[str] = mapped_column(db.String(256), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
# credential of the tool provider
|
||||
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
|
||||
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
# credential type, e.g., "api-key", "oauth2"
|
||||
credential_type: Mapped[str] = mapped_column(
|
||||
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
|
||||
String(32), nullable=False, server_default=sa.text("'api-key'::character varying")
|
||||
)
|
||||
expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))
|
||||
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict:
|
||||
|
|
@ -107,35 +107,35 @@ class ApiToolProvider(Base):
|
|||
|
||||
__tablename__ = "tool_api_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
|
||||
db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
|
||||
sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# name of the api provider
|
||||
name = mapped_column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
|
||||
name = mapped_column(String(255), nullable=False, server_default=sa.text("'API KEY 1'::character varying"))
|
||||
# icon
|
||||
icon = mapped_column(db.String(255), nullable=False)
|
||||
icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# original schema
|
||||
schema = mapped_column(db.Text, nullable=False)
|
||||
schema_type_str: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
schema = mapped_column(sa.Text, nullable=False)
|
||||
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# who created this tool
|
||||
user_id = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
# description of the provider
|
||||
description = mapped_column(db.Text, nullable=False)
|
||||
description = mapped_column(sa.Text, nullable=False)
|
||||
# json format tools
|
||||
tools_str = mapped_column(db.Text, nullable=False)
|
||||
tools_str = mapped_column(sa.Text, nullable=False)
|
||||
# json format credentials
|
||||
credentials_str = mapped_column(db.Text, nullable=False)
|
||||
credentials_str = mapped_column(sa.Text, nullable=False)
|
||||
# privacy policy
|
||||
privacy_policy = mapped_column(db.String(255), nullable=True)
|
||||
privacy_policy = mapped_column(String(255), nullable=True)
|
||||
# custom_disclaimer
|
||||
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def schema_type(self) -> ApiProviderSchemaType:
|
||||
|
|
@ -167,17 +167,17 @@ class ToolLabelBinding(Base):
|
|||
|
||||
__tablename__ = "tool_label_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"),
|
||||
db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"),
|
||||
sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# tool id
|
||||
tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False)
|
||||
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# tool type
|
||||
tool_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# label name
|
||||
label_name: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
|
||||
|
||||
class WorkflowToolProvider(Base):
|
||||
|
|
@ -187,38 +187,38 @@ class WorkflowToolProvider(Base):
|
|||
|
||||
__tablename__ = "tool_workflow_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"),
|
||||
db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"),
|
||||
db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"),
|
||||
sa.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"),
|
||||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# name of the workflow provider
|
||||
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# label of the workflow provider
|
||||
label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="")
|
||||
label: Mapped[str] = mapped_column(String(255), nullable=False, server_default="")
|
||||
# icon
|
||||
icon: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# app id of the workflow provider
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# version of the workflow provider
|
||||
version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="")
|
||||
version: Mapped[str] = mapped_column(String(255), nullable=False, server_default="")
|
||||
# who created this tool
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# description of the provider
|
||||
description: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# parameter configuration
|
||||
parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]")
|
||||
parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]")
|
||||
# privacy policy
|
||||
privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default="")
|
||||
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -245,38 +245,38 @@ class MCPToolProvider(Base):
|
|||
|
||||
__tablename__ = "tool_mcp_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"),
|
||||
db.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"),
|
||||
db.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"),
|
||||
sa.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"),
|
||||
sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# name of the mcp provider
|
||||
name: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# server identifier of the mcp provider
|
||||
server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False)
|
||||
server_identifier: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# encrypted url of the mcp provider
|
||||
server_url: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
server_url: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# hash of server_url for uniqueness check
|
||||
server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False)
|
||||
server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# icon of the mcp provider
|
||||
icon: Mapped[str] = mapped_column(db.String(255), nullable=True)
|
||||
icon: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# who created this tool
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# encrypted credentials
|
||||
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
|
||||
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
|
||||
# authed
|
||||
authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False)
|
||||
authed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||
# tools
|
||||
tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]")
|
||||
tools: Mapped[str] = mapped_column(sa.Text, nullable=False, default="[]")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
|
|
@ -347,35 +347,35 @@ class ToolModelInvoke(Base):
|
|||
"""
|
||||
|
||||
__tablename__ = "tool_model_invokes"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# who invoke this tool
|
||||
user_id = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
# provider
|
||||
provider = mapped_column(db.String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# type
|
||||
tool_type = mapped_column(db.String(40), nullable=False)
|
||||
tool_type = mapped_column(String(40), nullable=False)
|
||||
# tool name
|
||||
tool_name = mapped_column(db.String(128), nullable=False)
|
||||
tool_name = mapped_column(String(128), nullable=False)
|
||||
# invoke parameters
|
||||
model_parameters = mapped_column(db.Text, nullable=False)
|
||||
model_parameters = mapped_column(sa.Text, nullable=False)
|
||||
# prompt messages
|
||||
prompt_messages = mapped_column(db.Text, nullable=False)
|
||||
prompt_messages = mapped_column(sa.Text, nullable=False)
|
||||
# invoke response
|
||||
model_response = mapped_column(db.Text, nullable=False)
|
||||
model_response = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
prompt_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0"))
|
||||
total_price = mapped_column(db.Numeric(10, 7))
|
||||
currency = mapped_column(db.String(255), nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
prompt_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_price = mapped_column(sa.Numeric(10, 7))
|
||||
currency: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
@deprecated
|
||||
|
|
@ -386,13 +386,13 @@ class ToolConversationVariables(Base):
|
|||
|
||||
__tablename__ = "tool_conversation_variables"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"),
|
||||
# add index for user_id and conversation_id
|
||||
db.Index("user_id_idx", "user_id"),
|
||||
db.Index("conversation_id_idx", "conversation_id"),
|
||||
sa.Index("user_id_idx", "user_id"),
|
||||
sa.Index("conversation_id_idx", "conversation_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# conversation user id
|
||||
user_id = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
|
|
@ -400,10 +400,10 @@ class ToolConversationVariables(Base):
|
|||
# conversation id
|
||||
conversation_id = mapped_column(StringUUID, nullable=False)
|
||||
# variables pool
|
||||
variables_str = mapped_column(db.Text, nullable=False)
|
||||
variables_str = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def variables(self) -> Any:
|
||||
|
|
@ -417,11 +417,11 @@ class ToolFile(Base):
|
|||
|
||||
__tablename__ = "tool_files"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_file_pkey"),
|
||||
db.Index("tool_file_conversation_id_idx", "conversation_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_file_pkey"),
|
||||
sa.Index("tool_file_conversation_id_idx", "conversation_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# conversation user id
|
||||
user_id: Mapped[str] = mapped_column(StringUUID)
|
||||
# tenant id
|
||||
|
|
@ -429,11 +429,11 @@ class ToolFile(Base):
|
|||
# conversation id
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
# file key
|
||||
file_key: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
file_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# mime type
|
||||
mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
mimetype: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# original url
|
||||
original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True)
|
||||
original_url: Mapped[str] = mapped_column(String(2048), nullable=True)
|
||||
# name
|
||||
name: Mapped[str] = mapped_column(default="")
|
||||
# size
|
||||
|
|
@ -448,30 +448,30 @@ class DeprecatedPublishedAppTool(Base):
|
|||
|
||||
__tablename__ = "tool_published_apps"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
|
||||
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
|
||||
sa.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
|
||||
sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# id of the app
|
||||
app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
|
||||
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# who published this tool
|
||||
description = mapped_column(db.Text, nullable=False)
|
||||
description = mapped_column(sa.Text, nullable=False)
|
||||
# llm_description of the tool, for LLM
|
||||
llm_description = mapped_column(db.Text, nullable=False)
|
||||
llm_description = mapped_column(sa.Text, nullable=False)
|
||||
# query description, query will be seem as a parameter of the tool,
|
||||
# to describe this parameter to llm, we need this field
|
||||
query_description = mapped_column(db.Text, nullable=False)
|
||||
query_description = mapped_column(sa.Text, nullable=False)
|
||||
# query name, the name of the query parameter
|
||||
query_name = mapped_column(db.String(40), nullable=False)
|
||||
query_name = mapped_column(String(40), nullable=False)
|
||||
# name of the tool provider
|
||||
tool_name = mapped_column(db.String(40), nullable=False)
|
||||
tool_name = mapped_column(String(40), nullable=False)
|
||||
# author
|
||||
author = mapped_column(db.String(40), nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
author = mapped_column(String(40), nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
@property
|
||||
def description_i18n(self) -> I18nObject:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
from sqlalchemy import func
|
||||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from models.base import Base
|
||||
|
|
@ -11,18 +14,18 @@ from .types import StringUUID
|
|||
class SavedMessage(Base):
|
||||
__tablename__ = "saved_messages"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="saved_message_pkey"),
|
||||
db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
|
||||
sa.PrimaryKeyConstraint("id", name="saved_message_pkey"),
|
||||
sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
created_by_role = mapped_column(
|
||||
db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")
|
||||
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
|
|
@ -32,15 +35,15 @@ class SavedMessage(Base):
|
|||
class PinnedConversation(Base):
|
||||
__tablename__ = "pinned_conversations"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
|
||||
db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
|
||||
sa.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
|
||||
sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_by_role = mapped_column(
|
||||
db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")
|
||||
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -6,8 +6,9 @@ from enum import Enum, StrEnum
|
|||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import orm
|
||||
from sqlalchemy import DateTime, orm
|
||||
|
||||
from core.file.constants import maybe_file_object
|
||||
from core.file.models import File
|
||||
|
|
@ -24,8 +25,7 @@ from ._workflow_exc import NodeNotFoundError, WorkflowDataError
|
|||
if TYPE_CHECKING:
|
||||
from models.model import AppMode
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import Index, PrimaryKeyConstraint, UniqueConstraint, func
|
||||
from sqlalchemy import Index, PrimaryKeyConstraint, String, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
|
||||
|
|
@ -118,33 +118,33 @@ class Workflow(Base):
|
|||
|
||||
__tablename__ = "workflows"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_pkey"),
|
||||
db.Index("workflow_version_idx", "tenant_id", "app_id", "version"),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_pkey"),
|
||||
sa.Index("workflow_version_idx", "tenant_id", "app_id", "version"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
version: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
version: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
marked_name: Mapped[str] = mapped_column(default="", server_default="")
|
||||
marked_comment: Mapped[str] = mapped_column(default="", server_default="")
|
||||
graph: Mapped[str] = mapped_column(sa.Text)
|
||||
_features: Mapped[str] = mapped_column("features", sa.TEXT)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime,
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=naive_utc_now(),
|
||||
server_onupdate=func.current_timestamp(),
|
||||
)
|
||||
_environment_variables: Mapped[str] = mapped_column(
|
||||
"environment_variables", db.Text, nullable=False, server_default="{}"
|
||||
"environment_variables", sa.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
_conversation_variables: Mapped[str] = mapped_column(
|
||||
"conversation_variables", db.Text, nullable=False, server_default="{}"
|
||||
"conversation_variables", sa.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
_rag_pipeline_variables: Mapped[str] = mapped_column(
|
||||
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
|
||||
|
|
@ -521,31 +521,31 @@ class WorkflowRun(Base):
|
|||
|
||||
__tablename__ = "workflow_runs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
|
||||
db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
|
||||
sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID)
|
||||
type: Mapped[str] = mapped_column(db.String(255))
|
||||
triggered_from: Mapped[str] = mapped_column(db.String(255))
|
||||
version: Mapped[str] = mapped_column(db.String(255))
|
||||
graph: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
inputs: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
||||
type: Mapped[str] = mapped_column(String(255))
|
||||
triggered_from: Mapped[str] = mapped_column(String(255))
|
||||
version: Mapped[str] = mapped_column(String(255))
|
||||
graph: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
inputs: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
||||
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
|
||||
error: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0"))
|
||||
error: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
||||
total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
|
||||
total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
|
||||
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
|
|
@ -735,29 +735,29 @@ class WorkflowNodeExecutionModel(Base):
|
|||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID)
|
||||
triggered_from: Mapped[str] = mapped_column(db.String(255))
|
||||
triggered_from: Mapped[str] = mapped_column(String(255))
|
||||
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
index: Mapped[int] = mapped_column(db.Integer)
|
||||
predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255))
|
||||
node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255))
|
||||
node_id: Mapped[str] = mapped_column(db.String(255))
|
||||
node_type: Mapped[str] = mapped_column(db.String(255))
|
||||
title: Mapped[str] = mapped_column(db.String(255))
|
||||
inputs: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
process_data: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
outputs: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
status: Mapped[str] = mapped_column(db.String(255))
|
||||
error: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0"))
|
||||
execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
|
||||
created_by_role: Mapped[str] = mapped_column(db.String(255))
|
||||
index: Mapped[int] = mapped_column(sa.Integer)
|
||||
predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
node_execution_id: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
node_id: Mapped[str] = mapped_column(String(255))
|
||||
node_type: Mapped[str] = mapped_column(String(255))
|
||||
title: Mapped[str] = mapped_column(String(255))
|
||||
inputs: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
process_data: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
outputs: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
status: Mapped[str] = mapped_column(String(255))
|
||||
error: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0"))
|
||||
execution_metadata: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
created_by_role: Mapped[str] = mapped_column(String(255))
|
||||
created_by: Mapped[str] = mapped_column(StringUUID)
|
||||
finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
|
|
@ -865,19 +865,19 @@ class WorkflowAppLog(Base):
|
|||
|
||||
__tablename__ = "workflow_app_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
|
||||
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
|
||||
sa.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_from: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
created_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def workflow_run(self):
|
||||
|
|
@ -902,12 +902,12 @@ class ConversationVariable(Base):
|
|||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
|
||||
data: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
data: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), index=True
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None:
|
||||
|
|
@ -964,17 +964,17 @@ class WorkflowDraftVariable(Base):
|
|||
__allow_unmapped__ = True
|
||||
|
||||
# id is the unique identifier of a draft variable.
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime,
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime,
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
server_default=func.current_timestamp(),
|
||||
|
|
@ -989,7 +989,7 @@ class WorkflowDraftVariable(Base):
|
|||
#
|
||||
# If it's not edited after creation, its value is `None`.
|
||||
last_edited_at: Mapped[datetime | None] = mapped_column(
|
||||
db.DateTime,
|
||||
DateTime,
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ dev = [
|
|||
"pytest-cov~=4.1.0",
|
||||
"pytest-env~=1.1.3",
|
||||
"pytest-mock~=3.14.0",
|
||||
"testcontainers~=4.10.0",
|
||||
"types-aiofiles~=24.1.0",
|
||||
"types-beautifulsoup4~=4.12.0",
|
||||
"types-cachetools~=5.5.0",
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ from libs.email_i18n import EmailType, get_email_i18n_service
|
|||
|
||||
redis_config = parse_url(dify_config.CELERY_BROKER_URL)
|
||||
celery_redis = Redis(
|
||||
host=redis_config["hostname"],
|
||||
port=redis_config["port"],
|
||||
password=redis_config["password"],
|
||||
db=int(redis_config["virtual_host"]) if redis_config["virtual_host"] else 1,
|
||||
host=redis_config.get("hostname") or "localhost",
|
||||
port=redis_config.get("port") or 6379,
|
||||
password=redis_config.get("password") or None,
|
||||
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import yaml # type: ignore
|
|||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from packaging import version
|
||||
from packaging.version import parse as parse_version
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -269,7 +270,7 @@ class AppDslService:
|
|||
check_dependencies_pending_data = None
|
||||
if dependencies:
|
||||
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
|
||||
elif imported_version <= "0.1.5":
|
||||
elif parse_version(imported_version) <= parse_version("0.1.5"):
|
||||
if "workflow" in data:
|
||||
graph = data.get("workflow", {}).get("graph", {})
|
||||
dependencies_list = self._extract_dependencies_from_workflow_graph(graph)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from openai._exceptions import RateLimitError
|
||||
|
||||
|
|
@ -15,6 +16,7 @@ from libs.helper import RateLimiter
|
|||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
|
@ -86,7 +88,8 @@ class AppGenerateService:
|
|||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
return rate_limit.generate(
|
||||
AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().generate(
|
||||
|
|
@ -101,7 +104,8 @@ class AppGenerateService:
|
|||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().generate(
|
||||
|
|
@ -210,14 +214,27 @@ class AppGenerateService:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow:
|
||||
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow:
|
||||
"""
|
||||
Get workflow
|
||||
:param app_model: app model
|
||||
:param invoke_from: invoke from
|
||||
:param workflow_id: optional workflow id to specify a specific version
|
||||
:return:
|
||||
"""
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# If workflow_id is specified, get the specific workflow version
|
||||
if workflow_id:
|
||||
try:
|
||||
workflow_uuid = uuid.UUID(workflow_id)
|
||||
except ValueError:
|
||||
raise WorkflowIdFormatError(f"Invalid workflow_id format: '{workflow_id}'. ")
|
||||
workflow = workflow_service.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise WorkflowNotFoundError(f"Workflow not found with id: {workflow_id}")
|
||||
return workflow
|
||||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# fetch draft workflow by app_model
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
|
|
|
|||
|
|
@ -159,9 +159,9 @@ class BillingService:
|
|||
):
|
||||
limiter_key = f"{account_id}:{tenant_id}"
|
||||
if cls.compliance_download_rate_limiter.is_rate_limited(limiter_key):
|
||||
from controllers.console.error import CompilanceRateLimitError
|
||||
from controllers.console.error import ComplianceRateLimitError
|
||||
|
||||
raise CompilanceRateLimitError()
|
||||
raise ComplianceRateLimitError()
|
||||
|
||||
json = {
|
||||
"doc_name": doc_name,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
from collections.abc import Callable, Sequence
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from sqlalchemy import asc, desc, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import ConversationVariable
|
||||
|
|
@ -15,6 +18,7 @@ from models.model import App, Conversation, EndUser, Message
|
|||
from services.errors.conversation import (
|
||||
ConversationNotExistsError,
|
||||
ConversationVariableNotExistsError,
|
||||
ConversationVariableTypeMismatchError,
|
||||
LastConversationNotExistsError,
|
||||
)
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
|
@ -220,3 +224,82 @@ class ConversationService:
|
|||
]
|
||||
|
||||
return InfiniteScrollPagination(variables, limit, has_more)
|
||||
|
||||
@classmethod
|
||||
def update_conversation_variable(
|
||||
cls,
|
||||
app_model: App,
|
||||
conversation_id: str,
|
||||
variable_id: str,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
new_value: Any,
|
||||
) -> dict:
|
||||
"""
|
||||
Update a conversation variable's value.
|
||||
|
||||
Args:
|
||||
app_model: The app model
|
||||
conversation_id: The conversation ID
|
||||
variable_id: The variable ID to update
|
||||
user: The user (Account or EndUser)
|
||||
new_value: The new value for the variable
|
||||
|
||||
Returns:
|
||||
Dictionary containing the updated variable information
|
||||
|
||||
Raises:
|
||||
ConversationNotExistsError: If the conversation doesn't exist
|
||||
ConversationVariableNotExistsError: If the variable doesn't exist
|
||||
ConversationVariableTypeMismatchError: If the new value type doesn't match the variable's expected type
|
||||
"""
|
||||
# Verify conversation exists and user has access
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
# Get the existing conversation variable
|
||||
stmt = (
|
||||
select(ConversationVariable)
|
||||
.where(ConversationVariable.app_id == app_model.id)
|
||||
.where(ConversationVariable.conversation_id == conversation.id)
|
||||
.where(ConversationVariable.id == variable_id)
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
existing_variable = session.scalar(stmt)
|
||||
if not existing_variable:
|
||||
raise ConversationVariableNotExistsError()
|
||||
|
||||
# Convert existing variable to Variable object
|
||||
current_variable = existing_variable.to_variable()
|
||||
|
||||
# Validate that the new value type matches the expected variable type
|
||||
expected_type = SegmentType(current_variable.value_type)
|
||||
if not expected_type.is_valid(new_value):
|
||||
inferred_type = SegmentType.infer_segment_type(new_value)
|
||||
raise ConversationVariableTypeMismatchError(
|
||||
f"Type mismatch: variable '{current_variable.name}' expects {expected_type.value}, "
|
||||
f"but got {inferred_type.value if inferred_type else 'unknown'} type"
|
||||
)
|
||||
|
||||
# Create updated variable with new value only, preserving everything else
|
||||
updated_variable_dict = {
|
||||
"id": current_variable.id,
|
||||
"name": current_variable.name,
|
||||
"description": current_variable.description,
|
||||
"value_type": current_variable.value_type,
|
||||
"value": new_value,
|
||||
"selector": current_variable.selector,
|
||||
}
|
||||
|
||||
updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict)
|
||||
|
||||
# Use the conversation variable updater to persist the changes
|
||||
updater = conversation_variable_updater_factory()
|
||||
updater.update(conversation_id, updated_variable)
|
||||
updater.flush()
|
||||
|
||||
# Return the updated variable data
|
||||
return {
|
||||
"created_at": existing_variable.created_at,
|
||||
"updated_at": naive_utc_now(), # Update timestamp
|
||||
**updated_variable.model_dump(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -311,7 +311,7 @@ class DatasetService:
|
|||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(f"The dataset in unavailable, due to: {ex.description}")
|
||||
raise ValueError(f"The dataset is unavailable, due to: {ex.description}")
|
||||
|
||||
@staticmethod
|
||||
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str):
|
||||
|
|
@ -426,7 +426,7 @@ class DatasetService:
|
|||
raise ValueError("External knowledge api id is required.")
|
||||
# Update metadata fields
|
||||
dataset.updated_by = user.id if user else None
|
||||
dataset.updated_at = datetime.datetime.utcnow()
|
||||
dataset.updated_at = naive_utc_now()
|
||||
db.session.add(dataset)
|
||||
|
||||
# Update external knowledge binding
|
||||
|
|
@ -2498,6 +2498,7 @@ class SegmentService:
|
|||
|
||||
db.session.add(segment_document)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count += segment_document.word_count
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
|
@ -2582,6 +2583,7 @@ class SegmentService:
|
|||
else:
|
||||
keywords_list.append(None)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count += increment_word_count
|
||||
db.session.add(document)
|
||||
try:
|
||||
|
|
@ -2643,6 +2645,7 @@ class SegmentService:
|
|||
db.session.commit()
|
||||
# update document word count
|
||||
if word_count_change != 0:
|
||||
assert document.word_count is not None
|
||||
document.word_count = max(0, document.word_count + word_count_change)
|
||||
db.session.add(document)
|
||||
# update segment index task
|
||||
|
|
@ -2718,6 +2721,7 @@ class SegmentService:
|
|||
word_count_change = segment.word_count - word_count_change
|
||||
# update document word count
|
||||
if word_count_change != 0:
|
||||
assert document.word_count is not None
|
||||
document.word_count = max(0, document.word_count + word_count_change)
|
||||
db.session.add(document)
|
||||
db.session.add(segment)
|
||||
|
|
@ -2781,6 +2785,7 @@ class SegmentService:
|
|||
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id)
|
||||
db.session.delete(segment)
|
||||
# update document word count
|
||||
assert document.word_count is not None
|
||||
document.word_count -= segment.word_count
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
|
@ -2825,7 +2830,7 @@ class SegmentService:
|
|||
)
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segmment_ids = []
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
|
|
@ -2835,10 +2840,10 @@ class SegmentService:
|
|||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
db.session.add(segment)
|
||||
real_deal_segmment_ids.append(segment.id)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
|
||||
enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
|
||||
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
elif action == "disable":
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
|
|
@ -2852,7 +2857,7 @@ class SegmentService:
|
|||
)
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segmment_ids = []
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
|
|
@ -2862,10 +2867,10 @@ class SegmentService:
|
|||
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.add(segment)
|
||||
real_deal_segmment_ids.append(segment.id)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
|
||||
disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
|
||||
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
else:
|
||||
raise InvalidActionError()
|
||||
|
||||
|
|
@ -3123,7 +3128,7 @@ class SegmentService:
|
|||
# check segment
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
|
||||
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
|
|||
|
|
@ -8,3 +8,11 @@ class WorkflowHashNotEqualError(Exception):
|
|||
|
||||
class IsDraftWorkflowError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowIdFormatError(Exception):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -15,3 +15,7 @@ class ConversationCompletedError(Exception):
|
|||
|
||||
class ConversationVariableNotExistsError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class ConversationVariableTypeMismatchError(BaseServiceError):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -79,7 +79,10 @@ class MetadataService:
|
|||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
for document in documents:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
if not document.doc_metadata:
|
||||
doc_metadata = {}
|
||||
else:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
value = doc_metadata.pop(old_name, None)
|
||||
doc_metadata[name] = value
|
||||
document.doc_metadata = doc_metadata
|
||||
|
|
@ -109,7 +112,10 @@ class MetadataService:
|
|||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
for document in documents:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
if not document.doc_metadata:
|
||||
doc_metadata = {}
|
||||
else:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata.pop(metadata.name, None)
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
|
|
@ -137,7 +143,6 @@ class MetadataService:
|
|||
lock_key = f"dataset_metadata_lock_{dataset.id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
|
||||
dataset.built_in_field_enabled = True
|
||||
db.session.add(dataset)
|
||||
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||
if documents:
|
||||
|
|
@ -153,6 +158,7 @@ class MetadataService:
|
|||
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
dataset.built_in_field_enabled = True
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
logging.exception("Enable built-in field failed")
|
||||
|
|
@ -166,13 +172,15 @@ class MetadataService:
|
|||
lock_key = f"dataset_metadata_lock_{dataset.id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
|
||||
dataset.built_in_field_enabled = False
|
||||
db.session.add(dataset)
|
||||
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||
document_ids = []
|
||||
if documents:
|
||||
for document in documents:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
if not document.doc_metadata:
|
||||
doc_metadata = {}
|
||||
else:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata.pop(BuiltInField.document_name.value, None)
|
||||
doc_metadata.pop(BuiltInField.uploader.value, None)
|
||||
doc_metadata.pop(BuiltInField.upload_date.value, None)
|
||||
|
|
@ -181,6 +189,7 @@ class MetadataService:
|
|||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
document_ids.append(document.id)
|
||||
dataset.built_in_field_enabled = False
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
logging.exception("Disable built-in field failed")
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
import logging
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
|
||||
from models.engine import db
|
||||
|
|
@ -38,7 +39,7 @@ class PluginDataMigration:
|
|||
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
|
||||
limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql))
|
||||
rs = conn.execute(sa.text(sql))
|
||||
|
||||
current_iter_count = 0
|
||||
for i in rs:
|
||||
|
|
@ -94,7 +95,7 @@ limit 1000"""
|
|||
:provider_name
|
||||
{update_retrieval_model_sql}
|
||||
where id = :record_id"""
|
||||
conn.execute(db.text(sql), params)
|
||||
conn.execute(sa.text(sql), params)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
|
||||
|
|
@ -148,7 +149,7 @@ limit 1000"""
|
|||
params = {"last_id": last_id or ""}
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql), params)
|
||||
rs = conn.execute(sa.text(sql), params)
|
||||
|
||||
current_iter_count = 0
|
||||
batch_updates = []
|
||||
|
|
@ -193,7 +194,7 @@ limit 1000"""
|
|||
SET {provider_column_name} = :updated_value
|
||||
WHERE id = :record_id
|
||||
"""
|
||||
conn.execute(db.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
|
||||
conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]",
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import Any, Optional
|
|||
from uuid import uuid4
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
import tqdm
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -197,7 +198,7 @@ class PluginMigration:
|
|||
"""
|
||||
with Session(db.engine) as session:
|
||||
rs = session.execute(
|
||||
db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
|
||||
sa.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
|
||||
)
|
||||
result = []
|
||||
for row in rs:
|
||||
|
|
|
|||
|
|
@ -486,10 +486,10 @@ class BuiltinToolManageService:
|
|||
oauth_params = encrypter.decrypt(user_client.oauth_params)
|
||||
return oauth_params
|
||||
|
||||
# only verified provider can use custom oauth client
|
||||
is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
|
||||
tenant_id, provider.plugin_unique_identifier
|
||||
)
|
||||
# only verified provider can use official oauth client
|
||||
is_verified = not isinstance(
|
||||
provider_controller, PluginToolProviderController
|
||||
) or PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
|
||||
if not is_verified:
|
||||
return oauth_params
|
||||
|
||||
|
|
|
|||
|
|
@ -422,7 +422,7 @@ class WorkflowDraftVariableService:
|
|||
description=conv_var.description,
|
||||
)
|
||||
draft_conv_vars.append(draft_var)
|
||||
_batch_upsert_draft_varaible(
|
||||
_batch_upsert_draft_variable(
|
||||
self._session,
|
||||
draft_conv_vars,
|
||||
policy=_UpsertPolicy.IGNORE,
|
||||
|
|
@ -434,7 +434,7 @@ class _UpsertPolicy(StrEnum):
|
|||
OVERWRITE = "overwrite"
|
||||
|
||||
|
||||
def _batch_upsert_draft_varaible(
|
||||
def _batch_upsert_draft_variable(
|
||||
session: Session,
|
||||
draft_vars: Sequence[WorkflowDraftVariable],
|
||||
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
|
||||
|
|
@ -721,7 +721,7 @@ class DraftVariableSaver:
|
|||
draft_vars = self._build_variables_from_start_mapping(outputs)
|
||||
else:
|
||||
draft_vars = self._build_variables_from_mapping(outputs)
|
||||
_batch_upsert_draft_varaible(self._session, draft_vars)
|
||||
_batch_upsert_draft_variable(self._session, draft_vars)
|
||||
|
||||
@staticmethod
|
||||
def _should_variable_be_editable(node_id: str, name: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -129,7 +129,10 @@ class WorkflowService:
|
|||
if not workflow:
|
||||
return None
|
||||
if workflow.version == Workflow.VERSION_DRAFT:
|
||||
raise IsDraftWorkflowError(f"Workflow is draft version, id={workflow_id}")
|
||||
raise IsDraftWorkflowError(
|
||||
f"Cannot use draft workflow version. Workflow ID: {workflow_id}. "
|
||||
f"Please use a published workflow version or leave workflow_id empty."
|
||||
)
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
|
||||
|
|
@ -442,9 +445,9 @@ class WorkflowService:
|
|||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
Run free workflow node
|
||||
"""
|
||||
# run draft workflow node
|
||||
# run free workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
node_execution = self._handle_node_run_result(
|
||||
|
|
|
|||
|
|
@ -134,6 +134,7 @@ def batch_create_segment_to_index_task(
|
|||
db.session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
# update document word count
|
||||
assert dataset_document.word_count is not None
|
||||
dataset_document.word_count += word_count_change
|
||||
db.session.add(dataset_document)
|
||||
# add index to db
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
from collections.abc import Callable
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from celery import shared_task # type: ignore
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
|
@ -331,7 +332,7 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
|
|||
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query_sql), params)
|
||||
rs = conn.execute(sa.text(query_sql), params)
|
||||
if rs.rowcount == 0:
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import os
|
|||
import uuid
|
||||
|
||||
import tablestore
|
||||
from _pytest.python_api import approx
|
||||
|
||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
|
||||
TableStoreConfig,
|
||||
|
|
@ -16,7 +17,7 @@ from tests.integration_tests.vdb.test_vector_store import (
|
|||
|
||||
|
||||
class TableStoreVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
def __init__(self, normalize_full_text_score: bool = False):
|
||||
super().__init__()
|
||||
self.vector = TableStoreVector(
|
||||
collection_name=self.collection_name,
|
||||
|
|
@ -25,6 +26,7 @@ class TableStoreVectorTest(AbstractVectorTest):
|
|||
instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"),
|
||||
access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"),
|
||||
access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"),
|
||||
normalize_full_text_bm25_score=normalize_full_text_score,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -64,7 +66,21 @@ class TableStoreVectorTest(AbstractVectorTest):
|
|||
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
assert not hasattr(docs[0], "score")
|
||||
if self.vector._config.normalize_full_text_bm25_score:
|
||||
assert docs[0].metadata["score"] == approx(0.1214, abs=1e-3)
|
||||
else:
|
||||
assert docs[0].metadata.get("score") is None
|
||||
|
||||
# return none if normalize_full_text_score=true and score_threshold > 0
|
||||
docs = self.vector.search_by_full_text(
|
||||
get_example_text(), document_ids_filter=[self.example_doc_id], score_threshold=0.5
|
||||
)
|
||||
if self.vector._config.normalize_full_text_bm25_score:
|
||||
assert len(docs) == 0
|
||||
else:
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
assert docs[0].metadata.get("score") is None
|
||||
|
||||
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
|
||||
assert len(docs) == 0
|
||||
|
|
@ -80,3 +96,5 @@ class TableStoreVectorTest(AbstractVectorTest):
|
|||
|
||||
def test_tablestore_vector(setup_mock_redis):
|
||||
TableStoreVectorTest().run_all_tests()
|
||||
TableStoreVectorTest(normalize_full_text_score=True).run_all_tests()
|
||||
TableStoreVectorTest(normalize_full_text_score=False).run_all_tests()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,328 @@
|
|||
"""
|
||||
TestContainers-based integration test configuration for Dify API.
|
||||
|
||||
This module provides containerized test infrastructure using TestContainers library
|
||||
to spin up real database and service instances for integration testing. This approach
|
||||
ensures tests run against actual service implementations rather than mocks, providing
|
||||
more reliable and realistic test scenarios.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
from testcontainers.core.container import DockerContainer
|
||||
from testcontainers.core.waiting_utils import wait_for_logs
|
||||
from testcontainers.postgres import PostgresContainer
|
||||
from testcontainers.redis import RedisContainer
|
||||
|
||||
from app_factory import create_app
|
||||
from models import db
|
||||
|
||||
# Configure logging for test containers
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyTestContainers:
|
||||
"""
|
||||
Manages all test containers required for Dify integration tests.
|
||||
|
||||
This class provides a centralized way to manage multiple containers
|
||||
needed for comprehensive integration testing, including databases,
|
||||
caches, and search engines.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize container management with default configurations."""
|
||||
self.postgres: Optional[PostgresContainer] = None
|
||||
self.redis: Optional[RedisContainer] = None
|
||||
self.dify_sandbox: Optional[DockerContainer] = None
|
||||
self._containers_started = False
|
||||
logger.info("DifyTestContainers initialized - ready to manage test containers")
|
||||
|
||||
def start_containers_with_env(self) -> None:
|
||||
"""
|
||||
Start all required containers for integration testing.
|
||||
|
||||
This method initializes and starts PostgreSQL, Redis
|
||||
containers with appropriate configurations for Dify testing. Containers
|
||||
are started in dependency order to ensure proper initialization.
|
||||
"""
|
||||
if self._containers_started:
|
||||
logger.info("Containers already started - skipping container startup")
|
||||
return
|
||||
|
||||
logger.info("Starting test containers for Dify integration tests...")
|
||||
|
||||
# Start PostgreSQL container for main application database
|
||||
# PostgreSQL is used for storing user data, workflows, and application state
|
||||
logger.info("Initializing PostgreSQL container...")
|
||||
self.postgres = PostgresContainer(
|
||||
image="postgres:16-alpine",
|
||||
)
|
||||
self.postgres.start()
|
||||
db_host = self.postgres.get_container_host_ip()
|
||||
db_port = self.postgres.get_exposed_port(5432)
|
||||
os.environ["DB_HOST"] = db_host
|
||||
os.environ["DB_PORT"] = str(db_port)
|
||||
os.environ["DB_USERNAME"] = self.postgres.username
|
||||
os.environ["DB_PASSWORD"] = self.postgres.password
|
||||
os.environ["DB_DATABASE"] = self.postgres.dbname
|
||||
logger.info(
|
||||
"PostgreSQL container started successfully - Host: %s, Port: %s User: %s, Database: %s",
|
||||
db_host,
|
||||
db_port,
|
||||
self.postgres.username,
|
||||
self.postgres.dbname,
|
||||
)
|
||||
|
||||
# Wait for PostgreSQL to be ready
|
||||
logger.info("Waiting for PostgreSQL to be ready to accept connections...")
|
||||
wait_for_logs(self.postgres, "is ready to accept connections", timeout=30)
|
||||
logger.info("PostgreSQL container is ready and accepting connections")
|
||||
|
||||
# Install uuid-ossp extension for UUID generation
|
||||
logger.info("Installing uuid-ossp extension...")
|
||||
try:
|
||||
import psycopg2
|
||||
|
||||
conn = psycopg2.connect(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
user=self.postgres.username,
|
||||
password=self.postgres.password,
|
||||
database=self.postgres.dbname,
|
||||
)
|
||||
conn.autocommit = True
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
|
||||
cursor.close()
|
||||
conn.close()
|
||||
logger.info("uuid-ossp extension installed successfully")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to install uuid-ossp extension: %s", e)
|
||||
|
||||
# Set up storage environment variables
|
||||
os.environ["STORAGE_TYPE"] = "opendal"
|
||||
os.environ["OPENDAL_SCHEME"] = "fs"
|
||||
os.environ["OPENDAL_FS_ROOT"] = "storage"
|
||||
|
||||
# Start Redis container for caching and session management
|
||||
# Redis is used for storing session data, cache entries, and temporary data
|
||||
logger.info("Initializing Redis container...")
|
||||
self.redis = RedisContainer(image="redis:latest", port=6379)
|
||||
self.redis.start()
|
||||
redis_host = self.redis.get_container_host_ip()
|
||||
redis_port = self.redis.get_exposed_port(6379)
|
||||
os.environ["REDIS_HOST"] = redis_host
|
||||
os.environ["REDIS_PORT"] = str(redis_port)
|
||||
logger.info("Redis container started successfully - Host: %s, Port: %s", redis_host, redis_port)
|
||||
|
||||
# Wait for Redis to be ready
|
||||
logger.info("Waiting for Redis to be ready to accept connections...")
|
||||
wait_for_logs(self.redis, "Ready to accept connections", timeout=30)
|
||||
logger.info("Redis container is ready and accepting connections")
|
||||
|
||||
# Start Dify Sandbox container for code execution environment
|
||||
# Dify Sandbox provides a secure environment for executing user code
|
||||
logger.info("Initializing Dify Sandbox container...")
|
||||
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest")
|
||||
self.dify_sandbox.with_exposed_ports(8194)
|
||||
self.dify_sandbox.env = {
|
||||
"API_KEY": "test_api_key",
|
||||
}
|
||||
self.dify_sandbox.start()
|
||||
sandbox_host = self.dify_sandbox.get_container_host_ip()
|
||||
sandbox_port = self.dify_sandbox.get_exposed_port(8194)
|
||||
os.environ["CODE_EXECUTION_ENDPOINT"] = f"http://{sandbox_host}:{sandbox_port}"
|
||||
os.environ["CODE_EXECUTION_API_KEY"] = "test_api_key"
|
||||
logger.info("Dify Sandbox container started successfully - Host: %s, Port: %s", sandbox_host, sandbox_port)
|
||||
|
||||
# Wait for Dify Sandbox to be ready
|
||||
logger.info("Waiting for Dify Sandbox to be ready to accept connections...")
|
||||
wait_for_logs(self.dify_sandbox, "config init success", timeout=60)
|
||||
logger.info("Dify Sandbox container is ready and accepting connections")
|
||||
|
||||
self._containers_started = True
|
||||
logger.info("All test containers started successfully")
|
||||
|
||||
def stop_containers(self) -> None:
|
||||
"""
|
||||
Stop and clean up all test containers.
|
||||
|
||||
This method ensures proper cleanup of all containers to prevent
|
||||
resource leaks and conflicts between test runs.
|
||||
"""
|
||||
if not self._containers_started:
|
||||
logger.info("No containers to stop - containers were not started")
|
||||
return
|
||||
|
||||
logger.info("Stopping and cleaning up test containers...")
|
||||
containers = [self.redis, self.postgres, self.dify_sandbox]
|
||||
for container in containers:
|
||||
if container:
|
||||
try:
|
||||
container_name = container.image
|
||||
logger.info("Stopping container: %s", container_name)
|
||||
container.stop()
|
||||
logger.info("Successfully stopped container: %s", container_name)
|
||||
except Exception as e:
|
||||
# Log error but don't fail the test cleanup
|
||||
logger.warning("Failed to stop container %s: %s", container, e)
|
||||
|
||||
self._containers_started = False
|
||||
logger.info("All test containers stopped and cleaned up successfully")
|
||||
|
||||
|
||||
# Global container manager instance
|
||||
_container_manager = DifyTestContainers()
|
||||
|
||||
|
||||
def _create_app_with_containers() -> Flask:
|
||||
"""
|
||||
Create Flask application configured to use test containers.
|
||||
|
||||
This function creates a Flask application instance that is configured
|
||||
to connect to the test containers instead of the default development
|
||||
or production databases.
|
||||
|
||||
Returns:
|
||||
Flask: Configured Flask application for containerized testing
|
||||
"""
|
||||
logger.info("Creating Flask application with test container configuration...")
|
||||
|
||||
# Re-create the config after environment variables have been set
|
||||
from configs import dify_config
|
||||
|
||||
# Force re-creation of config with new environment variables
|
||||
dify_config.__dict__.clear()
|
||||
dify_config.__init__()
|
||||
|
||||
# Create and configure the Flask application
|
||||
logger.info("Initializing Flask application...")
|
||||
app = create_app()
|
||||
logger.info("Flask application created successfully")
|
||||
|
||||
# Initialize database schema
|
||||
logger.info("Creating database schema...")
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
logger.info("Database schema created successfully")
|
||||
|
||||
logger.info("Flask application configured and ready for testing")
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def set_up_containers_and_env() -> Generator[DifyTestContainers, None, None]:
|
||||
"""
|
||||
Session-scoped fixture to manage test containers.
|
||||
|
||||
This fixture ensures containers are started once per test session
|
||||
and properly cleaned up when all tests are complete. This approach
|
||||
improves test performance by reusing containers across multiple tests.
|
||||
|
||||
Yields:
|
||||
DifyTestContainers: Container manager instance
|
||||
"""
|
||||
logger.info("=== Starting test session container management ===")
|
||||
_container_manager.start_containers_with_env()
|
||||
logger.info("Test containers ready for session")
|
||||
yield _container_manager
|
||||
logger.info("=== Cleaning up test session containers ===")
|
||||
_container_manager.stop_containers()
|
||||
logger.info("Test session container cleanup completed")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def flask_app_with_containers(set_up_containers_and_env) -> Flask:
|
||||
"""
|
||||
Session-scoped Flask application fixture using test containers.
|
||||
|
||||
This fixture provides a Flask application instance that is configured
|
||||
to use the test containers for all database and service connections.
|
||||
|
||||
Args:
|
||||
containers: Container manager fixture
|
||||
|
||||
Returns:
|
||||
Flask: Configured Flask application
|
||||
"""
|
||||
logger.info("=== Creating session-scoped Flask application ===")
|
||||
app = _create_app_with_containers()
|
||||
logger.info("Session-scoped Flask application created successfully")
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]:
|
||||
"""
|
||||
Request context fixture for containerized Flask application.
|
||||
|
||||
This fixture provides a Flask request context for tests that need
|
||||
to interact with the Flask application within a request scope.
|
||||
|
||||
Args:
|
||||
flask_app_with_containers: Flask application fixture
|
||||
|
||||
Yields:
|
||||
None: Request context is active during yield
|
||||
"""
|
||||
logger.debug("Creating Flask request context...")
|
||||
with flask_app_with_containers.test_request_context():
|
||||
logger.debug("Flask request context active")
|
||||
yield
|
||||
logger.debug("Flask request context closed")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]:
|
||||
"""
|
||||
Test client fixture for containerized Flask application.
|
||||
|
||||
This fixture provides a Flask test client that can be used to make
|
||||
HTTP requests to the containerized application for integration testing.
|
||||
|
||||
Args:
|
||||
flask_app_with_containers: Flask application fixture
|
||||
|
||||
Yields:
|
||||
FlaskClient: Test client instance
|
||||
"""
|
||||
logger.debug("Creating Flask test client...")
|
||||
with flask_app_with_containers.test_client() as client:
|
||||
logger.debug("Flask test client ready")
|
||||
yield client
|
||||
logger.debug("Flask test client closed")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Database session fixture for containerized testing.
|
||||
|
||||
This fixture provides a SQLAlchemy database session that is connected
|
||||
to the test PostgreSQL container, allowing tests to interact with
|
||||
the database directly.
|
||||
|
||||
Args:
|
||||
flask_app_with_containers: Flask application fixture
|
||||
|
||||
Yields:
|
||||
Session: Database session instance
|
||||
"""
|
||||
logger.debug("Creating database session...")
|
||||
with flask_app_with_containers.app_context():
|
||||
session = db.session()
|
||||
logger.debug("Database session created and ready")
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
logger.debug("Database session closed")
|
||||
|
|
@ -0,0 +1,371 @@
|
|||
import unittest
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
from models import ToolFile, UploadFile
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
|
||||
class TestStorageKeyLoader(unittest.TestCase):
|
||||
"""
|
||||
Integration tests for StorageKeyLoader class.
|
||||
|
||||
Tests the batched loading of storage keys from the database for files
|
||||
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data before each test method."""
|
||||
self.session = db.session()
|
||||
self.tenant_id = str(uuid4())
|
||||
self.user_id = str(uuid4())
|
||||
self.conversation_id = str(uuid4())
|
||||
|
||||
# Create test data that will be cleaned up after each test
|
||||
self.test_upload_files = []
|
||||
self.test_tool_files = []
|
||||
|
||||
# Create StorageKeyLoader instance
|
||||
self.loader = StorageKeyLoader(self.session, self.tenant_id)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test data after each test method."""
|
||||
self.session.rollback()
|
||||
|
||||
def _create_upload_file(
|
||||
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
|
||||
) -> UploadFile:
|
||||
"""Helper method to create an UploadFile record for testing."""
|
||||
if file_id is None:
|
||||
file_id = str(uuid4())
|
||||
if storage_key is None:
|
||||
storage_key = f"test_storage_key_{uuid4()}"
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
key=storage_key,
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
upload_file.id = file_id
|
||||
|
||||
self.session.add(upload_file)
|
||||
self.session.flush()
|
||||
self.test_upload_files.append(upload_file)
|
||||
|
||||
return upload_file
|
||||
|
||||
def _create_tool_file(
|
||||
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
|
||||
) -> ToolFile:
|
||||
"""Helper method to create a ToolFile record for testing."""
|
||||
if file_id is None:
|
||||
file_id = str(uuid4())
|
||||
if file_key is None:
|
||||
file_key = f"test_file_key_{uuid4()}"
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
tool_file = ToolFile()
|
||||
tool_file.id = file_id
|
||||
tool_file.user_id = self.user_id
|
||||
tool_file.tenant_id = tenant_id
|
||||
tool_file.conversation_id = self.conversation_id
|
||||
tool_file.file_key = file_key
|
||||
tool_file.mimetype = "text/plain"
|
||||
tool_file.original_url = "http://example.com/file.txt"
|
||||
tool_file.name = "test_tool_file.txt"
|
||||
tool_file.size = 2048
|
||||
|
||||
self.session.add(tool_file)
|
||||
self.session.flush()
|
||||
self.test_tool_files.append(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
def _create_file(
|
||||
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
|
||||
) -> File:
|
||||
"""Helper method to create a File object for testing."""
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
|
||||
file_related_id = None
|
||||
remote_url = None
|
||||
|
||||
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
|
||||
file_related_id = related_id
|
||||
elif transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
remote_url = "https://example.com/test_file.txt"
|
||||
file_related_id = related_id
|
||||
|
||||
return File(
|
||||
id=str(uuid4()), # Generate new UUID for File.id
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=transfer_method,
|
||||
related_id=file_related_id,
|
||||
remote_url=remote_url,
|
||||
filename="test_file.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="initial_key",
|
||||
)
|
||||
|
||||
def test_load_storage_keys_local_file(self):
|
||||
"""Test loading storage keys for LOCAL_FILE transfer method."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_remote_url(self):
|
||||
"""Test loading storage keys for REMOTE_URL transfer method."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_tool_file(self):
|
||||
"""Test loading storage keys for TOOL_FILE transfer method."""
|
||||
# Create test data
|
||||
tool_file = self._create_tool_file()
|
||||
file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == tool_file.file_key
|
||||
|
||||
def test_load_storage_keys_mixed_methods(self):
|
||||
"""Test batch loading with mixed transfer methods."""
|
||||
# Create test data for different transfer methods
|
||||
upload_file1 = self._create_upload_file()
|
||||
upload_file2 = self._create_upload_file()
|
||||
tool_file = self._create_tool_file()
|
||||
|
||||
file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||
file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
files = [file1, file2, file3]
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys(files)
|
||||
|
||||
# Verify all storage keys were loaded correctly
|
||||
assert file1._storage_key == upload_file1.key
|
||||
assert file2._storage_key == upload_file2.key
|
||||
assert file3._storage_key == tool_file.file_key
|
||||
|
||||
def test_load_storage_keys_empty_list(self):
|
||||
"""Test with empty file list."""
|
||||
# Should not raise any exceptions
|
||||
self.loader.load_storage_keys([])
|
||||
|
||||
def test_load_storage_keys_tenant_mismatch(self):
|
||||
"""Test tenant_id validation."""
|
||||
# Create file with different tenant_id
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(
|
||||
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
|
||||
)
|
||||
|
||||
# Should raise ValueError for tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
assert "invalid file, expected tenant_id" in str(context.value)
|
||||
|
||||
def test_load_storage_keys_missing_file_id(self):
|
||||
"""Test with None file.related_id."""
|
||||
# Create a file with valid parameters first, then manually set related_id to None
|
||||
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file.related_id = None
|
||||
|
||||
# Should raise ValueError for None file related_id
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
assert str(context.value) == "file id should not be None."
|
||||
|
||||
def test_load_storage_keys_nonexistent_upload_file_records(self):
|
||||
"""Test with missing UploadFile database records."""
|
||||
# Create file with non-existent upload file id
|
||||
non_existent_id = str(uuid4())
|
||||
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Should raise ValueError for missing record
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_nonexistent_tool_file_records(self):
|
||||
"""Test with missing ToolFile database records."""
|
||||
# Create file with non-existent tool file id
|
||||
non_existent_id = str(uuid4())
|
||||
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# Should raise ValueError for missing record
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_invalid_uuid(self):
|
||||
"""Test with invalid UUID format."""
|
||||
# Create a file with valid parameters first, then manually set invalid related_id
|
||||
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file.related_id = "invalid-uuid-format"
|
||||
|
||||
# Should raise ValueError for invalid UUID
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_batch_efficiency(self):
|
||||
"""Test batched operations use efficient queries."""
|
||||
# Create multiple files of different types
|
||||
upload_files = [self._create_upload_file() for _ in range(3)]
|
||||
tool_files = [self._create_tool_file() for _ in range(2)]
|
||||
|
||||
files = []
|
||||
files.extend(
|
||||
[self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
|
||||
)
|
||||
files.extend(
|
||||
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
|
||||
)
|
||||
|
||||
# Mock the session to count queries
|
||||
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
|
||||
self.loader.load_storage_keys(files)
|
||||
|
||||
# Should make exactly 2 queries (one for upload_files, one for tool_files)
|
||||
assert mock_scalars.call_count == 2
|
||||
|
||||
# Verify all storage keys were loaded correctly
|
||||
for i, file in enumerate(files[:3]):
|
||||
assert file._storage_key == upload_files[i].key
|
||||
for i, file in enumerate(files[3:]):
|
||||
assert file._storage_key == tool_files[i].file_key
|
||||
|
||||
def test_load_storage_keys_tenant_isolation(self):
|
||||
"""Test that tenant isolation works correctly."""
|
||||
# Create files for different tenants
|
||||
other_tenant_id = str(uuid4())
|
||||
|
||||
# Create upload file for current tenant
|
||||
upload_file_current = self._create_upload_file()
|
||||
file_current = self._create_file(
|
||||
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||
)
|
||||
|
||||
# Create upload file for other tenant (but don't add to cleanup list)
|
||||
upload_file_other = UploadFile(
|
||||
tenant_id=other_tenant_id,
|
||||
storage_type="local",
|
||||
key="other_tenant_key",
|
||||
name="other_file.txt",
|
||||
size=1024,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
upload_file_other.id = str(uuid4())
|
||||
self.session.add(upload_file_other)
|
||||
self.session.flush()
|
||||
|
||||
# Create file for other tenant but try to load with current tenant's loader
|
||||
file_other = self._create_file(
|
||||
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
|
||||
)
|
||||
|
||||
# Should raise ValueError due to tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file_other])
|
||||
|
||||
assert "invalid file, expected tenant_id" in str(context.value)
|
||||
|
||||
# Current tenant's file should still work
|
||||
self.loader.load_storage_keys([file_current])
|
||||
assert file_current._storage_key == upload_file_current.key
|
||||
|
||||
def test_load_storage_keys_mixed_tenant_batch(self):
|
||||
"""Test batch with mixed tenant files (should fail on first mismatch)."""
|
||||
# Create files for current tenant
|
||||
upload_file_current = self._create_upload_file()
|
||||
file_current = self._create_file(
|
||||
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||
)
|
||||
|
||||
# Create file for different tenant
|
||||
other_tenant_id = str(uuid4())
|
||||
file_other = self._create_file(
|
||||
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
|
||||
)
|
||||
|
||||
# Should raise ValueError on tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file_current, file_other])
|
||||
|
||||
assert "invalid file, expected tenant_id" in str(context.value)
|
||||
|
||||
def test_load_storage_keys_duplicate_file_ids(self):
|
||||
"""Test handling of duplicate file IDs in the batch."""
|
||||
# Create upload file
|
||||
upload_file = self._create_upload_file()
|
||||
|
||||
# Create two File objects with same related_id
|
||||
file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Should handle duplicates gracefully
|
||||
self.loader.load_storage_keys([file1, file2])
|
||||
|
||||
# Both files should have the same storage key
|
||||
assert file1._storage_key == upload_file.key
|
||||
assert file2._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_session_isolation(self):
|
||||
"""Test that the loader uses the provided session correctly."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Create loader with different session (same underlying connection)
|
||||
|
||||
with Session(bind=db.engine) as other_session:
|
||||
other_loader = StorageKeyLoader(other_session, self.tenant_id)
|
||||
with pytest.raises(ValueError):
|
||||
other_loader.load_storage_keys([file])
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
import pytest
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
|
||||
CODE_LANGUAGE = "unsupported_language"
|
||||
|
||||
|
||||
def test_unsupported_with_code_template():
|
||||
with pytest.raises(CodeExecutionError) as e:
|
||||
CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={})
|
||||
assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
from textwrap import dedent
|
||||
|
||||
from .test_utils import CodeExecutorTestMixin
|
||||
|
||||
|
||||
class TestJavaScriptCodeExecutor(CodeExecutorTestMixin):
|
||||
"""Test class for JavaScript code executor functionality."""
|
||||
|
||||
def test_javascript_plain(self, flask_app_with_containers):
|
||||
"""Test basic JavaScript code execution with console.log output"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
code = 'console.log("Hello World")'
|
||||
result_message = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code)
|
||||
assert result_message == "Hello World\n"
|
||||
|
||||
def test_javascript_json(self, flask_app_with_containers):
|
||||
"""Test JavaScript code execution with JSON output"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
code = dedent("""
|
||||
obj = {'Hello': 'World'}
|
||||
console.log(JSON.stringify(obj))
|
||||
""")
|
||||
result = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code)
|
||||
assert result == '{"Hello":"World"}\n'
|
||||
|
||||
def test_javascript_with_code_template(self, flask_app_with_containers):
|
||||
"""Test JavaScript workflow code template execution with inputs"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
JavascriptCodeProvider, _ = self.javascript_imports
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JAVASCRIPT,
|
||||
code=JavascriptCodeProvider.get_default_code(),
|
||||
inputs={"arg1": "Hello", "arg2": "World"},
|
||||
)
|
||||
assert result == {"result": "HelloWorld"}
|
||||
|
||||
def test_javascript_get_runner_script(self, flask_app_with_containers):
|
||||
"""Test JavaScript template transformer runner script generation"""
|
||||
_, NodeJsTemplateTransformer = self.javascript_imports
|
||||
|
||||
runner_script = NodeJsTemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(NodeJsTemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(NodeJsTemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(NodeJsTemplateTransformer._result_tag) == 2
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
import base64
|
||||
|
||||
from .test_utils import CodeExecutorTestMixin
|
||||
|
||||
|
||||
class TestJinja2CodeExecutor(CodeExecutorTestMixin):
|
||||
"""Test class for Jinja2 code executor functionality."""
|
||||
|
||||
def test_jinja2(self, flask_app_with_containers):
|
||||
"""Test basic Jinja2 template execution with variable substitution"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
_, Jinja2TemplateTransformer = self.jinja2_imports
|
||||
|
||||
template = "Hello {{template}}"
|
||||
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
||||
code = (
|
||||
Jinja2TemplateTransformer.get_runner_script()
|
||||
.replace(Jinja2TemplateTransformer._code_placeholder, template)
|
||||
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
||||
)
|
||||
result = CodeExecutor.execute_code(
|
||||
language=CodeLanguage.JINJA2, preload=Jinja2TemplateTransformer.get_preload_script(), code=code
|
||||
)
|
||||
assert result == "<<RESULT>>Hello World<<RESULT>>\n"
|
||||
|
||||
def test_jinja2_with_code_template(self, flask_app_with_containers):
|
||||
"""Test Jinja2 workflow code template execution with inputs"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code="Hello {{template}}", inputs={"template": "World"}
|
||||
)
|
||||
assert result == {"result": "Hello World"}
|
||||
|
||||
def test_jinja2_get_runner_script(self, flask_app_with_containers):
|
||||
"""Test Jinja2 template transformer runner script generation"""
|
||||
_, Jinja2TemplateTransformer = self.jinja2_imports
|
||||
|
||||
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
from textwrap import dedent
|
||||
|
||||
from .test_utils import CodeExecutorTestMixin
|
||||
|
||||
|
||||
class TestPython3CodeExecutor(CodeExecutorTestMixin):
|
||||
"""Test class for Python3 code executor functionality."""
|
||||
|
||||
def test_python3_plain(self, flask_app_with_containers):
|
||||
"""Test basic Python3 code execution with print output"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
code = 'print("Hello World")'
|
||||
result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code)
|
||||
assert result == "Hello World\n"
|
||||
|
||||
def test_python3_json(self, flask_app_with_containers):
|
||||
"""Test Python3 code execution with JSON output"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
code = dedent("""
|
||||
import json
|
||||
print(json.dumps({'Hello': 'World'}))
|
||||
""")
|
||||
result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code)
|
||||
assert result == '{"Hello": "World"}\n'
|
||||
|
||||
def test_python3_with_code_template(self, flask_app_with_containers):
|
||||
"""Test Python3 workflow code template execution with inputs"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
Python3CodeProvider, _ = self.python3_imports
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.PYTHON3,
|
||||
code=Python3CodeProvider.get_default_code(),
|
||||
inputs={"arg1": "Hello", "arg2": "World"},
|
||||
)
|
||||
assert result == {"result": "HelloWorld"}
|
||||
|
||||
def test_python3_get_runner_script(self, flask_app_with_containers):
|
||||
"""Test Python3 template transformer runner script generation"""
|
||||
_, Python3TemplateTransformer = self.python3_imports
|
||||
|
||||
runner_script = Python3TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Python3TemplateTransformer._result_tag) == 2
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
Test utilities for code executor integration tests.
|
||||
|
||||
This module provides lazy import functions to avoid module loading issues
|
||||
that occur when modules are imported before the flask_app_with_containers fixture
|
||||
has set up the proper environment variables and configuration.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def force_reload_code_executor():
|
||||
"""
|
||||
Force reload the code_executor module to reinitialize code_execution_endpoint_url.
|
||||
|
||||
This function should be called after setting up environment variables
|
||||
to ensure the code_execution_endpoint_url is initialized with the correct value.
|
||||
"""
|
||||
try:
|
||||
import core.helper.code_executor.code_executor
|
||||
|
||||
importlib.reload(core.helper.code_executor.code_executor)
|
||||
except Exception as e:
|
||||
# Log the error but don't fail the test
|
||||
print(f"Warning: Failed to reload code_executor module: {e}")
|
||||
|
||||
|
||||
def get_code_executor_imports():
|
||||
"""
|
||||
Lazy import function for core CodeExecutor classes.
|
||||
|
||||
Returns:
|
||||
tuple: (CodeExecutor, CodeLanguage) classes
|
||||
"""
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
|
||||
return CodeExecutor, CodeLanguage
|
||||
|
||||
|
||||
def get_javascript_imports():
|
||||
"""
|
||||
Lazy import function for JavaScript-specific modules.
|
||||
|
||||
Returns:
|
||||
tuple: (JavascriptCodeProvider, NodeJsTemplateTransformer) classes
|
||||
"""
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
|
||||
|
||||
return JavascriptCodeProvider, NodeJsTemplateTransformer
|
||||
|
||||
|
||||
def get_python3_imports():
|
||||
"""
|
||||
Lazy import function for Python3-specific modules.
|
||||
|
||||
Returns:
|
||||
tuple: (Python3CodeProvider, Python3TemplateTransformer) classes
|
||||
"""
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
|
||||
return Python3CodeProvider, Python3TemplateTransformer
|
||||
|
||||
|
||||
def get_jinja2_imports():
|
||||
"""
|
||||
Lazy import function for Jinja2-specific modules.
|
||||
|
||||
Returns:
|
||||
tuple: (None, Jinja2TemplateTransformer) classes
|
||||
"""
|
||||
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
|
||||
|
||||
return None, Jinja2TemplateTransformer
|
||||
|
||||
|
||||
class CodeExecutorTestMixin:
|
||||
"""
|
||||
Mixin class providing lazy import methods for code executor tests.
|
||||
|
||||
This mixin helps avoid module loading issues by deferring imports
|
||||
until after the flask_app_with_containers fixture has set up the environment.
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
"""
|
||||
Setup method called before each test method.
|
||||
Force reload the code_executor module to ensure fresh initialization.
|
||||
"""
|
||||
force_reload_code_executor()
|
||||
|
||||
@property
|
||||
def code_executor_imports(self):
|
||||
"""Property to get CodeExecutor and CodeLanguage classes."""
|
||||
return get_code_executor_imports()
|
||||
|
||||
@property
|
||||
def javascript_imports(self):
|
||||
"""Property to get JavaScript-specific classes."""
|
||||
return get_javascript_imports()
|
||||
|
||||
@property
|
||||
def python3_imports(self):
|
||||
"""Property to get Python3-specific classes."""
|
||||
return get_python3_imports()
|
||||
|
||||
@property
|
||||
def jinja2_imports(self):
|
||||
"""Property to get Jinja2-specific classes."""
|
||||
return get_jinja2_imports()
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
import io
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.console.error import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
||||
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
||||
|
||||
|
||||
class TestFileUploadSecurity:
|
||||
"""Test file upload security logic without complex framework setup"""
|
||||
|
||||
# Test 1: Basic file validation
|
||||
def test_should_validate_file_presence(self):
|
||||
"""Test that missing file is detected"""
|
||||
from flask import Flask, request
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(method="POST", data={}):
|
||||
# Simulate the check in FileApi.post()
|
||||
if "file" not in request.files:
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
raise NoFileUploadedError()
|
||||
|
||||
def test_should_validate_multiple_files(self):
|
||||
"""Test that multiple files are rejected"""
|
||||
from flask import Flask, request
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
file_data = {
|
||||
"file": (io.BytesIO(b"content1"), "file1.txt", "text/plain"),
|
||||
"file2": (io.BytesIO(b"content2"), "file2.txt", "text/plain"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"):
|
||||
# Simulate the check in FileApi.post()
|
||||
if len(request.files) > 1:
|
||||
with pytest.raises(TooManyFilesError):
|
||||
raise TooManyFilesError()
|
||||
|
||||
def test_should_validate_empty_filename(self):
|
||||
"""Test that empty filename is rejected"""
|
||||
from flask import Flask, request
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
file_data = {"file": (io.BytesIO(b"content"), "", "text/plain")}
|
||||
|
||||
with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"):
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
raise FilenameNotExistsError
|
||||
|
||||
# Test 2: Security - Filename sanitization
|
||||
def test_should_detect_path_traversal_in_filename(self):
|
||||
"""Test protection against directory traversal attacks"""
|
||||
dangerous_filenames = [
|
||||
"../../../etc/passwd",
|
||||
"..\\..\\windows\\system32\\config\\sam",
|
||||
"../../../../etc/shadow",
|
||||
"./../../../sensitive.txt",
|
||||
]
|
||||
|
||||
for filename in dangerous_filenames:
|
||||
# Any filename containing .. should be considered dangerous
|
||||
assert ".." in filename, f"Filename {filename} should be detected as path traversal"
|
||||
|
||||
def test_should_detect_null_byte_injection(self):
|
||||
"""Test protection against null byte injection"""
|
||||
dangerous_filenames = [
|
||||
"file.jpg\x00.php",
|
||||
"document.pdf\x00.exe",
|
||||
"image.png\x00.sh",
|
||||
]
|
||||
|
||||
for filename in dangerous_filenames:
|
||||
# Null bytes should be detected
|
||||
assert "\x00" in filename, f"Filename {filename} should be detected as null byte injection"
|
||||
|
||||
def test_should_sanitize_special_characters(self):
|
||||
"""Test that special characters in filenames are handled safely"""
|
||||
# Characters that could be problematic in various contexts
|
||||
dangerous_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\x00"]
|
||||
|
||||
for char in dangerous_chars:
|
||||
filename = f"file{char}name.txt"
|
||||
# These characters should be detected or sanitized
|
||||
assert any(c in filename for c in dangerous_chars)
|
||||
|
||||
# Test 3: Permission validation
|
||||
def test_should_validate_dataset_permissions(self):
|
||||
"""Test dataset upload permission logic"""
|
||||
|
||||
class MockUser:
|
||||
is_dataset_editor = False
|
||||
|
||||
user = MockUser()
|
||||
source = "datasets"
|
||||
|
||||
# Simulate the permission check in FileApi.post()
|
||||
if source == "datasets" and not user.is_dataset_editor:
|
||||
with pytest.raises(Forbidden):
|
||||
raise Forbidden()
|
||||
|
||||
def test_should_allow_general_upload_without_permission(self):
|
||||
"""Test general upload doesn't require dataset permission"""
|
||||
|
||||
class MockUser:
|
||||
is_dataset_editor = False
|
||||
|
||||
user = MockUser()
|
||||
source = None # General upload
|
||||
|
||||
# This should not raise an exception
|
||||
if source == "datasets" and not user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
# Test passes if no exception is raised
|
||||
|
||||
# Test 4: Service error handling
|
||||
@patch("services.file_service.FileService.upload_file")
|
||||
def test_should_handle_file_too_large_error(self, mock_upload):
|
||||
"""Test that service FileTooLargeError is properly converted"""
|
||||
mock_upload.side_effect = ServiceFileTooLargeError("File too large")
|
||||
|
||||
try:
|
||||
mock_upload(filename="test.txt", content=b"data", mimetype="text/plain", user=None, source=None)
|
||||
except ServiceFileTooLargeError as e:
|
||||
# Simulate the error conversion in FileApi.post()
|
||||
with pytest.raises(FileTooLargeError):
|
||||
raise FileTooLargeError(e.description)
|
||||
|
||||
@patch("services.file_service.FileService.upload_file")
|
||||
def test_should_handle_unsupported_file_type_error(self, mock_upload):
|
||||
"""Test that service UnsupportedFileTypeError is properly converted"""
|
||||
mock_upload.side_effect = ServiceUnsupportedFileTypeError()
|
||||
|
||||
try:
|
||||
mock_upload(
|
||||
filename="test.exe", content=b"data", mimetype="application/octet-stream", user=None, source=None
|
||||
)
|
||||
except ServiceUnsupportedFileTypeError:
|
||||
# Simulate the error conversion in FileApi.post()
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# Test 5: File type security
|
||||
def test_should_identify_dangerous_file_extensions(self):
|
||||
"""Test detection of potentially dangerous file extensions"""
|
||||
dangerous_extensions = [
|
||||
".php",
|
||||
".PHP",
|
||||
".pHp", # PHP files (case variations)
|
||||
".exe",
|
||||
".EXE", # Executables
|
||||
".sh",
|
||||
".SH", # Shell scripts
|
||||
".bat",
|
||||
".BAT", # Batch files
|
||||
".cmd",
|
||||
".CMD", # Command files
|
||||
".ps1",
|
||||
".PS1", # PowerShell
|
||||
".jar",
|
||||
".JAR", # Java archives
|
||||
".vbs",
|
||||
".VBS", # VBScript
|
||||
]
|
||||
|
||||
safe_extensions = [".txt", ".pdf", ".jpg", ".png", ".doc", ".docx"]
|
||||
|
||||
# Just verify our test data is correct
|
||||
for ext in dangerous_extensions:
|
||||
assert ext.lower() in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"]
|
||||
|
||||
for ext in safe_extensions:
|
||||
assert ext.lower() not in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"]
|
||||
|
||||
def test_should_detect_double_extensions(self):
|
||||
"""Test detection of double extension attacks"""
|
||||
suspicious_filenames = [
|
||||
"image.jpg.php",
|
||||
"document.pdf.exe",
|
||||
"photo.png.sh",
|
||||
"file.txt.bat",
|
||||
]
|
||||
|
||||
for filename in suspicious_filenames:
|
||||
# Check that these have multiple extensions
|
||||
parts = filename.split(".")
|
||||
assert len(parts) > 2, f"Filename {filename} should have multiple extensions"
|
||||
|
||||
# Test 6: Configuration validation
|
||||
def test_upload_configuration_structure(self):
|
||||
"""Test that upload configuration has correct structure"""
|
||||
# Simulate the configuration returned by FileApi.get()
|
||||
config = {
|
||||
"file_size_limit": 15,
|
||||
"batch_count_limit": 5,
|
||||
"image_file_size_limit": 10,
|
||||
"video_file_size_limit": 500,
|
||||
"audio_file_size_limit": 50,
|
||||
"workflow_file_upload_limit": 10,
|
||||
}
|
||||
|
||||
# Verify all required fields are present
|
||||
required_fields = [
|
||||
"file_size_limit",
|
||||
"batch_count_limit",
|
||||
"image_file_size_limit",
|
||||
"video_file_size_limit",
|
||||
"audio_file_size_limit",
|
||||
"workflow_file_upload_limit",
|
||||
]
|
||||
|
||||
for field in required_fields:
|
||||
assert field in config, f"Missing required field: {field}"
|
||||
assert isinstance(config[field], int), f"Field {field} should be an integer"
|
||||
assert config[field] > 0, f"Field {field} should be positive"
|
||||
|
||||
# Test 7: Source parameter handling
|
||||
def test_source_parameter_normalization(self):
|
||||
"""Test that source parameter is properly normalized"""
|
||||
test_cases = [
|
||||
("datasets", "datasets"),
|
||||
("other", None),
|
||||
("", None),
|
||||
(None, None),
|
||||
]
|
||||
|
||||
for input_source, expected in test_cases:
|
||||
# Simulate the source normalization in FileApi.post()
|
||||
source = "datasets" if input_source == "datasets" else None
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
assert source == expected
|
||||
|
||||
# Test 8: Boundary conditions
|
||||
def test_should_handle_edge_case_file_sizes(self):
|
||||
"""Test handling of boundary file sizes"""
|
||||
test_cases = [
|
||||
(0, "Empty file"), # 0 bytes
|
||||
(1, "Single byte"), # 1 byte
|
||||
(15 * 1024 * 1024 - 1, "Just under limit"), # Just under 15MB
|
||||
(15 * 1024 * 1024, "At limit"), # Exactly 15MB
|
||||
(15 * 1024 * 1024 + 1, "Just over limit"), # Just over 15MB
|
||||
]
|
||||
|
||||
for size, description in test_cases:
|
||||
# Just verify our test data
|
||||
assert isinstance(size, int), f"{description}: Size should be integer"
|
||||
assert size >= 0, f"{description}: Size should be non-negative"
|
||||
|
||||
def test_should_handle_special_mime_types(self):
|
||||
"""Test handling of various MIME types"""
|
||||
mime_type_tests = [
|
||||
("application/octet-stream", "Generic binary"),
|
||||
("text/plain", "Plain text"),
|
||||
("image/jpeg", "JPEG image"),
|
||||
("application/pdf", "PDF document"),
|
||||
("", "Empty MIME type"),
|
||||
(None, "None MIME type"),
|
||||
]
|
||||
|
||||
for mime_type, description in mime_type_tests:
|
||||
# Verify test data structure
|
||||
if mime_type is not None:
|
||||
assert isinstance(mime_type, str), f"{description}: MIME type should be string or None"
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from core.variables.types import SegmentType
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
|
||||
|
||||
class TestSegmentTypeIsArrayType:
|
||||
|
|
@ -17,7 +17,6 @@ class TestSegmentTypeIsArrayType:
|
|||
value is tested for the is_array_type method.
|
||||
"""
|
||||
# Arrange
|
||||
all_segment_types = set(SegmentType)
|
||||
expected_array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
|
|
@ -58,3 +57,27 @@ class TestSegmentTypeIsArrayType:
|
|||
for seg_type in enum_values:
|
||||
is_array = seg_type.is_array_type()
|
||||
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
|
||||
|
||||
|
||||
class TestSegmentTypeIsValidArrayValidation:
|
||||
"""
|
||||
Test SegmentType.is_valid with array types using different validation strategies.
|
||||
"""
|
||||
|
||||
def test_array_validation_all_success(self):
|
||||
value = ["hello", "world", "foo"]
|
||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
|
||||
|
||||
def test_array_validation_all_fail(self):
|
||||
value = ["hello", 123, "world"]
|
||||
# Should return False, since 123 is not a string
|
||||
assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
|
||||
|
||||
def test_array_validation_first(self):
|
||||
value = ["hello", 123, None]
|
||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST)
|
||||
|
||||
def test_array_validation_none(self):
|
||||
value = [1, 2, 3]
|
||||
# validation is None, skip
|
||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE)
|
||||
|
|
|
|||
32
api/uv.lock
32
api/uv.lock
|
|
@ -1318,6 +1318,7 @@ dev = [
|
|||
{ name = "pytest-mock" },
|
||||
{ name = "ruff" },
|
||||
{ name = "scipy-stubs" },
|
||||
{ name = "testcontainers" },
|
||||
{ name = "types-aiofiles" },
|
||||
{ name = "types-beautifulsoup4" },
|
||||
{ name = "types-cachetools" },
|
||||
|
|
@ -1500,6 +1501,7 @@ dev = [
|
|||
{ name = "pytest-mock", specifier = "~=3.14.0" },
|
||||
{ name = "ruff", specifier = "~=0.12.3" },
|
||||
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
|
||||
{ name = "testcontainers", specifier = "~=4.10.0" },
|
||||
{ name = "types-aiofiles", specifier = "~=24.1.0" },
|
||||
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
|
||||
{ name = "types-cachetools", specifier = "~=5.5.0" },
|
||||
|
|
@ -1600,6 +1602,20 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docker"
|
||||
version = "7.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
{ name = "requests" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docstring-parser"
|
||||
version = "0.16"
|
||||
|
|
@ -5468,6 +5484,22 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "testcontainers"
|
||||
version = "4.10.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "docker" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "urllib3" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a1/49/9c618aff1c50121d183cdfbc3a4a5cf2727a2cde1893efe6ca55c7009196/testcontainers-4.10.0.tar.gz", hash = "sha256:03f85c3e505d8b4edeb192c72a961cebbcba0dd94344ae778b4a159cb6dcf8d3", size = 63327, upload-time = "2025-04-02T16:13:27.582Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/0a/824b0c1ecf224802125279c3effff2e25ed785ed046e67da6e53d928de4c/testcontainers-4.10.0-py3-none-any.whl", hash = "sha256:31ed1a81238c7e131a2a29df6db8f23717d892b592fa5a1977fd0dcd0c23fc23", size = 107414, upload-time = "2025-04-02T16:13:25.785Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tidb-vector"
|
||||
version = "0.0.9"
|
||||
|
|
|
|||
|
|
@ -15,3 +15,6 @@ dev/pytest/pytest_workflow.sh
|
|||
|
||||
# Unit tests
|
||||
dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
# TestContainers tests
|
||||
dev/pytest/pytest_testcontainers.sh
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
#!/bin/bash
|
||||
set -x
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/../.."
|
||||
|
||||
pytest api/tests/test_containers_integration_tests
|
||||
|
|
@ -653,6 +653,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com
|
|||
TABLESTORE_INSTANCE_NAME=instance-name
|
||||
TABLESTORE_ACCESS_KEY_ID=xxx
|
||||
TABLESTORE_ACCESS_KEY_SECRET=xxx
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false
|
||||
|
||||
# ------------------------------
|
||||
# Knowledge Configuration
|
||||
|
|
|
|||
|
|
@ -312,6 +312,7 @@ x-shared-env: &shared-api-worker-env
|
|||
TABLESTORE_INSTANCE_NAME: ${TABLESTORE_INSTANCE_NAME:-instance-name}
|
||||
TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx}
|
||||
TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx}
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false}
|
||||
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
|
||||
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
|
||||
ETL_TYPE: ${ETL_TYPE:-dify}
|
||||
|
|
|
|||
|
|
@ -265,7 +265,6 @@ export default translation
|
|||
fs.writeFileSync(path.join(testZhDir, 'pages.ts'), file2Content)
|
||||
|
||||
const allEnKeys = await getKeysFromLanguage('en-US')
|
||||
const allZhKeys = await getKeysFromLanguage('zh-Hans')
|
||||
|
||||
// Test file filtering logic
|
||||
const targetFile = 'components'
|
||||
|
|
@ -563,4 +562,201 @@ export default translation
|
|||
expect(enKeys.length - zhKeysExtra.length).toBe(-2) // -2 means 2 extra keys
|
||||
})
|
||||
})
|
||||
|
||||
describe('Auto-remove multiline key-value pairs', () => {
|
||||
// Helper function to simulate removeExtraKeysFromFile logic
|
||||
function removeExtraKeysFromFile(content: string, keysToRemove: string[]): string {
|
||||
const lines = content.split('\n')
|
||||
const linesToRemove: number[] = []
|
||||
|
||||
for (const keyToRemove of keysToRemove) {
|
||||
let targetLineIndex = -1
|
||||
const linesToRemoveForKey: number[] = []
|
||||
|
||||
// Find the key line (simplified for single-level keys in test)
|
||||
for (let i = 0; i < lines.length; i++) {
|
||||
const line = lines[i]
|
||||
const keyPattern = new RegExp(`^\\s*${keyToRemove}\\s*:`)
|
||||
if (keyPattern.test(line)) {
|
||||
targetLineIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if (targetLineIndex !== -1) {
|
||||
linesToRemoveForKey.push(targetLineIndex)
|
||||
|
||||
// Check if this is a multiline key-value pair
|
||||
const keyLine = lines[targetLineIndex]
|
||||
const trimmedKeyLine = keyLine.trim()
|
||||
|
||||
// If key line ends with ":" (not complete value), it's likely multiline
|
||||
if (trimmedKeyLine.endsWith(':') && !trimmedKeyLine.includes('{') && !trimmedKeyLine.match(/:\s*['"`]/)) {
|
||||
// Find the value lines that belong to this key
|
||||
let currentLine = targetLineIndex + 1
|
||||
let foundValue = false
|
||||
|
||||
while (currentLine < lines.length) {
|
||||
const line = lines[currentLine]
|
||||
const trimmed = line.trim()
|
||||
|
||||
// Skip empty lines
|
||||
if (trimmed === '') {
|
||||
currentLine++
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this line starts a new key (indicates end of current value)
|
||||
if (trimmed.match(/^\w+\s*:/))
|
||||
break
|
||||
|
||||
// Check if this line is part of the value
|
||||
if (trimmed.startsWith('\'') || trimmed.startsWith('"') || trimmed.startsWith('`') || foundValue) {
|
||||
linesToRemoveForKey.push(currentLine)
|
||||
foundValue = true
|
||||
|
||||
// Check if this line ends the value (ends with quote and comma/no comma)
|
||||
if ((trimmed.endsWith('\',') || trimmed.endsWith('",') || trimmed.endsWith('`,')
|
||||
|| trimmed.endsWith('\'') || trimmed.endsWith('"') || trimmed.endsWith('`'))
|
||||
&& !trimmed.startsWith('//'))
|
||||
break
|
||||
}
|
||||
else {
|
||||
break
|
||||
}
|
||||
|
||||
currentLine++
|
||||
}
|
||||
}
|
||||
|
||||
linesToRemove.push(...linesToRemoveForKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove duplicates and sort in reverse order
|
||||
const uniqueLinesToRemove = [...new Set(linesToRemove)].sort((a, b) => b - a)
|
||||
|
||||
for (const lineIndex of uniqueLinesToRemove)
|
||||
lines.splice(lineIndex, 1)
|
||||
|
||||
return lines.join('\n')
|
||||
}
|
||||
|
||||
it('should remove single-line key-value pairs correctly', () => {
|
||||
const content = `const translation = {
|
||||
keepThis: 'This should stay',
|
||||
removeThis: 'This should be removed',
|
||||
alsoKeep: 'This should also stay',
|
||||
}
|
||||
|
||||
export default translation`
|
||||
|
||||
const result = removeExtraKeysFromFile(content, ['removeThis'])
|
||||
|
||||
expect(result).toContain('keepThis: \'This should stay\'')
|
||||
expect(result).toContain('alsoKeep: \'This should also stay\'')
|
||||
expect(result).not.toContain('removeThis: \'This should be removed\'')
|
||||
})
|
||||
|
||||
it('should remove multiline key-value pairs completely', () => {
|
||||
const content = `const translation = {
|
||||
keepThis: 'This should stay',
|
||||
removeMultiline:
|
||||
'This is a multiline value that should be removed completely',
|
||||
alsoKeep: 'This should also stay',
|
||||
}
|
||||
|
||||
export default translation`
|
||||
|
||||
const result = removeExtraKeysFromFile(content, ['removeMultiline'])
|
||||
|
||||
expect(result).toContain('keepThis: \'This should stay\'')
|
||||
expect(result).toContain('alsoKeep: \'This should also stay\'')
|
||||
expect(result).not.toContain('removeMultiline:')
|
||||
expect(result).not.toContain('This is a multiline value that should be removed completely')
|
||||
})
|
||||
|
||||
it('should handle mixed single-line and multiline removals', () => {
|
||||
const content = `const translation = {
|
||||
keepThis: 'Keep this',
|
||||
removeSingle: 'Remove this single line',
|
||||
removeMultiline:
|
||||
'Remove this multiline value',
|
||||
anotherMultiline:
|
||||
'Another multiline that spans multiple lines',
|
||||
keepAnother: 'Keep this too',
|
||||
}
|
||||
|
||||
export default translation`
|
||||
|
||||
const result = removeExtraKeysFromFile(content, ['removeSingle', 'removeMultiline', 'anotherMultiline'])
|
||||
|
||||
expect(result).toContain('keepThis: \'Keep this\'')
|
||||
expect(result).toContain('keepAnother: \'Keep this too\'')
|
||||
expect(result).not.toContain('removeSingle:')
|
||||
expect(result).not.toContain('removeMultiline:')
|
||||
expect(result).not.toContain('anotherMultiline:')
|
||||
expect(result).not.toContain('Remove this single line')
|
||||
expect(result).not.toContain('Remove this multiline value')
|
||||
expect(result).not.toContain('Another multiline that spans multiple lines')
|
||||
})
|
||||
|
||||
it('should properly detect multiline vs single-line patterns', () => {
|
||||
const multilineContent = `const translation = {
|
||||
singleLine: 'This is single line',
|
||||
multilineKey:
|
||||
'This is multiline',
|
||||
keyWithColon: 'Value with: colon inside',
|
||||
objectKey: {
|
||||
nested: 'value'
|
||||
},
|
||||
}
|
||||
|
||||
export default translation`
|
||||
|
||||
// Test that single line with colon in value is not treated as multiline
|
||||
const result1 = removeExtraKeysFromFile(multilineContent, ['keyWithColon'])
|
||||
expect(result1).not.toContain('keyWithColon:')
|
||||
expect(result1).not.toContain('Value with: colon inside')
|
||||
|
||||
// Test that true multiline is handled correctly
|
||||
const result2 = removeExtraKeysFromFile(multilineContent, ['multilineKey'])
|
||||
expect(result2).not.toContain('multilineKey:')
|
||||
expect(result2).not.toContain('This is multiline')
|
||||
|
||||
// Test that object key removal works (note: this is a simplified test)
|
||||
// In real scenario, object removal would be more complex
|
||||
const result3 = removeExtraKeysFromFile(multilineContent, ['objectKey'])
|
||||
expect(result3).not.toContain('objectKey: {')
|
||||
// Note: Our simplified test function doesn't handle nested object removal perfectly
|
||||
// This is acceptable as it's testing the main multiline string removal functionality
|
||||
})
|
||||
|
||||
it('should handle real-world Polish translation structure', () => {
|
||||
const polishContent = `const translation = {
|
||||
createApp: 'UTWÓRZ APLIKACJĘ',
|
||||
newApp: {
|
||||
captionAppType: 'Jaki typ aplikacji chcesz stworzyć?',
|
||||
chatbotDescription:
|
||||
'Zbuduj aplikację opartą na czacie. Ta aplikacja używa formatu pytań i odpowiedzi.',
|
||||
agentDescription:
|
||||
'Zbuduj inteligentnego agenta, który może autonomicznie wybierać narzędzia.',
|
||||
basic: 'Podstawowy',
|
||||
},
|
||||
}
|
||||
|
||||
export default translation`
|
||||
|
||||
const result = removeExtraKeysFromFile(polishContent, ['captionAppType', 'chatbotDescription', 'agentDescription'])
|
||||
|
||||
expect(result).toContain('createApp: \'UTWÓRZ APLIKACJĘ\'')
|
||||
expect(result).toContain('basic: \'Podstawowy\'')
|
||||
expect(result).not.toContain('captionAppType:')
|
||||
expect(result).not.toContain('chatbotDescription:')
|
||||
expect(result).not.toContain('agentDescription:')
|
||||
expect(result).not.toContain('Jaki typ aplikacji')
|
||||
expect(result).not.toContain('Zbuduj aplikację opartą na czacie')
|
||||
expect(result).not.toContain('Zbuduj inteligentnego agenta')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -0,0 +1,305 @@
|
|||
/**
|
||||
* Document Detail Navigation Fix Verification Test
|
||||
*
|
||||
* This test specifically validates that the backToPrev function in the document detail
|
||||
* component correctly preserves pagination and filter states.
|
||||
*/
|
||||
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useDocumentDetail, useDocumentMetadata } from '@/service/knowledge/use-document'
|
||||
|
||||
// Mock Next.js router
|
||||
const mockPush = jest.fn()
|
||||
jest.mock('next/navigation', () => ({
|
||||
useRouter: jest.fn(() => ({
|
||||
push: mockPush,
|
||||
})),
|
||||
}))
|
||||
|
||||
// Mock the document service hooks
|
||||
jest.mock('@/service/knowledge/use-document', () => ({
|
||||
useDocumentDetail: jest.fn(),
|
||||
useDocumentMetadata: jest.fn(),
|
||||
useInvalidDocumentList: jest.fn(() => jest.fn()),
|
||||
}))
|
||||
|
||||
// Mock other dependencies
|
||||
jest.mock('@/context/dataset-detail', () => ({
|
||||
useDatasetDetailContext: jest.fn(() => [null]),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/use-base', () => ({
|
||||
useInvalid: jest.fn(() => jest.fn()),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/knowledge/use-segment', () => ({
|
||||
useSegmentListKey: jest.fn(),
|
||||
useChildSegmentListKey: jest.fn(),
|
||||
}))
|
||||
|
||||
// Create a minimal version of the DocumentDetail component that includes our fix
|
||||
const DocumentDetailWithFix = ({ datasetId, documentId }: { datasetId: string; documentId: string }) => {
|
||||
const router = useRouter()
|
||||
|
||||
// This is the FIXED implementation from detail/index.tsx
|
||||
const backToPrev = () => {
|
||||
// Preserve pagination and filter states when navigating back
|
||||
const searchParams = new URLSearchParams(window.location.search)
|
||||
const queryString = searchParams.toString()
|
||||
const separator = queryString ? '?' : ''
|
||||
const backPath = `/datasets/${datasetId}/documents${separator}${queryString}`
|
||||
router.push(backPath)
|
||||
}
|
||||
|
||||
return (
|
||||
<div data-testid="document-detail-fixed">
|
||||
<button data-testid="back-button-fixed" onClick={backToPrev}>
|
||||
Back to Documents
|
||||
</button>
|
||||
<div data-testid="document-info">
|
||||
Dataset: {datasetId}, Document: {documentId}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
describe('Document Detail Navigation Fix Verification', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
|
||||
// Mock successful API responses
|
||||
;(useDocumentDetail as jest.Mock).mockReturnValue({
|
||||
data: {
|
||||
id: 'doc-123',
|
||||
name: 'Test Document',
|
||||
display_status: 'available',
|
||||
enabled: true,
|
||||
archived: false,
|
||||
},
|
||||
error: null,
|
||||
})
|
||||
|
||||
;(useDocumentMetadata as jest.Mock).mockReturnValue({
|
||||
data: null,
|
||||
error: null,
|
||||
})
|
||||
})
|
||||
|
||||
describe('Query Parameter Preservation', () => {
|
||||
test('preserves pagination state (page 3, limit 25)', () => {
|
||||
// Simulate user coming from page 3 with 25 items per page
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=3&limit=25',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
// User clicks back button
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
// Should preserve the pagination state
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=3&limit=25')
|
||||
|
||||
console.log('✅ Pagination state preserved: page=3&limit=25')
|
||||
})
|
||||
|
||||
test('preserves search keyword and filters', () => {
|
||||
// Simulate user with search and filters applied
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=2&limit=10&keyword=API%20documentation&status=active',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
// Should preserve all query parameters
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=2&limit=10&keyword=API+documentation&status=active')
|
||||
|
||||
console.log('✅ Search and filters preserved')
|
||||
})
|
||||
|
||||
test('handles complex query parameters with special characters', () => {
|
||||
// Test with complex query string including encoded characters
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=1&limit=50&keyword=test%20%26%20debug&sort=name&order=desc&filter=%7B%22type%22%3A%22pdf%22%7D',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
// URLSearchParams will normalize the encoding, but preserve all parameters
|
||||
const expectedCall = mockPush.mock.calls[0][0]
|
||||
expect(expectedCall).toMatch(/^\/datasets\/dataset-123\/documents\?/)
|
||||
expect(expectedCall).toMatch(/page=1/)
|
||||
expect(expectedCall).toMatch(/limit=50/)
|
||||
expect(expectedCall).toMatch(/keyword=test/)
|
||||
expect(expectedCall).toMatch(/sort=name/)
|
||||
expect(expectedCall).toMatch(/order=desc/)
|
||||
|
||||
console.log('✅ Complex query parameters handled:', expectedCall)
|
||||
})
|
||||
|
||||
test('handles empty query parameters gracefully', () => {
|
||||
// No query parameters in URL
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
// Should navigate to clean documents URL
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents')
|
||||
|
||||
console.log('✅ Empty parameters handled gracefully')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Different Dataset IDs', () => {
|
||||
test('works with different dataset identifiers', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=5&limit=10',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
// Test with different dataset ID format
|
||||
render(<DocumentDetailWithFix datasetId="ds-prod-2024-001" documentId="doc-456" />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/ds-prod-2024-001/documents?page=5&limit=10')
|
||||
|
||||
console.log('✅ Works with different dataset ID formats')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Real User Scenarios', () => {
|
||||
test('scenario: user searches, goes to page 3, views document, clicks back', () => {
|
||||
// User searched for "API" and navigated to page 3
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?keyword=API&page=3&limit=10',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="main-dataset" documentId="api-doc-123" />)
|
||||
|
||||
// User decides to go back to continue browsing
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
// Should return to page 3 of API search results
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/main-dataset/documents?keyword=API&page=3&limit=10')
|
||||
|
||||
console.log('✅ Real user scenario: search + pagination preserved')
|
||||
})
|
||||
|
||||
test('scenario: user applies multiple filters, goes to document, returns', () => {
|
||||
// User has applied multiple filters and is on page 2
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=2&limit=25&status=active&type=pdf&sort=created_at&order=desc',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="filtered-dataset" documentId="filtered-doc" />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
// All filters should be preserved
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-dataset/documents?page=2&limit=25&status=active&type=pdf&sort=created_at&order=desc')
|
||||
|
||||
console.log('✅ Complex filtering scenario preserved')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling and Edge Cases', () => {
|
||||
test('handles malformed query parameters gracefully', () => {
|
||||
// Test with potentially problematic query string
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=invalid&limit=&keyword=test&=emptykey&malformed',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
// Should not throw errors
|
||||
expect(() => {
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
}).not.toThrow()
|
||||
|
||||
// Should still attempt navigation (URLSearchParams will clean up the parameters)
|
||||
expect(mockPush).toHaveBeenCalled()
|
||||
const navigationPath = mockPush.mock.calls[0][0]
|
||||
expect(navigationPath).toMatch(/^\/datasets\/dataset-123\/documents/)
|
||||
|
||||
console.log('✅ Malformed parameters handled gracefully:', navigationPath)
|
||||
})
|
||||
|
||||
test('handles very long query strings', () => {
|
||||
// Test with a very long query string
|
||||
const longKeyword = 'a'.repeat(1000)
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: `?page=1&keyword=${longKeyword}`,
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
expect(() => {
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
}).not.toThrow()
|
||||
|
||||
expect(mockPush).toHaveBeenCalled()
|
||||
|
||||
console.log('✅ Long query strings handled')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Performance Verification', () => {
|
||||
test('navigation function executes quickly', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=1&limit=10&keyword=test',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
const startTime = performance.now()
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
const endTime = performance.now()
|
||||
|
||||
const executionTime = endTime - startTime
|
||||
|
||||
// Should execute in less than 10ms
|
||||
expect(executionTime).toBeLessThan(10)
|
||||
|
||||
console.log(`⚡ Navigation execution time: ${executionTime.toFixed(2)}ms`)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Document List Sorting Tests
|
||||
*/
|
||||
|
||||
describe('Document List Sorting', () => {
|
||||
const mockDocuments = [
|
||||
{ id: '1', name: 'Beta.pdf', word_count: 500, hit_count: 10, created_at: 1699123456 },
|
||||
{ id: '2', name: 'Alpha.txt', word_count: 200, hit_count: 25, created_at: 1699123400 },
|
||||
{ id: '3', name: 'Gamma.docx', word_count: 800, hit_count: 5, created_at: 1699123500 },
|
||||
]
|
||||
|
||||
const sortDocuments = (docs: any[], field: string, order: 'asc' | 'desc') => {
|
||||
return [...docs].sort((a, b) => {
|
||||
let aValue: any
|
||||
let bValue: any
|
||||
|
||||
switch (field) {
|
||||
case 'name':
|
||||
aValue = a.name?.toLowerCase() || ''
|
||||
bValue = b.name?.toLowerCase() || ''
|
||||
break
|
||||
case 'word_count':
|
||||
aValue = a.word_count || 0
|
||||
bValue = b.word_count || 0
|
||||
break
|
||||
case 'hit_count':
|
||||
aValue = a.hit_count || 0
|
||||
bValue = b.hit_count || 0
|
||||
break
|
||||
case 'created_at':
|
||||
aValue = a.created_at
|
||||
bValue = b.created_at
|
||||
break
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
|
||||
if (field === 'name') {
|
||||
const result = aValue.localeCompare(bValue)
|
||||
return order === 'asc' ? result : -result
|
||||
}
|
||||
else {
|
||||
const result = aValue - bValue
|
||||
return order === 'asc' ? result : -result
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
test('sorts by name descending (default for UI consistency)', () => {
|
||||
const sorted = sortDocuments(mockDocuments, 'name', 'desc')
|
||||
expect(sorted.map(doc => doc.name)).toEqual(['Gamma.docx', 'Beta.pdf', 'Alpha.txt'])
|
||||
})
|
||||
|
||||
test('sorts by name ascending (after toggle)', () => {
|
||||
const sorted = sortDocuments(mockDocuments, 'name', 'asc')
|
||||
expect(sorted.map(doc => doc.name)).toEqual(['Alpha.txt', 'Beta.pdf', 'Gamma.docx'])
|
||||
})
|
||||
|
||||
test('sorts by word_count descending', () => {
|
||||
const sorted = sortDocuments(mockDocuments, 'word_count', 'desc')
|
||||
expect(sorted.map(doc => doc.word_count)).toEqual([800, 500, 200])
|
||||
})
|
||||
|
||||
test('sorts by hit_count descending', () => {
|
||||
const sorted = sortDocuments(mockDocuments, 'hit_count', 'desc')
|
||||
expect(sorted.map(doc => doc.hit_count)).toEqual([25, 10, 5])
|
||||
})
|
||||
|
||||
test('sorts by created_at descending (newest first)', () => {
|
||||
const sorted = sortDocuments(mockDocuments, 'created_at', 'desc')
|
||||
expect(sorted.map(doc => doc.created_at)).toEqual([1699123500, 1699123456, 1699123400])
|
||||
})
|
||||
|
||||
test('handles empty values correctly', () => {
|
||||
const docsWithEmpty = [
|
||||
{ id: '1', name: 'Test', word_count: 100, hit_count: 5, created_at: 1699123456 },
|
||||
{ id: '2', name: 'Empty', word_count: 0, hit_count: 0, created_at: 1699123400 },
|
||||
]
|
||||
|
||||
const sorted = sortDocuments(docsWithEmpty, 'word_count', 'desc')
|
||||
expect(sorted.map(doc => doc.word_count)).toEqual([100, 0])
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,290 @@
|
|||
/**
|
||||
* Navigation Utilities Test
|
||||
*
|
||||
* Tests for the navigation utility functions to ensure they handle
|
||||
* query parameter preservation correctly across different scenarios.
|
||||
*/
|
||||
|
||||
import {
|
||||
createBackNavigation,
|
||||
createNavigationPath,
|
||||
createNavigationPathWithParams,
|
||||
datasetNavigation,
|
||||
extractQueryParams,
|
||||
mergeQueryParams,
|
||||
} from '@/utils/navigation'
|
||||
|
||||
// Mock router for testing
|
||||
const mockPush = jest.fn()
|
||||
const mockRouter = { push: mockPush }
|
||||
|
||||
describe('Navigation Utilities', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('createNavigationPath', () => {
|
||||
test('preserves query parameters by default', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&limit=10&keyword=test' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const path = createNavigationPath('/datasets/123/documents')
|
||||
expect(path).toBe('/datasets/123/documents?page=3&limit=10&keyword=test')
|
||||
})
|
||||
|
||||
test('returns clean path when preserveParams is false', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&limit=10' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const path = createNavigationPath('/datasets/123/documents', false)
|
||||
expect(path).toBe('/datasets/123/documents')
|
||||
})
|
||||
|
||||
test('handles empty query parameters', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const path = createNavigationPath('/datasets/123/documents')
|
||||
expect(path).toBe('/datasets/123/documents')
|
||||
})
|
||||
|
||||
test('handles errors gracefully', () => {
|
||||
// Mock window.location to throw an error
|
||||
Object.defineProperty(window, 'location', {
|
||||
get: () => {
|
||||
throw new Error('Location access denied')
|
||||
},
|
||||
configurable: true,
|
||||
})
|
||||
|
||||
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
|
||||
const path = createNavigationPath('/datasets/123/documents')
|
||||
|
||||
expect(path).toBe('/datasets/123/documents')
|
||||
expect(consoleSpy).toHaveBeenCalledWith('Failed to preserve query parameters:', expect.any(Error))
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('createBackNavigation', () => {
|
||||
test('creates function that navigates with preserved params', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=2&limit=25' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const backFn = createBackNavigation(mockRouter, '/datasets/123/documents')
|
||||
backFn()
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/123/documents?page=2&limit=25')
|
||||
})
|
||||
|
||||
test('creates function that navigates without params when specified', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=2&limit=25' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const backFn = createBackNavigation(mockRouter, '/datasets/123/documents', false)
|
||||
backFn()
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/123/documents')
|
||||
})
|
||||
})
|
||||
|
||||
describe('extractQueryParams', () => {
|
||||
test('extracts specified parameters', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&limit=10&keyword=test&other=value' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const params = extractQueryParams(['page', 'limit', 'keyword'])
|
||||
expect(params).toEqual({
|
||||
page: '3',
|
||||
limit: '10',
|
||||
keyword: 'test',
|
||||
})
|
||||
})
|
||||
|
||||
test('handles missing parameters', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const params = extractQueryParams(['page', 'limit', 'missing'])
|
||||
expect(params).toEqual({
|
||||
page: '3',
|
||||
})
|
||||
})
|
||||
|
||||
test('handles errors gracefully', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
get: () => {
|
||||
throw new Error('Location access denied')
|
||||
},
|
||||
configurable: true,
|
||||
})
|
||||
|
||||
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
|
||||
const params = extractQueryParams(['page', 'limit'])
|
||||
|
||||
expect(params).toEqual({})
|
||||
expect(consoleSpy).toHaveBeenCalledWith('Failed to extract query parameters:', expect.any(Error))
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('createNavigationPathWithParams', () => {
|
||||
test('creates path with specified parameters', () => {
|
||||
const path = createNavigationPathWithParams('/datasets/123/documents', {
|
||||
page: 1,
|
||||
limit: 25,
|
||||
keyword: 'search term',
|
||||
})
|
||||
|
||||
expect(path).toBe('/datasets/123/documents?page=1&limit=25&keyword=search+term')
|
||||
})
|
||||
|
||||
test('filters out empty values', () => {
|
||||
const path = createNavigationPathWithParams('/datasets/123/documents', {
|
||||
page: 1,
|
||||
limit: '',
|
||||
keyword: 'test',
|
||||
empty: null,
|
||||
undefined,
|
||||
})
|
||||
|
||||
expect(path).toBe('/datasets/123/documents?page=1&keyword=test')
|
||||
})
|
||||
|
||||
test('handles errors gracefully', () => {
|
||||
// Mock URLSearchParams to throw an error
|
||||
const originalURLSearchParams = globalThis.URLSearchParams
|
||||
globalThis.URLSearchParams = jest.fn(() => {
|
||||
throw new Error('URLSearchParams error')
|
||||
}) as any
|
||||
|
||||
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
|
||||
const path = createNavigationPathWithParams('/datasets/123/documents', { page: 1 })
|
||||
|
||||
expect(path).toBe('/datasets/123/documents')
|
||||
expect(consoleSpy).toHaveBeenCalledWith('Failed to create navigation path with params:', expect.any(Error))
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
globalThis.URLSearchParams = originalURLSearchParams
|
||||
})
|
||||
})
|
||||
|
||||
describe('mergeQueryParams', () => {
|
||||
test('merges new params with existing ones', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&limit=10' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const merged = mergeQueryParams({ keyword: 'test', page: '1' })
|
||||
const result = merged.toString()
|
||||
|
||||
expect(result).toContain('page=1') // overridden
|
||||
expect(result).toContain('limit=10') // preserved
|
||||
expect(result).toContain('keyword=test') // added
|
||||
})
|
||||
|
||||
test('removes parameters when value is null', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&limit=10&keyword=test' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const merged = mergeQueryParams({ keyword: null, filter: 'active' })
|
||||
const result = merged.toString()
|
||||
|
||||
expect(result).toContain('page=3')
|
||||
expect(result).toContain('limit=10')
|
||||
expect(result).not.toContain('keyword')
|
||||
expect(result).toContain('filter=active')
|
||||
})
|
||||
|
||||
test('creates fresh params when preserveExisting is false', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&limit=10' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const merged = mergeQueryParams({ keyword: 'test' }, false)
|
||||
const result = merged.toString()
|
||||
|
||||
expect(result).toBe('keyword=test')
|
||||
})
|
||||
})
|
||||
|
||||
describe('datasetNavigation', () => {
|
||||
test('backToDocuments creates correct navigation function', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=2&limit=25' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const backFn = datasetNavigation.backToDocuments(mockRouter, 'dataset-123')
|
||||
backFn()
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=2&limit=25')
|
||||
})
|
||||
|
||||
test('toDocumentDetail creates correct navigation function', () => {
|
||||
const detailFn = datasetNavigation.toDocumentDetail(mockRouter, 'dataset-123', 'doc-456')
|
||||
detailFn()
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents/doc-456')
|
||||
})
|
||||
|
||||
test('toDocumentSettings creates correct navigation function', () => {
|
||||
const settingsFn = datasetNavigation.toDocumentSettings(mockRouter, 'dataset-123', 'doc-456')
|
||||
settingsFn()
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents/doc-456/settings')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Real-world Integration Scenarios', () => {
|
||||
test('complete user workflow: list -> detail -> back', () => {
|
||||
// User starts on page 3 with search
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&keyword=API&limit=25' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
// Create back navigation function (as would be done in detail component)
|
||||
const backToDocuments = datasetNavigation.backToDocuments(mockRouter, 'main-dataset')
|
||||
|
||||
// User clicks back
|
||||
backToDocuments()
|
||||
|
||||
// Should return to exact same list state
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/main-dataset/documents?page=3&keyword=API&limit=25')
|
||||
})
|
||||
|
||||
test('user applies filters then views document', () => {
|
||||
// Complex filter state
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const backFn = createBackNavigation(mockRouter, '/datasets/filtered-set/documents')
|
||||
backFn()
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-set/documents?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,396 @@
|
|||
/**
|
||||
* Unified Tags Editing - Pure Logic Tests
|
||||
*
|
||||
* This test file validates the core business logic and state management
|
||||
* behaviors introduced in the recent 7 commits without requiring complex mocks.
|
||||
*/
|
||||
|
||||
describe('Unified Tags Editing - Pure Logic Tests', () => {
|
||||
describe('Tag State Management Logic', () => {
|
||||
it('should detect when tag values have changed', () => {
|
||||
const currentValue = ['tag1', 'tag2']
|
||||
const newSelectedTagIDs = ['tag1', 'tag3']
|
||||
|
||||
// This is the valueNotChanged logic from TagSelector component
|
||||
const valueNotChanged
|
||||
= currentValue.length === newSelectedTagIDs.length
|
||||
&& currentValue.every(v => newSelectedTagIDs.includes(v))
|
||||
&& newSelectedTagIDs.every(v => currentValue.includes(v))
|
||||
|
||||
expect(valueNotChanged).toBe(false)
|
||||
})
|
||||
|
||||
it('should correctly identify unchanged tag values', () => {
|
||||
const currentValue = ['tag1', 'tag2']
|
||||
const newSelectedTagIDs = ['tag2', 'tag1'] // Same tags, different order
|
||||
|
||||
const valueNotChanged
|
||||
= currentValue.length === newSelectedTagIDs.length
|
||||
&& currentValue.every(v => newSelectedTagIDs.includes(v))
|
||||
&& newSelectedTagIDs.every(v => currentValue.includes(v))
|
||||
|
||||
expect(valueNotChanged).toBe(true)
|
||||
})
|
||||
|
||||
it('should calculate correct tag operations for binding/unbinding', () => {
|
||||
const currentValue = ['tag1', 'tag2']
|
||||
const selectedTagIDs = ['tag2', 'tag3']
|
||||
|
||||
// This is the handleValueChange logic from TagSelector
|
||||
const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v))
|
||||
const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v))
|
||||
|
||||
expect(addTagIDs).toEqual(['tag3'])
|
||||
expect(removeTagIDs).toEqual(['tag1'])
|
||||
})
|
||||
|
||||
it('should handle empty tag arrays correctly', () => {
|
||||
const currentValue: string[] = []
|
||||
const selectedTagIDs = ['tag1']
|
||||
|
||||
const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v))
|
||||
const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v))
|
||||
|
||||
expect(addTagIDs).toEqual(['tag1'])
|
||||
expect(removeTagIDs).toEqual([])
|
||||
expect(currentValue.length).toBe(0) // Verify empty array usage
|
||||
})
|
||||
|
||||
it('should handle removing all tags', () => {
|
||||
const currentValue = ['tag1', 'tag2']
|
||||
const selectedTagIDs: string[] = []
|
||||
|
||||
const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v))
|
||||
const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v))
|
||||
|
||||
expect(addTagIDs).toEqual([])
|
||||
expect(removeTagIDs).toEqual(['tag1', 'tag2'])
|
||||
expect(selectedTagIDs.length).toBe(0) // Verify empty array usage
|
||||
})
|
||||
})
|
||||
|
||||
describe('Fallback Logic (from layout-main.tsx)', () => {
|
||||
it('should trigger fallback when tags are missing or empty', () => {
|
||||
const appDetailWithoutTags = { tags: [] }
|
||||
const appDetailWithTags = { tags: [{ id: 'tag1' }] }
|
||||
const appDetailWithUndefinedTags = { tags: undefined as any }
|
||||
|
||||
// This simulates the condition in layout-main.tsx
|
||||
const shouldFallback1 = !appDetailWithoutTags.tags || appDetailWithoutTags.tags.length === 0
|
||||
const shouldFallback2 = !appDetailWithTags.tags || appDetailWithTags.tags.length === 0
|
||||
const shouldFallback3 = !appDetailWithUndefinedTags.tags || appDetailWithUndefinedTags.tags.length === 0
|
||||
|
||||
expect(shouldFallback1).toBe(true) // Empty array should trigger fallback
|
||||
expect(shouldFallback2).toBe(false) // Has tags, no fallback needed
|
||||
expect(shouldFallback3).toBe(true) // Undefined tags should trigger fallback
|
||||
})
|
||||
|
||||
it('should preserve tags when fallback succeeds', () => {
|
||||
const originalAppDetail = { tags: [] as any[] }
|
||||
const fallbackResult = { tags: [{ id: 'tag1', name: 'fallback-tag' }] }
|
||||
|
||||
// This simulates the successful fallback in layout-main.tsx
|
||||
if (fallbackResult?.tags)
|
||||
originalAppDetail.tags = fallbackResult.tags
|
||||
|
||||
expect(originalAppDetail.tags).toEqual(fallbackResult.tags)
|
||||
expect(originalAppDetail.tags.length).toBe(1)
|
||||
})
|
||||
|
||||
it('should continue with empty tags when fallback fails', () => {
|
||||
const originalAppDetail: { tags: any[] } = { tags: [] }
|
||||
const fallbackResult: { tags?: any[] } | null = null
|
||||
|
||||
// This simulates fallback failure in layout-main.tsx
|
||||
if (fallbackResult?.tags)
|
||||
originalAppDetail.tags = fallbackResult.tags
|
||||
|
||||
expect(originalAppDetail.tags).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('TagSelector Auto-initialization Logic', () => {
|
||||
it('should trigger getTagList when tagList is empty', () => {
|
||||
const tagList: any[] = []
|
||||
let getTagListCalled = false
|
||||
const getTagList = () => {
|
||||
getTagListCalled = true
|
||||
}
|
||||
|
||||
// This simulates the useEffect in TagSelector
|
||||
if (tagList.length === 0)
|
||||
getTagList()
|
||||
|
||||
expect(getTagListCalled).toBe(true)
|
||||
})
|
||||
|
||||
it('should not trigger getTagList when tagList has items', () => {
|
||||
const tagList = [{ id: 'tag1', name: 'existing-tag' }]
|
||||
let getTagListCalled = false
|
||||
const getTagList = () => {
|
||||
getTagListCalled = true
|
||||
}
|
||||
|
||||
// This simulates the useEffect in TagSelector
|
||||
if (tagList.length === 0)
|
||||
getTagList()
|
||||
|
||||
expect(getTagListCalled).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('State Initialization Patterns', () => {
|
||||
it('should maintain AppCard tag state pattern', () => {
|
||||
const app = { tags: [{ id: 'tag1', name: 'test' }] }
|
||||
|
||||
// Original AppCard pattern: useState(app.tags)
|
||||
const initialTags = app.tags
|
||||
expect(Array.isArray(initialTags)).toBe(true)
|
||||
expect(initialTags.length).toBe(1)
|
||||
expect(initialTags).toBe(app.tags) // Reference equality for AppCard
|
||||
})
|
||||
|
||||
it('should maintain AppInfo tag state pattern', () => {
|
||||
const appDetail = { tags: [{ id: 'tag1', name: 'test' }] }
|
||||
|
||||
// New AppInfo pattern: useState(appDetail?.tags || [])
|
||||
const initialTags = appDetail?.tags || []
|
||||
expect(Array.isArray(initialTags)).toBe(true)
|
||||
expect(initialTags.length).toBe(1)
|
||||
})
|
||||
|
||||
it('should handle undefined appDetail gracefully in AppInfo', () => {
|
||||
const appDetail = undefined
|
||||
|
||||
// AppInfo pattern with undefined appDetail
|
||||
const initialTags = (appDetail as any)?.tags || []
|
||||
expect(Array.isArray(initialTags)).toBe(true)
|
||||
expect(initialTags.length).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('CSS Class and Layout Logic', () => {
|
||||
it('should apply correct minimum width condition', () => {
|
||||
const minWidth = 'true'
|
||||
|
||||
// This tests the minWidth logic in TagSelector
|
||||
const shouldApplyMinWidth = minWidth && '!min-w-80'
|
||||
expect(shouldApplyMinWidth).toBe('!min-w-80')
|
||||
})
|
||||
|
||||
it('should not apply minimum width when not specified', () => {
|
||||
const minWidth = undefined
|
||||
|
||||
const shouldApplyMinWidth = minWidth && '!min-w-80'
|
||||
expect(shouldApplyMinWidth).toBeFalsy()
|
||||
})
|
||||
|
||||
it('should handle overflow layout classes correctly', () => {
|
||||
// This tests the layout pattern from AppCard and new AppInfo
|
||||
const overflowLayoutClasses = {
|
||||
container: 'flex w-0 grow items-center',
|
||||
inner: 'w-full',
|
||||
truncate: 'truncate',
|
||||
}
|
||||
|
||||
expect(overflowLayoutClasses.container).toContain('w-0 grow')
|
||||
expect(overflowLayoutClasses.inner).toContain('w-full')
|
||||
expect(overflowLayoutClasses.truncate).toBe('truncate')
|
||||
})
|
||||
})
|
||||
|
||||
describe('fetchAppWithTags Service Logic', () => {
|
||||
it('should correctly find app by ID from app list', () => {
|
||||
const appList = [
|
||||
{ id: 'app1', name: 'App 1', tags: [] },
|
||||
{ id: 'test-app-id', name: 'Test App', tags: [{ id: 'tag1', name: 'test' }] },
|
||||
{ id: 'app3', name: 'App 3', tags: [] },
|
||||
]
|
||||
const targetAppId = 'test-app-id'
|
||||
|
||||
// This simulates the logic in fetchAppWithTags
|
||||
const foundApp = appList.find(app => app.id === targetAppId)
|
||||
|
||||
expect(foundApp).toBeDefined()
|
||||
expect(foundApp?.id).toBe('test-app-id')
|
||||
expect(foundApp?.tags.length).toBe(1)
|
||||
})
|
||||
|
||||
it('should return null when app not found', () => {
|
||||
const appList = [
|
||||
{ id: 'app1', name: 'App 1' },
|
||||
{ id: 'app2', name: 'App 2' },
|
||||
]
|
||||
const targetAppId = 'nonexistent-app'
|
||||
|
||||
const foundApp = appList.find(app => app.id === targetAppId) || null
|
||||
|
||||
expect(foundApp).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle empty app list', () => {
|
||||
const appList: any[] = []
|
||||
const targetAppId = 'any-app'
|
||||
|
||||
const foundApp = appList.find(app => app.id === targetAppId) || null
|
||||
|
||||
expect(foundApp).toBeNull()
|
||||
expect(appList.length).toBe(0) // Verify empty array usage
|
||||
})
|
||||
})
|
||||
|
||||
describe('Data Structure Validation', () => {
|
||||
it('should maintain consistent tag data structure', () => {
|
||||
const tag = {
|
||||
id: 'tag1',
|
||||
name: 'test-tag',
|
||||
type: 'app',
|
||||
binding_count: 1,
|
||||
}
|
||||
|
||||
expect(tag).toHaveProperty('id')
|
||||
expect(tag).toHaveProperty('name')
|
||||
expect(tag).toHaveProperty('type')
|
||||
expect(tag).toHaveProperty('binding_count')
|
||||
expect(tag.type).toBe('app')
|
||||
expect(typeof tag.binding_count).toBe('number')
|
||||
})
|
||||
|
||||
it('should handle tag arrays correctly', () => {
|
||||
const tags = [
|
||||
{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 1 },
|
||||
{ id: 'tag2', name: 'Tag 2', type: 'app', binding_count: 0 },
|
||||
]
|
||||
|
||||
expect(Array.isArray(tags)).toBe(true)
|
||||
expect(tags.length).toBe(2)
|
||||
expect(tags.every(tag => tag.type === 'app')).toBe(true)
|
||||
})
|
||||
|
||||
it('should validate app data structure with tags', () => {
|
||||
const app = {
|
||||
id: 'test-app',
|
||||
name: 'Test App',
|
||||
tags: [
|
||||
{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 1 },
|
||||
],
|
||||
}
|
||||
|
||||
expect(app).toHaveProperty('id')
|
||||
expect(app).toHaveProperty('name')
|
||||
expect(app).toHaveProperty('tags')
|
||||
expect(Array.isArray(app.tags)).toBe(true)
|
||||
expect(app.tags.length).toBe(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Performance and Edge Cases', () => {
|
||||
it('should handle large tag arrays efficiently', () => {
|
||||
const largeTags = Array.from({ length: 100 }, (_, i) => `tag${i}`)
|
||||
const selectedTags = ['tag1', 'tag50', 'tag99']
|
||||
|
||||
// Performance test: filtering should be efficient
|
||||
const startTime = Date.now()
|
||||
const addTags = selectedTags.filter(tag => !largeTags.includes(tag))
|
||||
const removeTags = largeTags.filter(tag => !selectedTags.includes(tag))
|
||||
const endTime = Date.now()
|
||||
|
||||
expect(endTime - startTime).toBeLessThan(10) // Should be very fast
|
||||
expect(addTags.length).toBe(0) // All selected tags exist
|
||||
expect(removeTags.length).toBe(97) // 100 - 3 = 97 tags to remove
|
||||
})
|
||||
|
||||
it('should handle malformed tag data gracefully', () => {
|
||||
const mixedData = [
|
||||
{ id: 'valid1', name: 'Valid Tag', type: 'app', binding_count: 1 },
|
||||
{ id: 'invalid1' }, // Missing required properties
|
||||
null,
|
||||
undefined,
|
||||
{ id: 'valid2', name: 'Another Valid', type: 'app', binding_count: 0 },
|
||||
]
|
||||
|
||||
// Filter out invalid entries
|
||||
const validTags = mixedData.filter((tag): tag is { id: string; name: string; type: string; binding_count: number } =>
|
||||
tag != null
|
||||
&& typeof tag === 'object'
|
||||
&& 'id' in tag
|
||||
&& 'name' in tag
|
||||
&& 'type' in tag
|
||||
&& 'binding_count' in tag
|
||||
&& typeof tag.binding_count === 'number',
|
||||
)
|
||||
|
||||
expect(validTags.length).toBe(2)
|
||||
expect(validTags.every(tag => tag.id && tag.name)).toBe(true)
|
||||
})
|
||||
|
||||
it('should handle concurrent tag operations correctly', () => {
|
||||
const operations = [
|
||||
{ type: 'add', tagIds: ['tag1', 'tag2'] },
|
||||
{ type: 'remove', tagIds: ['tag3'] },
|
||||
{ type: 'add', tagIds: ['tag4'] },
|
||||
]
|
||||
|
||||
// Simulate processing operations
|
||||
const results = operations.map(op => ({
|
||||
...op,
|
||||
processed: true,
|
||||
timestamp: Date.now(),
|
||||
}))
|
||||
|
||||
expect(results.length).toBe(3)
|
||||
expect(results.every(result => result.processed)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Backward Compatibility Verification', () => {
|
||||
it('should not break existing AppCard behavior', () => {
|
||||
// Verify AppCard continues to work with original patterns
|
||||
const originalAppCardLogic = {
|
||||
initializeTags: (app: any) => app.tags,
|
||||
updateTags: (_currentTags: any[], newTags: any[]) => newTags,
|
||||
shouldRefresh: true,
|
||||
}
|
||||
|
||||
const app = { tags: [{ id: 'tag1', name: 'original' }] }
|
||||
const initializedTags = originalAppCardLogic.initializeTags(app)
|
||||
|
||||
expect(initializedTags).toBe(app.tags)
|
||||
expect(originalAppCardLogic.shouldRefresh).toBe(true)
|
||||
})
|
||||
|
||||
it('should ensure AppInfo follows AppCard patterns', () => {
|
||||
// Verify AppInfo uses compatible state management
|
||||
const appCardPattern = (app: any) => app.tags
|
||||
const appInfoPattern = (appDetail: any) => appDetail?.tags || []
|
||||
|
||||
const appWithTags = { tags: [{ id: 'tag1' }] }
|
||||
const appWithoutTags = { tags: [] }
|
||||
const undefinedApp = undefined
|
||||
|
||||
expect(appCardPattern(appWithTags)).toEqual(appInfoPattern(appWithTags))
|
||||
expect(appInfoPattern(appWithoutTags)).toEqual([])
|
||||
expect(appInfoPattern(undefinedApp)).toEqual([])
|
||||
})
|
||||
|
||||
it('should maintain consistent API parameters', () => {
|
||||
// Verify service layer maintains expected parameters
|
||||
const fetchAppListParams = {
|
||||
url: '/apps',
|
||||
params: { page: 1, limit: 100 },
|
||||
}
|
||||
|
||||
const tagApiParams = {
|
||||
bindTag: (tagIDs: string[], targetID: string, type: string) => ({ tagIDs, targetID, type }),
|
||||
unBindTag: (tagID: string, targetID: string, type: string) => ({ tagID, targetID, type }),
|
||||
}
|
||||
|
||||
expect(fetchAppListParams.url).toBe('/apps')
|
||||
expect(fetchAppListParams.params.limit).toBe(100)
|
||||
|
||||
const bindResult = tagApiParams.bindTag(['tag1'], 'app1', 'app')
|
||||
expect(bindResult.tagIDs).toEqual(['tag1'])
|
||||
expect(bindResult.type).toBe('app')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,212 @@
|
|||
/**
|
||||
* XSS Fix Verification Test
|
||||
*
|
||||
* This test verifies that the XSS vulnerability in check-code pages has been
|
||||
* properly fixed by replacing dangerouslySetInnerHTML with safe React rendering.
|
||||
*/
|
||||
|
||||
import React from 'react'
|
||||
import { cleanup, render } from '@testing-library/react'
|
||||
import '@testing-library/jest-dom'
|
||||
|
||||
// Mock i18next with the new safe translation structure
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
if (key === 'login.checkCode.tipsPrefix')
|
||||
return 'We send a verification code to '
|
||||
|
||||
return key
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock Next.js useSearchParams
|
||||
jest.mock('next/navigation', () => ({
|
||||
useSearchParams: () => ({
|
||||
get: (key: string) => {
|
||||
if (key === 'email')
|
||||
return 'test@example.com<script>alert("XSS")</script>'
|
||||
return null
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
// Fixed CheckCode component implementation (current secure version)
|
||||
const SecureCheckCodeComponent = ({ email }: { email: string }) => {
|
||||
const { t } = require('react-i18next').useTranslation()
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h1>Check Code</h1>
|
||||
<p>
|
||||
<span>
|
||||
{t('login.checkCode.tipsPrefix')}
|
||||
<strong>{email}</strong>
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Vulnerable implementation for comparison (what we fixed)
|
||||
const VulnerableCheckCodeComponent = ({ email }: { email: string }) => {
|
||||
const mockTranslation = (key: string, params?: any) => {
|
||||
if (key === 'login.checkCode.tips' && params?.email)
|
||||
return `We send a verification code to <strong>${params.email}</strong>`
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h1>Check Code</h1>
|
||||
<p>
|
||||
<span dangerouslySetInnerHTML={{ __html: mockTranslation('login.checkCode.tips', { email }) }}></span>
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
describe('XSS Fix Verification - Check Code Pages Security', () => {
|
||||
afterEach(() => {
|
||||
cleanup()
|
||||
})
|
||||
|
||||
const maliciousEmail = 'test@example.com<script>alert("XSS")</script>'
|
||||
|
||||
it('should securely render email with HTML characters as text (FIXED VERSION)', () => {
|
||||
console.log('\n🔒 Security Fix Verification Report')
|
||||
console.log('===================================')
|
||||
|
||||
const { container } = render(<SecureCheckCodeComponent email={maliciousEmail} />)
|
||||
|
||||
const spanElement = container.querySelector('span')
|
||||
const strongElement = container.querySelector('strong')
|
||||
const scriptElements = container.querySelectorAll('script')
|
||||
|
||||
console.log('\n✅ Fixed Implementation Results:')
|
||||
console.log('- Email rendered in strong tag:', strongElement?.textContent)
|
||||
console.log('- HTML tags visible as text:', strongElement?.textContent?.includes('<script>'))
|
||||
console.log('- Script elements created:', scriptElements.length)
|
||||
console.log('- Full text content:', spanElement?.textContent)
|
||||
|
||||
// Verify secure behavior
|
||||
expect(strongElement?.textContent).toBe(maliciousEmail) // Email rendered as text
|
||||
expect(strongElement?.textContent).toContain('<script>') // HTML visible as text
|
||||
expect(scriptElements).toHaveLength(0) // No script elements created
|
||||
expect(spanElement?.textContent).toBe(`We send a verification code to ${maliciousEmail}`)
|
||||
|
||||
console.log('\n🎯 Security Status: SECURE - HTML automatically escaped by React')
|
||||
})
|
||||
|
||||
it('should demonstrate the vulnerability that was fixed (VULNERABLE VERSION)', () => {
|
||||
const { container } = render(<VulnerableCheckCodeComponent email={maliciousEmail} />)
|
||||
|
||||
const spanElement = container.querySelector('span')
|
||||
const strongElement = container.querySelector('strong')
|
||||
const scriptElements = container.querySelectorAll('script')
|
||||
|
||||
console.log('\n⚠️ Previous Vulnerable Implementation:')
|
||||
console.log('- HTML content:', spanElement?.innerHTML)
|
||||
console.log('- Strong element text:', strongElement?.textContent)
|
||||
console.log('- Script elements created:', scriptElements.length)
|
||||
console.log('- Script content:', scriptElements[0]?.textContent)
|
||||
|
||||
// Verify vulnerability exists in old implementation
|
||||
expect(scriptElements).toHaveLength(1) // Script element was created
|
||||
expect(scriptElements[0]?.textContent).toBe('alert("XSS")') // Contains malicious code
|
||||
expect(spanElement?.innerHTML).toContain('<script>') // Raw HTML in DOM
|
||||
|
||||
console.log('\n❌ Security Status: VULNERABLE - dangerouslySetInnerHTML creates script elements')
|
||||
})
|
||||
|
||||
it('should verify all affected components use the secure pattern', () => {
|
||||
console.log('\n📋 Component Security Audit')
|
||||
console.log('============================')
|
||||
|
||||
// Test multiple malicious inputs
|
||||
const testCases = [
|
||||
'user@test.com<img src=x onerror=alert(1)>',
|
||||
'test@evil.com<div onclick="alert(2)">click</div>',
|
||||
'admin@site.com<script>document.cookie="stolen"</script>',
|
||||
'normal@email.com',
|
||||
]
|
||||
|
||||
testCases.forEach((testEmail, index) => {
|
||||
const { container } = render(<SecureCheckCodeComponent email={testEmail} />)
|
||||
|
||||
const strongElement = container.querySelector('strong')
|
||||
const scriptElements = container.querySelectorAll('script')
|
||||
const imgElements = container.querySelectorAll('img')
|
||||
const divElements = container.querySelectorAll('div:not([data-testid])')
|
||||
|
||||
console.log(`\n📧 Test Case ${index + 1}: ${testEmail.substring(0, 20)}...`)
|
||||
console.log(` - Script elements: ${scriptElements.length}`)
|
||||
console.log(` - Img elements: ${imgElements.length}`)
|
||||
console.log(` - Malicious divs: ${divElements.length - 1}`) // -1 for container div
|
||||
console.log(` - Text content: ${strongElement?.textContent === testEmail ? 'SAFE' : 'ISSUE'}`)
|
||||
|
||||
// All should be safe
|
||||
expect(scriptElements).toHaveLength(0)
|
||||
expect(imgElements).toHaveLength(0)
|
||||
expect(strongElement?.textContent).toBe(testEmail)
|
||||
})
|
||||
|
||||
console.log('\n✅ All test cases passed - secure rendering confirmed')
|
||||
})
|
||||
|
||||
it('should validate the translation structure is secure', () => {
|
||||
console.log('\n🔍 Translation Security Analysis')
|
||||
console.log('=================================')
|
||||
|
||||
const { t } = require('react-i18next').useTranslation()
|
||||
const prefix = t('login.checkCode.tipsPrefix')
|
||||
|
||||
console.log('- Translation key used: login.checkCode.tipsPrefix')
|
||||
console.log('- Translation value:', prefix)
|
||||
console.log('- Contains HTML tags:', prefix.includes('<'))
|
||||
console.log('- Pure text content:', !prefix.includes('<') && !prefix.includes('>'))
|
||||
|
||||
// Verify translation is plain text
|
||||
expect(prefix).toBe('We send a verification code to ')
|
||||
expect(prefix).not.toContain('<')
|
||||
expect(prefix).not.toContain('>')
|
||||
expect(typeof prefix).toBe('string')
|
||||
|
||||
console.log('\n✅ Translation structure is secure - no HTML content')
|
||||
})
|
||||
|
||||
it('should confirm React automatic escaping works correctly', () => {
|
||||
console.log('\n⚡ React Security Mechanism Test')
|
||||
console.log('=================================')
|
||||
|
||||
// Test React's automatic escaping with various inputs
|
||||
const dangerousInputs = [
|
||||
'<script>alert("xss")</script>',
|
||||
'<img src="x" onerror="alert(1)">',
|
||||
'"><script>alert(2)</script>',
|
||||
'\'>alert(3)</script>',
|
||||
'<div onclick="alert(4)">click</div>',
|
||||
]
|
||||
|
||||
dangerousInputs.forEach((input, index) => {
|
||||
const TestComponent = () => <strong>{input}</strong>
|
||||
const { container } = render(<TestComponent />)
|
||||
|
||||
const strongElement = container.querySelector('strong')
|
||||
const scriptElements = container.querySelectorAll('script')
|
||||
|
||||
console.log(`\n🧪 Input ${index + 1}: ${input.substring(0, 30)}...`)
|
||||
console.log(` - Rendered as text: ${strongElement?.textContent === input}`)
|
||||
console.log(` - No script execution: ${scriptElements.length === 0}`)
|
||||
|
||||
expect(strongElement?.textContent).toBe(input)
|
||||
expect(scriptElements).toHaveLength(0)
|
||||
})
|
||||
|
||||
console.log('\n🛡️ React automatic escaping is working perfectly')
|
||||
})
|
||||
})
|
||||
|
||||
export {}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue