Merge branch 'main' into feat/model-auth

This commit is contained in:
zxhlyh 2025-08-06 10:33:05 +08:00
commit 4e6cb26778
268 changed files with 12829 additions and 1436 deletions

27
.github/ISSUE_TEMPLATE/chore.yaml vendored Normal file
View File

@ -0,0 +1,27 @@
name: "✨ Refactor"
description: Refactor existing code for improved readability and maintainability.
title: "[Chore/Refactor] "
labels:
- refactor
body:
- type: textarea
id: description
attributes:
label: Description
placeholder: "Describe the refactor you are proposing."
validations:
required: true
- type: textarea
id: motivation
attributes:
label: Motivation
placeholder: "Explain why this refactor is necessary."
validations:
required: false
- type: textarea
id: additional-context
attributes:
label: Additional Context
placeholder: "Add any other context or screenshots about the request here."
validations:
required: false

View File

@ -99,3 +99,6 @@ jobs:
- name: Run Tool
run: uv run --project api bash dev/pytest/pytest_tools.sh
- name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh

View File

@ -7,6 +7,7 @@ on:
- "deploy/dev"
- "deploy/enterprise"
- "build/**"
- "release/e-*"
tags:
- "*"

View File

@ -5,6 +5,7 @@ import secrets
from typing import Any, Optional
import click
import sqlalchemy as sa
from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
@ -457,7 +458,7 @@ def convert_to_agent_apps():
"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query))
rs = conn.execute(sa.text(sql_query))
apps = []
for i in rs:
@ -702,7 +703,7 @@ def fix_app_site_missing():
sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id
where sites.id is null limit 1000"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql))
rs = conn.execute(sa.text(sql))
processed_count = 0
for i in rs:
@ -916,7 +917,7 @@ def clear_orphaned_file_records(force: bool):
)
orphaned_message_files = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])})
@ -937,7 +938,7 @@ def clear_orphaned_file_records(force: bool):
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
query = "DELETE FROM message_files WHERE id IN :ids"
with db.engine.begin() as conn:
conn.execute(db.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
click.echo(
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
)
@ -954,7 +955,7 @@ def clear_orphaned_file_records(force: bool):
click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white"))
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]})
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
@ -974,7 +975,7 @@ def clear_orphaned_file_records(force: bool):
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
elif ids_table["type"] == "text":
@ -989,7 +990,7 @@ def clear_orphaned_file_records(force: bool):
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
@ -1008,7 +1009,7 @@ def clear_orphaned_file_records(force: bool):
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
@ -1037,7 +1038,7 @@ def clear_orphaned_file_records(force: bool):
click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white"))
query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids"
with db.engine.begin() as conn:
conn.execute(db.text(query), {"ids": tuple(orphaned_files)})
conn.execute(sa.text(query), {"ids": tuple(orphaned_files)})
except Exception as e:
click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red"))
return
@ -1107,7 +1108,7 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white"))
query = f"SELECT {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
rs = conn.execute(sa.text(query))
for i in rs:
all_files_in_tables.append(str(i[0]))
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))

View File

@ -9,10 +9,10 @@ DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])

View File

@ -84,6 +84,7 @@ from .datasets import (
external,
hit_testing,
metadata,
upload_file,
website,
)

View File

@ -67,7 +67,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "message_count": i.message_count})
@ -176,7 +176,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
@ -234,7 +234,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"}
@ -310,7 +310,7 @@ ORDER BY
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
@ -373,7 +373,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{
@ -435,7 +435,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)})
@ -495,7 +495,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})

View File

@ -2,6 +2,7 @@ from datetime import datetime
from decimal import Decimal
import pytz
import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restful import Resource, reqparse
@ -71,7 +72,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "runs": i.runs})
@ -133,7 +134,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
@ -195,7 +196,7 @@ WHERE
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{
@ -277,7 +278,7 @@ GROUP BY
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
rs = conn.execute(sa.text(sql_query), arg_dict)
for i in rs:
response_data.append(
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}

View File

@ -642,7 +642,7 @@ class DocumentIndexingStatusApi(DocumentResource):
return marshal(document_dict, document_status_fields)
class DocumentDetailApi(DocumentResource):
class DocumentApi(DocumentResource):
METADATA_CHOICES = {"all", "only", "without"}
@setup_required
@ -730,6 +730,28 @@ class DocumentDetailApi(DocumentResource):
return response, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id)
try:
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
class DocumentProcessingApi(DocumentResource):
@setup_required
@ -768,30 +790,6 @@ class DocumentProcessingApi(DocumentResource):
return {"result": "success"}, 200
class DocumentDeleteApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id)
try:
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
class DocumentMetadataApi(DocumentResource):
@setup_required
@login_required
@ -1037,11 +1035,10 @@ api.add_resource(
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
)
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")

View File

@ -0,0 +1,62 @@
from flask_login import current_user
from flask_restful import Resource
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import UploadFile
from services.dataset_service import DocumentService
class UploadFileApi(Resource):
@setup_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""Get upload file."""
# check dataset
dataset_id = str(dataset_id)
dataset = (
db.session.query(Dataset)
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id)
.first()
)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check upload file
if document.data_source_type != "upload_file":
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:
raise ValueError("Upload file id not found in document data source info.")
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"url": url,
"download_url": f"{url}&as_attachment=true",
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at.timestamp(),
}, 200
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")

View File

@ -127,7 +127,7 @@ class EducationActivateLimitError(BaseHTTPException):
code = 429
class CompilanceRateLimitError(BaseHTTPException):
error_code = "compilance_rate_limit"
class ComplianceRateLimitError(BaseHTTPException):
error_code = "compliance_rate_limit"
description = "Rate limit exceeded for downloading compliance report."
code = 429

View File

@ -58,21 +58,38 @@ class InstalledAppsListApi(Resource):
# filter out apps that user doesn't have access to
if FeatureService.get_system_features().webapp_auth.enabled:
user_id = current_user.id
res = []
app_ids = [installed_app["app"].id for installed_app in installed_app_list]
webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids)
# Pre-filter out apps without setting or with sso_verified
filtered_installed_apps = []
app_id_to_app_code = {}
for installed_app in installed_app_list:
webapp_setting = webapp_settings.get(installed_app["app"].id)
if not webapp_setting:
app_id = installed_app["app"].id
webapp_setting = webapp_settings.get(app_id)
if not webapp_setting or webapp_setting.access_mode == "sso_verified":
continue
if webapp_setting.access_mode == "sso_verified":
continue
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id,
app_code=app_code,
):
app_code = AppService.get_app_code_by_id(str(app_id))
app_id_to_app_code[app_id] = app_code
filtered_installed_apps.append(installed_app)
app_codes = list(app_id_to_app_code.values())
# Batch permission check
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
user_id=user_id,
app_codes=app_codes,
)
# Keep only allowed apps
res = []
for installed_app in filtered_installed_apps:
app_id = installed_app["app"].id
app_code = app_id_to_app_code[app_id]
if permissions.get(app_code):
res.append(installed_app)
installed_app_list = res
logger.debug("installed_app_list: %s, user_id: %s", installed_app_list, user_id)

View File

@ -2,7 +2,7 @@ import logging
from flask import request
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.service_api import api
@ -30,6 +30,7 @@ from libs import helper
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
@ -113,7 +114,7 @@ class ChatApi(Resource):
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
parser.add_argument("workflow_id", type=str, required=False, location="json")
args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
@ -128,6 +129,12 @@ class ChatApi(Resource):
)
return helper.compact_generate_response(response)
except WorkflowNotFoundError as ex:
raise NotFound(str(ex))
except IsDraftWorkflowError as ex:
raise BadRequest(str(ex))
except WorkflowIdFormatError as ex:
raise BadRequest(str(ex))
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:

View File

@ -5,7 +5,7 @@ from flask import request
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import InternalServerError
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.service_api import api
from controllers.service_api.app.error import (
@ -34,6 +34,7 @@ from libs.helper import TimestampField
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_app_service import WorkflowAppService
@ -120,6 +121,59 @@ class WorkflowRunApi(Resource):
raise InternalServerError()
class WorkflowRunByIdApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, workflow_id: str):
"""
Run specific workflow by ID
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
args = parser.parse_args()
# Add workflow_id to args for AppGenerateService
args["workflow_id"] = workflow_id
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming"
try:
response = AppGenerateService.generate(
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
)
return helper.compact_generate_response(response)
except WorkflowNotFoundError as ex:
raise NotFound(str(ex))
except IsDraftWorkflowError as ex:
raise BadRequest(str(ex))
except WorkflowIdFormatError as ex:
raise BadRequest(str(ex))
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowTaskStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id: str):
@ -193,5 +247,6 @@ class WorkflowAppLogApi(Resource):
api.add_resource(WorkflowRunApi, "/workflows/run")
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_run_id>")
api.add_resource(WorkflowRunByIdApi, "/workflows/<string:workflow_id>/run")
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
api.add_resource(WorkflowAppLogApi, "/workflows/logs")

View File

@ -358,39 +358,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
return documents_and_batch_fields, 200
class DocumentDeleteApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document."""
document_id = str(document_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
raise NotFound("Document Not Exists.")
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
try:
# delete document
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return 204
class DocumentListApi(DatasetApiResource):
def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
@ -473,7 +440,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
return data
class DocumentDetailApi(DatasetApiResource):
class DocumentApi(DatasetApiResource):
METADATA_CHOICES = {"all", "only", "without"}
def get(self, tenant_id, dataset_id, document_id):
@ -567,6 +534,37 @@ class DocumentDetailApi(DatasetApiResource):
return response
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document."""
document_id = str(document_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
raise NotFound("Document Not Exists.")
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
try:
# delete document
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return 204
api.add_resource(
DocumentAddByTextApi,
@ -588,7 +586,6 @@ api.add_resource(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
)
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")

View File

@ -176,7 +176,7 @@ class ProviderConfig(BasicProviderConfig):
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
required: bool = False
default: Optional[Union[int, str]] = None
default: Optional[Union[int, str, float, bool]] = None
options: Optional[list[Option]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None

View File

@ -32,7 +32,7 @@ def get_attr(*, file: File, attr: FileAttribute):
case FileAttribute.TRANSFER_METHOD:
return file.transfer_method.value
case FileAttribute.URL:
return file.remote_url
return _to_url(file)
case FileAttribute.EXTENSION:
return file.extension
case FileAttribute.RELATED_ID:

View File

@ -208,6 +208,7 @@ class BasePluginClient:
except Exception:
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
logger.error("Error in stream reponse for plugin %s", rep.__dict__)
self._handle_plugin_daemon_error(error.error_type, error.message)
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
if rep.data is None:

View File

@ -2,6 +2,8 @@ from collections.abc import Mapping
from pydantic import TypeAdapter
from extensions.ext_logging import get_request_id
class PluginDaemonError(Exception):
"""Base class for all plugin daemon errors."""
@ -11,7 +13,7 @@ class PluginDaemonError(Exception):
def __str__(self) -> str:
# returns the class name and description
return f"{self.__class__.__name__}: {self.description}"
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"
class PluginDaemonInternalError(PluginDaemonError):

View File

@ -7,6 +7,7 @@ from urllib.parse import urlparse
import requests
from elasticsearch import Elasticsearch
from flask import current_app
from packaging.version import parse as parse_version
from pydantic import BaseModel, model_validator
from core.rag.datasource.vdb.field import Field
@ -149,7 +150,7 @@ class ElasticSearchVector(BaseVector):
return cast(str, info["version"]["number"])
def _check_version(self):
if self._version < "8.0.0":
if parse_version(self._version) < parse_version("8.0.0"):
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
def get_type(self) -> str:

View File

@ -20,9 +20,6 @@ class Tool(ABC):
The base class of a tool
"""
entity: ToolEntity
runtime: ToolRuntime
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
self.entity = entity
self.runtime = runtime

View File

@ -20,8 +20,6 @@ class BuiltinTool(Tool):
:param meta: the meta data of a tool call processing
"""
provider: str
def __init__(self, provider: str, **kwargs):
super().__init__(**kwargs)
self.provider = provider

View File

@ -21,9 +21,6 @@ API_TOOL_DEFAULT_TIMEOUT = (
class ApiTool(Tool):
api_bundle: ApiToolBundle
provider_id: str
"""
Api tool
"""

View File

@ -8,23 +8,16 @@ from core.mcp.mcp_client import MCPClient
from core.mcp.types import ImageContent, TextContent
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
class MCPTool(Tool):
tenant_id: str
icon: str
runtime_parameters: Optional[list[ToolParameter]]
server_url: str
provider_id: str
def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.runtime_parameters = None
self.server_url = server_url
self.provider_id = provider_id

View File

@ -9,11 +9,6 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
class PluginTool(Tool):
tenant_id: str
icon: str
plugin_unique_identifier: str
runtime_parameters: Optional[list[ToolParameter]]
def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
) -> None:
@ -21,7 +16,7 @@ class PluginTool(Tool):
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
self.runtime_parameters = None
self.runtime_parameters: Optional[list[ToolParameter]] = None
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.PLUGIN

View File

@ -7,6 +7,7 @@ from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
import sqlalchemy as sa
from pydantic import TypeAdapter
from yarl import URL
@ -616,7 +617,7 @@ class ToolManager:
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
@classmethod

View File

@ -20,8 +20,6 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas
class DatasetRetrieverTool(Tool):
retrieval_tool: DatasetRetrieverBaseTool
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
super().__init__(entity, runtime)
self.retrieval_tool = retrieval_tool

View File

@ -25,15 +25,6 @@ logger = logging.getLogger(__name__)
class WorkflowTool(Tool):
workflow_app_id: str
version: str
workflow_entities: dict[str, Any]
workflow_call_depth: int
thread_pool_id: Optional[str] = None
workflow_as_tool_id: str
label: str
"""
Workflow tool.
"""

View File

@ -109,7 +109,7 @@ class SegmentType(StrEnum):
elif array_validation == ArrayValidation.FIRST:
return element_type.is_valid(value[0])
else:
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value)
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
"""
@ -152,7 +152,7 @@ class SegmentType(StrEnum):
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
# ARRAY_ANY does not have correpond element type.
# ARRAY_ANY does not have corresponding element type.
SegmentType.ARRAY_STRING: SegmentType.STRING,
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,

View File

@ -318,6 +318,33 @@ class ToolNode(BaseNode):
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": message.message.text,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])

View File

@ -136,6 +136,8 @@ def init_app(app: DifyApp):
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
from opentelemetry.instrumentation.celery import CeleryInstrumentor
from opentelemetry.instrumentation.flask import FlaskInstrumentor
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.instrumentation.requests import RequestsInstrumentor
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider
from opentelemetry.propagate import set_global_textmap
@ -234,6 +236,8 @@ def init_app(app: DifyApp):
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
instrument_exception_logging()
init_sqlalchemy_instrumentor(app)
RedisInstrumentor().instrument()
RequestsInstrumentor().instrument()
atexit.register(shutdown_tracer)

View File

@ -59,6 +59,8 @@ model_config_fields = {
"updated_at": TimestampField,
}
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
app_detail_fields = {
"id": fields.String,
"name": fields.String,
@ -77,6 +79,7 @@ app_detail_fields = {
"updated_by": fields.String,
"updated_at": TimestampField,
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_fields)),
}
prompt_config_fields = {
@ -92,8 +95,6 @@ model_config_partial_fields = {
"updated_at": TimestampField,
}
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
app_partial_fields = {
"id": fields.String,
"name": fields.String,
@ -185,7 +186,6 @@ app_detail_fields_with_site = {
"enable_api": fields.Boolean,
"model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
"site": fields.Nested(site_fields),
"api_base_url": fields.String,
"use_icon_as_answer_icon": fields.Boolean,
"max_active_requests": fields.Integer,
@ -195,6 +195,8 @@ app_detail_fields_with_site = {
"updated_at": TimestampField,
"deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_fields)),
"site": fields.Nested(site_fields),
}

View File

@ -3,6 +3,7 @@ import json
from datetime import datetime
from typing import Optional, cast
import sqlalchemy as sa
from flask_login import UserMixin # type: ignore
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
@ -83,9 +84,9 @@ class AccountStatus(enum.StrEnum):
class Account(UserMixin, Base):
__tablename__ = "accounts"
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(String(255))
email: Mapped[str] = mapped_column(String(255))
password: Mapped[Optional[str]] = mapped_column(String(255))
@ -97,7 +98,7 @@ class Account(UserMixin, Base):
last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
status: Mapped[str] = mapped_column(String(16), server_default=db.text("'active'::character varying"))
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying"))
initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
@ -195,14 +196,14 @@ class TenantStatus(enum.StrEnum):
class Tenant(Base):
__tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(String(255))
encrypt_public_key = db.Column(db.Text)
plan: Mapped[str] = mapped_column(String(255), server_default=db.text("'basic'::character varying"))
status: Mapped[str] = mapped_column(String(255), server_default=db.text("'normal'::character varying"))
custom_config: Mapped[Optional[str]] = mapped_column(db.Text)
encrypt_public_key = db.Column(sa.Text)
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying"))
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
custom_config: Mapped[Optional[str]] = mapped_column(sa.Text)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
@ -225,16 +226,16 @@ class Tenant(Base):
class TenantAccountJoin(Base):
__tablename__ = "tenant_account_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
db.Index("tenant_account_join_account_id_idx", "account_id"),
db.Index("tenant_account_join_tenant_id_idx", "tenant_id"),
db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
sa.Index("tenant_account_join_account_id_idx", "account_id"),
sa.Index("tenant_account_join_tenant_id_idx", "tenant_id"),
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
account_id: Mapped[str] = mapped_column(StringUUID)
current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
role: Mapped[str] = mapped_column(String(16), server_default="normal")
invited_by: Mapped[Optional[str]] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
@ -244,12 +245,12 @@ class TenantAccountJoin(Base):
class AccountIntegrate(Base):
__tablename__ = "account_integrates"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
db.UniqueConstraint("account_id", "provider", name="unique_account_provider"),
db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
sa.UniqueConstraint("account_id", "provider", name="unique_account_provider"),
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
account_id: Mapped[str] = mapped_column(StringUUID)
provider: Mapped[str] = mapped_column(String(16))
open_id: Mapped[str] = mapped_column(String(255))
@ -261,20 +262,20 @@ class AccountIntegrate(Base):
class InvitationCode(Base):
__tablename__ = "invitation_codes"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
db.Index("invitation_codes_batch_idx", "batch"),
db.Index("invitation_codes_code_idx", "code", "status"),
sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
sa.Index("invitation_codes_batch_idx", "batch"),
sa.Index("invitation_codes_code_idx", "code", "status"),
)
id: Mapped[int] = mapped_column(db.Integer)
id: Mapped[int] = mapped_column(sa.Integer)
batch: Mapped[str] = mapped_column(String(255))
code: Mapped[str] = mapped_column(String(32))
status: Mapped[str] = mapped_column(String(16), server_default=db.text("'unused'::character varying"))
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying"))
used_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID)
used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
class TenantPluginPermission(Base):
@ -290,11 +291,11 @@ class TenantPluginPermission(Base):
__tablename__ = "account_plugin_permissions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"),
db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
sa.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"),
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone")
debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone")
@ -313,16 +314,16 @@ class TenantPluginAutoUpgradeStrategy(Base):
__tablename__ = "tenant_plugin_auto_upgrade_strategies"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"),
db.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
sa.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"),
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only")
upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day
upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude")
exclude_plugins: Mapped[list[str]] = mapped_column(db.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
include_plugins: Mapped[list[str]] = mapped_column(db.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
created_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -1,11 +1,11 @@
import enum
from datetime import datetime
import sqlalchemy as sa
from sqlalchemy import DateTime, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .engine import db
from .types import StringUUID
@ -19,11 +19,11 @@ class APIBasedExtensionPoint(enum.Enum):
class APIBasedExtension(Base):
__tablename__ = "api_based_extensions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
db.Index("api_based_extension_tenant_idx", "tenant_id"),
sa.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
sa.Index("api_based_extension_tenant_idx", "tenant_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -12,6 +12,7 @@ from datetime import datetime
from json import JSONDecodeError
from typing import Any, Optional, cast
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
@ -38,23 +39,23 @@ class DatasetPermissionEnum(enum.StrEnum):
class Dataset(Base):
__tablename__ = "datasets"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
db.Index("dataset_tenant_idx", "tenant_id"),
db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
sa.PrimaryKeyConstraint("id", name="dataset_pkey"),
sa.Index("dataset_tenant_idx", "tenant_id"),
sa.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
)
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
PROVIDER_LIST = ["vendor", "external", None]
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
name: Mapped[str] = mapped_column(String(255))
description = mapped_column(db.Text, nullable=True)
provider: Mapped[str] = mapped_column(String(255), server_default=db.text("'vendor'::character varying"))
permission: Mapped[str] = mapped_column(String(255), server_default=db.text("'only_me'::character varying"))
description = mapped_column(sa.Text, nullable=True)
provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'::character varying"))
permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'::character varying"))
data_source_type = mapped_column(String(255))
indexing_technique: Mapped[Optional[str]] = mapped_column(String(255))
index_struct = mapped_column(db.Text, nullable=True)
index_struct = mapped_column(sa.Text, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
@ -63,7 +64,7 @@ class Dataset(Base):
embedding_model_provider = db.Column(String(255), nullable=True) # TODO: mapped_column
collection_binding_id = mapped_column(StringUUID, nullable=True)
retrieval_model = mapped_column(JSONB, nullable=True)
built_in_field_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
built_in_field_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property
def dataset_keyword_table(self):
@ -262,14 +263,14 @@ class Dataset(Base):
class DatasetProcessRule(Base):
__tablename__ = "dataset_process_rules"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
sa.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
dataset_id = mapped_column(StringUUID, nullable=False)
mode = mapped_column(String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
rules = mapped_column(db.Text, nullable=True)
mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying"))
rules = mapped_column(sa.Text, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -302,20 +303,20 @@ class DatasetProcessRule(Base):
class Document(Base):
__tablename__ = "documents"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_pkey"),
db.Index("document_dataset_id_idx", "dataset_id"),
db.Index("document_is_paused_idx", "is_paused"),
db.Index("document_tenant_idx", "tenant_id"),
db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
sa.PrimaryKeyConstraint("id", name="document_pkey"),
sa.Index("document_dataset_id_idx", "dataset_id"),
sa.Index("document_is_paused_idx", "is_paused"),
sa.Index("document_tenant_idx", "tenant_id"),
sa.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
)
# initial fields
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(db.Integer, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
data_source_type: Mapped[str] = mapped_column(String(255), nullable=False)
data_source_info = mapped_column(db.Text, nullable=True)
data_source_info = mapped_column(sa.Text, nullable=True)
dataset_process_rule_id = mapped_column(StringUUID, nullable=True)
batch: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
@ -328,8 +329,8 @@ class Document(Base):
processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# parsing
file_id = mapped_column(db.Text, nullable=True)
word_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) # TODO: make this not nullable
file_id = mapped_column(sa.Text, nullable=True)
word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable
parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# cleaning
@ -339,32 +340,32 @@ class Document(Base):
splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# indexing
tokens: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
indexing_latency: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True)
tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
indexing_latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# pause
is_paused: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
is_paused: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
paused_by = mapped_column(StringUUID, nullable=True)
paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# error
error = mapped_column(db.Text, nullable=True)
error = mapped_column(sa.Text, nullable=True)
stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
# basic fields
indexing_status = mapped_column(String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'::character varying"))
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
archived: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
archived: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
archived_reason = mapped_column(String(255), nullable=True)
archived_by = mapped_column(StringUUID, nullable=True)
archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
doc_type = mapped_column(String(40), nullable=True)
doc_metadata = mapped_column(JSONB, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'::character varying"))
doc_language = mapped_column(String(255), nullable=True)
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
@ -643,44 +644,44 @@ class Document(Base):
class DocumentSegment(Base):
__tablename__ = "document_segments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
db.Index("document_segment_dataset_id_idx", "dataset_id"),
db.Index("document_segment_document_id_idx", "document_id"),
db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
db.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"),
db.Index("document_segment_tenant_idx", "tenant_id"),
sa.PrimaryKeyConstraint("id", name="document_segment_pkey"),
sa.Index("document_segment_dataset_id_idx", "dataset_id"),
sa.Index("document_segment_document_id_idx", "document_id"),
sa.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
sa.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
sa.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"),
sa.Index("document_segment_tenant_idx", "tenant_id"),
)
# initial fields
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
position: Mapped[int]
content = mapped_column(db.Text, nullable=False)
answer = mapped_column(db.Text, nullable=True)
content = mapped_column(sa.Text, nullable=False)
answer = mapped_column(sa.Text, nullable=True)
word_count: Mapped[int]
tokens: Mapped[int]
# indexing fields
keywords = mapped_column(db.JSON, nullable=True)
keywords = mapped_column(sa.JSON, nullable=True)
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
# basic fields
hit_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
status: Mapped[str] = mapped_column(String(255), server_default=db.text("'waiting'::character varying"))
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'::character varying"))
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
error = mapped_column(db.Text, nullable=True)
error = mapped_column(sa.Text, nullable=True)
stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
@property
@ -794,36 +795,36 @@ class DocumentSegment(Base):
class ChildChunk(Base):
__tablename__ = "child_chunks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
db.Index("child_chunks_node_idx", "index_node_id", "dataset_id"),
db.Index("child_chunks_segment_idx", "segment_id"),
sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
sa.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
sa.Index("child_chunks_node_idx", "index_node_id", "dataset_id"),
sa.Index("child_chunks_segment_idx", "segment_id"),
)
# initial fields
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
segment_id = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(db.Integer, nullable=False)
content = mapped_column(db.Text, nullable=False)
word_count: Mapped[int] = mapped_column(db.Integer, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
content = mapped_column(sa.Text, nullable=False)
word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
# indexing fields
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
type = mapped_column(String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying"))
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
error = mapped_column(db.Text, nullable=True)
error = mapped_column(sa.Text, nullable=True)
@property
def dataset(self):
@ -841,11 +842,11 @@ class ChildChunk(Base):
class AppDatasetJoin(Base):
__tablename__ = "app_dataset_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
sa.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
sa.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
)
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
@ -858,13 +859,13 @@ class AppDatasetJoin(Base):
class DatasetQuery(Base):
__tablename__ = "dataset_queries"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
db.Index("dataset_query_dataset_id_idx", "dataset_id"),
sa.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
sa.Index("dataset_query_dataset_id_idx", "dataset_id"),
)
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
dataset_id = mapped_column(StringUUID, nullable=False)
content = mapped_column(db.Text, nullable=False)
content = mapped_column(sa.Text, nullable=False)
source: Mapped[str] = mapped_column(String(255), nullable=False)
source_app_id = mapped_column(StringUUID, nullable=True)
created_by_role = mapped_column(String, nullable=False)
@ -875,15 +876,15 @@ class DatasetQuery(Base):
class DatasetKeywordTable(Base):
__tablename__ = "dataset_keyword_tables"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
sa.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
sa.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
dataset_id = mapped_column(StringUUID, nullable=False, unique=True)
keyword_table = mapped_column(db.Text, nullable=False)
keyword_table = mapped_column(sa.Text, nullable=False)
data_source_type = mapped_column(
String(255), nullable=False, server_default=db.text("'database'::character varying")
String(255), nullable=False, server_default=sa.text("'database'::character varying")
)
@property
@ -920,19 +921,19 @@ class DatasetKeywordTable(Base):
class Embedding(Base):
__tablename__ = "embeddings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="embedding_pkey"),
db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
db.Index("created_at_idx", "created_at"),
sa.PrimaryKeyConstraint("id", name="embedding_pkey"),
sa.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
sa.Index("created_at_idx", "created_at"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
model_name = mapped_column(
String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")
String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'::character varying")
)
hash = mapped_column(String(64), nullable=False)
embedding = mapped_column(db.LargeBinary, nullable=False)
embedding = mapped_column(sa.LargeBinary, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
provider_name = mapped_column(String(255), nullable=False, server_default=db.text("''::character varying"))
provider_name = mapped_column(String(255), nullable=False, server_default=sa.text("''::character varying"))
def set_embedding(self, embedding_data: list[float]):
self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
@ -944,14 +945,14 @@ class Embedding(Base):
class DatasetCollectionBinding(Base):
__tablename__ = "dataset_collection_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
db.Index("provider_model_name_idx", "provider_name", "model_name"),
sa.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
sa.Index("provider_model_name_idx", "provider_name", "model_name"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
type = mapped_column(String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
type = mapped_column(String(40), server_default=sa.text("'dataset'::character varying"), nullable=False)
collection_name = mapped_column(String(64), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -959,17 +960,17 @@ class DatasetCollectionBinding(Base):
class TidbAuthBinding(Base):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
db.Index("tidb_auth_bindings_active_idx", "active"),
db.Index("tidb_auth_bindings_created_at_idx", "created_at"),
db.Index("tidb_auth_bindings_status_idx", "status"),
sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
sa.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
sa.Index("tidb_auth_bindings_active_idx", "active"),
sa.Index("tidb_auth_bindings_created_at_idx", "created_at"),
sa.Index("tidb_auth_bindings_status_idx", "status"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=True)
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
active: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'::character varying"))
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
@ -979,10 +980,10 @@ class TidbAuthBinding(Base):
class Whitelist(Base):
__tablename__ = "whitelists"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
db.Index("whitelists_tenant_idx", "tenant_id"),
sa.PrimaryKeyConstraint("id", name="whitelists_pkey"),
sa.Index("whitelists_tenant_idx", "tenant_id"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=True)
category: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -991,33 +992,33 @@ class Whitelist(Base):
class DatasetPermission(Base):
__tablename__ = "dataset_permissions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
db.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
db.Index("idx_dataset_permissions_account_id", "account_id"),
db.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
sa.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
sa.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
sa.Index("idx_dataset_permissions_account_id", "account_id"),
sa.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), primary_key=True)
dataset_id = mapped_column(StringUUID, nullable=False)
account_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
has_permission: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
has_permission: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
class ExternalKnowledgeApis(Base):
__tablename__ = "external_knowledge_apis"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
db.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
db.Index("external_knowledge_apis_name_idx", "name"),
sa.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
sa.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
sa.Index("external_knowledge_apis_name_idx", "name"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(String(255), nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
settings = mapped_column(db.Text, nullable=True)
settings = mapped_column(sa.Text, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
@ -1061,18 +1062,18 @@ class ExternalKnowledgeApis(Base):
class ExternalKnowledgeBindings(Base):
__tablename__ = "external_knowledge_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
sa.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
sa.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
sa.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
sa.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
external_knowledge_api_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
external_knowledge_id = mapped_column(db.Text, nullable=False)
external_knowledge_id = mapped_column(sa.Text, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
@ -1082,57 +1083,57 @@ class ExternalKnowledgeBindings(Base):
class DatasetAutoDisableLog(Base):
__tablename__ = "dataset_auto_disable_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
db.Index("dataset_auto_disable_log_created_atx", "created_at"),
sa.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
sa.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
sa.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
sa.Index("dataset_auto_disable_log_created_atx", "created_at"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
notified: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
class RateLimitLog(Base):
__tablename__ = "rate_limit_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
db.Index("rate_limit_log_tenant_idx", "tenant_id"),
db.Index("rate_limit_log_operation_idx", "operation"),
sa.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
sa.Index("rate_limit_log_tenant_idx", "tenant_id"),
sa.Index("rate_limit_log_operation_idx", "operation"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False)
operation: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
class DatasetMetadata(Base):
__tablename__ = "dataset_metadatas"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
db.Index("dataset_metadata_tenant_idx", "tenant_id"),
db.Index("dataset_metadata_dataset_idx", "dataset_id"),
sa.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
sa.Index("dataset_metadata_tenant_idx", "tenant_id"),
sa.Index("dataset_metadata_dataset_idx", "dataset_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
created_by = mapped_column(StringUUID, nullable=False)
updated_by = mapped_column(StringUUID, nullable=True)
@ -1141,14 +1142,14 @@ class DatasetMetadata(Base):
class DatasetMetadataBinding(Base):
__tablename__ = "dataset_metadata_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
db.Index("dataset_metadata_binding_document_idx", "document_id"),
sa.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
sa.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
sa.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
sa.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
sa.Index("dataset_metadata_binding_document_idx", "document_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
metadata_id = mapped_column(StringUUID, nullable=False)

View File

@ -35,10 +35,10 @@ from .types import StringUUID
class DifySetup(Base):
__tablename__ = "dify_setups"
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
__table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
version: Mapped[str] = mapped_column(String(255), nullable=False)
setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
setup_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class AppMode(StrEnum):
@ -69,33 +69,33 @@ class IconType(Enum):
class App(Base):
__tablename__ = "apps"
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_pkey"), sa.Index("app_tenant_id_idx", "tenant_id"))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
name: Mapped[str] = mapped_column(String(255))
description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying"))
description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying"))
mode: Mapped[str] = mapped_column(String(255))
icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji
icon = db.Column(String(255))
icon_background: Mapped[Optional[str]] = mapped_column(String(255))
app_model_config_id = mapped_column(StringUUID, nullable=True)
workflow_id = mapped_column(StringUUID, nullable=True)
status: Mapped[str] = mapped_column(String(255), server_default=db.text("'normal'::character varying"))
enable_site: Mapped[bool] = mapped_column(db.Boolean)
enable_api: Mapped[bool] = mapped_column(db.Boolean)
api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
api_rph: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
is_demo: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
is_public: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
is_universal: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
tracing = mapped_column(db.Text, nullable=True)
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
enable_site: Mapped[bool] = mapped_column(sa.Boolean)
enable_api: Mapped[bool] = mapped_column(sa.Boolean)
api_rpm: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"))
api_rph: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"))
is_demo: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
tracing = mapped_column(sa.Text, nullable=True)
max_active_requests: Mapped[Optional[int]]
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property
def desc_or_prompt(self):
@ -302,36 +302,36 @@ class App(Base):
class AppModelConfig(Base):
__tablename__ = "app_model_configs"
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id"))
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
provider = mapped_column(String(255), nullable=True)
model_id = mapped_column(String(255), nullable=True)
configs = mapped_column(db.JSON, nullable=True)
configs = mapped_column(sa.JSON, nullable=True)
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
opening_statement = mapped_column(db.Text)
suggested_questions = mapped_column(db.Text)
suggested_questions_after_answer = mapped_column(db.Text)
speech_to_text = mapped_column(db.Text)
text_to_speech = mapped_column(db.Text)
more_like_this = mapped_column(db.Text)
model = mapped_column(db.Text)
user_input_form = mapped_column(db.Text)
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
opening_statement = mapped_column(sa.Text)
suggested_questions = mapped_column(sa.Text)
suggested_questions_after_answer = mapped_column(sa.Text)
speech_to_text = mapped_column(sa.Text)
text_to_speech = mapped_column(sa.Text)
more_like_this = mapped_column(sa.Text)
model = mapped_column(sa.Text)
user_input_form = mapped_column(sa.Text)
dataset_query_variable = mapped_column(String(255))
pre_prompt = mapped_column(db.Text)
agent_mode = mapped_column(db.Text)
sensitive_word_avoidance = mapped_column(db.Text)
retriever_resource = mapped_column(db.Text)
prompt_type = mapped_column(String(255), nullable=False, server_default=db.text("'simple'::character varying"))
chat_prompt_config = mapped_column(db.Text)
completion_prompt_config = mapped_column(db.Text)
dataset_configs = mapped_column(db.Text)
external_data_tools = mapped_column(db.Text)
file_upload = mapped_column(db.Text)
pre_prompt = mapped_column(sa.Text)
agent_mode = mapped_column(sa.Text)
sensitive_word_avoidance = mapped_column(sa.Text)
retriever_resource = mapped_column(sa.Text)
prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'::character varying"))
chat_prompt_config = mapped_column(sa.Text)
completion_prompt_config = mapped_column(sa.Text)
dataset_configs = mapped_column(sa.Text)
external_data_tools = mapped_column(sa.Text)
file_upload = mapped_column(sa.Text)
@property
def app(self):
@ -553,24 +553,24 @@ class AppModelConfig(Base):
class RecommendedApp(Base):
__tablename__ = "recommended_apps"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
db.Index("recommended_app_app_id_idx", "app_id"),
db.Index("recommended_app_is_listed_idx", "is_listed", "language"),
sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
sa.Index("recommended_app_app_id_idx", "app_id"),
sa.Index("recommended_app_is_listed_idx", "is_listed", "language"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
description = mapped_column(db.JSON, nullable=False)
description = mapped_column(sa.JSON, nullable=False)
copyright: Mapped[str] = mapped_column(String(255), nullable=False)
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False)
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
category: Mapped[str] = mapped_column(String(255), nullable=False)
position: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0)
is_listed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=True)
install_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0)
language = mapped_column(String(255), nullable=False, server_default=db.text("'en-US'::character varying"))
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self):
@ -581,20 +581,20 @@ class RecommendedApp(Base):
class InstalledApp(Base):
__tablename__ = "installed_apps"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="installed_app_pkey"),
db.Index("installed_app_tenant_id_idx", "tenant_id"),
db.Index("installed_app_app_id_idx", "app_id"),
db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
sa.PrimaryKeyConstraint("id", name="installed_app_pkey"),
sa.Index("installed_app_tenant_id_idx", "tenant_id"),
sa.Index("installed_app_app_id_idx", "app_id"),
sa.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
app_owner_tenant_id = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0)
is_pinned: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
last_used_at = mapped_column(db.DateTime, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
is_pinned: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
last_used_at = mapped_column(sa.DateTime, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self):
@ -610,23 +610,23 @@ class InstalledApp(Base):
class Conversation(Base):
__tablename__ = "conversations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="conversation_pkey"),
db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
sa.PrimaryKeyConstraint("id", name="conversation_pkey"),
sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
app_model_config_id = mapped_column(StringUUID, nullable=True)
model_provider = mapped_column(String(255), nullable=True)
override_model_configs = mapped_column(db.Text)
override_model_configs = mapped_column(sa.Text)
model_id = mapped_column(String(255), nullable=True)
mode: Mapped[str] = mapped_column(String(255))
name: Mapped[str] = mapped_column(String(255), nullable=False)
summary = mapped_column(db.Text)
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
introduction = mapped_column(db.Text)
system_instruction = mapped_column(db.Text)
system_instruction_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
summary = mapped_column(sa.Text)
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
introduction = mapped_column(sa.Text)
system_instruction = mapped_column(sa.Text)
system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
status: Mapped[str] = mapped_column(String(255), nullable=False)
# The `invoke_from` records how the conversation is created.
@ -639,18 +639,18 @@ class Conversation(Base):
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_end_user_id = mapped_column(StringUUID)
from_account_id = mapped_column(StringUUID)
read_at = mapped_column(db.DateTime)
read_at = mapped_column(sa.DateTime)
read_account_id = mapped_column(StringUUID)
dialogue_count: Mapped[int] = mapped_column(default=0)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
message_annotations = db.relationship(
"MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
)
is_deleted: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property
def inputs(self):
@ -892,36 +892,36 @@ class Message(Base):
Index("message_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
model_provider = mapped_column(String(255), nullable=True)
model_id = mapped_column(String(255), nullable=True)
override_model_configs = mapped_column(db.Text)
conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
query: Mapped[str] = mapped_column(db.Text, nullable=False)
message = mapped_column(db.JSON, nullable=False)
message_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
message_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
answer: Mapped[str] = db.Column(db.Text, nullable=False) # TODO make it mapped_column
answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
override_model_configs = mapped_column(sa.Text)
conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
query: Mapped[str] = mapped_column(sa.Text, nullable=False)
message = mapped_column(sa.JSON, nullable=False)
message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
message_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
answer: Mapped[str] = db.Column(sa.Text, nullable=False) # TODO make it mapped_column
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
parent_message_id = mapped_column(StringUUID, nullable=True)
provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0"))
total_price = mapped_column(db.Numeric(10, 7))
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_price = mapped_column(sa.Numeric(10, 7))
currency: Mapped[str] = mapped_column(String(255), nullable=False)
status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying"))
error = mapped_column(db.Text)
message_metadata = mapped_column(db.Text)
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
error = mapped_column(sa.Text)
message_metadata = mapped_column(sa.Text)
invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID)
from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
agent_based: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
@property
@ -1228,23 +1228,23 @@ class Message(Base):
class MessageFeedback(Base):
__tablename__ = "message_feedbacks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
db.Index("message_feedback_app_idx", "app_id"),
db.Index("message_feedback_message_idx", "message_id", "from_source"),
db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
sa.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
sa.Index("message_feedback_app_idx", "app_id"),
sa.Index("message_feedback_message_idx", "message_id", "from_source"),
sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
conversation_id = mapped_column(StringUUID, nullable=False)
message_id = mapped_column(StringUUID, nullable=False)
rating: Mapped[str] = mapped_column(String(255), nullable=False)
content = mapped_column(db.Text)
content = mapped_column(sa.Text)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
from_end_user_id = mapped_column(StringUUID)
from_account_id = mapped_column(StringUUID)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def from_account(self):
@ -1270,9 +1270,9 @@ class MessageFeedback(Base):
class MessageFile(Base):
__tablename__ = "message_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_file_pkey"),
db.Index("message_file_message_idx", "message_id"),
db.Index("message_file_created_by_idx", "created_by"),
sa.PrimaryKeyConstraint("id", name="message_file_pkey"),
sa.Index("message_file_message_idx", "message_id"),
sa.Index("message_file_created_by_idx", "created_by"),
)
def __init__(
@ -1296,37 +1296,37 @@ class MessageFile(Base):
self.created_by_role = created_by_role.value
self.created_by = created_by
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
transfer_method: Mapped[str] = mapped_column(String(255), nullable=False)
url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageAnnotation(Base):
__tablename__ = "message_annotations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
db.Index("message_annotation_app_idx", "app_id"),
db.Index("message_annotation_conversation_idx", "conversation_id"),
db.Index("message_annotation_message_idx", "message_id"),
sa.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
sa.Index("message_annotation_app_idx", "app_id"),
sa.Index("message_annotation_conversation_idx", "conversation_id"),
sa.Index("message_annotation_message_idx", "message_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id: Mapped[str] = mapped_column(StringUUID)
conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, db.ForeignKey("conversations.id"))
conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
message_id: Mapped[Optional[str]] = mapped_column(StringUUID)
question = db.Column(db.Text, nullable=True)
content = mapped_column(db.Text, nullable=False)
hit_count: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
question = db.Column(sa.Text, nullable=True)
content = mapped_column(sa.Text, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
account_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def account(self):
@ -1342,24 +1342,24 @@ class MessageAnnotation(Base):
class AppAnnotationHitHistory(Base):
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
db.Index("app_annotation_hit_histories_app_idx", "app_id"),
db.Index("app_annotation_hit_histories_account_idx", "account_id"),
db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"),
db.Index("app_annotation_hit_histories_message_idx", "message_id"),
sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
sa.Index("app_annotation_hit_histories_app_idx", "app_id"),
sa.Index("app_annotation_hit_histories_account_idx", "account_id"),
sa.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"),
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
source = mapped_column(db.Text, nullable=False)
question = mapped_column(db.Text, nullable=False)
source = mapped_column(sa.Text, nullable=False)
question = mapped_column(sa.Text, nullable=False)
account_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
score = mapped_column(Float, nullable=False, server_default=db.text("0"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
score = mapped_column(Float, nullable=False, server_default=sa.text("0"))
message_id = mapped_column(StringUUID, nullable=False)
annotation_question = mapped_column(db.Text, nullable=False)
annotation_content = mapped_column(db.Text, nullable=False)
annotation_question = mapped_column(sa.Text, nullable=False)
annotation_content = mapped_column(sa.Text, nullable=False)
@property
def account(self):
@ -1380,18 +1380,18 @@ class AppAnnotationHitHistory(Base):
class AppAnnotationSetting(Base):
__tablename__ = "app_annotation_settings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
db.Index("app_annotation_settings_app_idx", "app_id"),
sa.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
sa.Index("app_annotation_settings_app_idx", "app_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
score_threshold = mapped_column(Float, nullable=False, server_default=db.text("0"))
score_threshold = mapped_column(Float, nullable=False, server_default=sa.text("0"))
collection_binding_id = mapped_column(StringUUID, nullable=False)
created_user_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_user_id = mapped_column(StringUUID, nullable=False)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def collection_binding_detail(self):
@ -1408,58 +1408,58 @@ class AppAnnotationSetting(Base):
class OperationLog(Base):
__tablename__ = "operation_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="operation_log_pkey"),
db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
sa.PrimaryKeyConstraint("id", name="operation_log_pkey"),
sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
account_id = mapped_column(StringUUID, nullable=False)
action: Mapped[str] = mapped_column(String(255), nullable=False)
content = mapped_column(db.JSON)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
content = mapped_column(sa.JSON)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_ip: Mapped[str] = mapped_column(String(255), nullable=False)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class EndUser(Base, UserMixin):
__tablename__ = "end_users"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="end_user_pkey"),
db.Index("end_user_session_id_idx", "session_id", "type"),
db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
sa.PrimaryKeyConstraint("id", name="end_user_pkey"),
sa.Index("end_user_session_id_idx", "session_id", "type"),
sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=True)
type: Mapped[str] = mapped_column(String(255), nullable=False)
external_user_id = mapped_column(String(255), nullable=True)
name = mapped_column(String(255))
is_anonymous: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
session_id: Mapped[str] = mapped_column()
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class AppMCPServer(Base):
__tablename__ = "app_mcp_servers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),
db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
sa.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),
sa.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
sa.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(String(255), nullable=False)
server_code: Mapped[str] = mapped_column(String(255), nullable=False)
status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying"))
parameters = mapped_column(db.Text, nullable=False)
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
parameters = mapped_column(sa.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_server_code(n):
@ -1478,34 +1478,34 @@ class AppMCPServer(Base):
class Site(Base):
__tablename__ = "sites"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="site_pkey"),
db.Index("site_app_id_idx", "app_id"),
db.Index("site_code_idx", "code", "status"),
sa.PrimaryKeyConstraint("id", name="site_pkey"),
sa.Index("site_app_id_idx", "app_id"),
sa.Index("site_code_idx", "code", "status"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
title: Mapped[str] = mapped_column(String(255), nullable=False)
icon_type = mapped_column(String(255), nullable=True)
icon = mapped_column(String(255))
icon_background = mapped_column(String(255))
description = mapped_column(db.Text)
description = mapped_column(sa.Text)
default_language: Mapped[str] = mapped_column(String(255), nullable=False)
chat_color_theme = mapped_column(String(255))
chat_color_theme_inverted: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
chat_color_theme_inverted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
copyright = mapped_column(String(255))
privacy_policy = mapped_column(String(255))
show_workflow_steps: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
show_workflow_steps: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
_custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="")
customize_domain = mapped_column(String(255))
customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False)
prompt_public: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying"))
prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
created_by = mapped_column(StringUUID, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
code = mapped_column(String(255))
@property
@ -1535,19 +1535,19 @@ class Site(Base):
class ApiToken(Base):
__tablename__ = "api_tokens"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_token_pkey"),
db.Index("api_token_app_id_type_idx", "app_id", "type"),
db.Index("api_token_token_idx", "token", "type"),
db.Index("api_token_tenant_idx", "tenant_id", "type"),
sa.PrimaryKeyConstraint("id", name="api_token_pkey"),
sa.Index("api_token_app_id_type_idx", "app_id", "type"),
sa.Index("api_token_token_idx", "token", "type"),
sa.Index("api_token_tenant_idx", "tenant_id", "type"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=True)
tenant_id = mapped_column(StringUUID, nullable=True)
type = mapped_column(String(16), nullable=False)
token: Mapped[str] = mapped_column(String(255), nullable=False)
last_used_at = mapped_column(db.DateTime, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
last_used_at = mapped_column(sa.DateTime, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_api_key(prefix, n):
@ -1561,26 +1561,26 @@ class ApiToken(Base):
class UploadFile(Base):
__tablename__ = "upload_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="upload_file_pkey"),
db.Index("upload_file_tenant_idx", "tenant_id"),
sa.PrimaryKeyConstraint("id", name="upload_file_pkey"),
sa.Index("upload_file_tenant_idx", "tenant_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
key: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
size: Mapped[int] = mapped_column(db.Integer, nullable=False)
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
extension: Mapped[str] = mapped_column(String(255), nullable=False)
mime_type: Mapped[str] = mapped_column(String(255), nullable=True)
created_by_role: Mapped[str] = mapped_column(
String(255), nullable=False, server_default=db.text("'account'::character varying")
String(255), nullable=False, server_default=sa.text("'account'::character varying")
)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True)
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
source_url: Mapped[str] = mapped_column(sa.TEXT, default="")
@ -1623,71 +1623,71 @@ class UploadFile(Base):
class ApiRequest(Base):
__tablename__ = "api_requests"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_request_pkey"),
db.Index("api_request_token_idx", "tenant_id", "api_token_id"),
sa.PrimaryKeyConstraint("id", name="api_request_pkey"),
sa.Index("api_request_token_idx", "tenant_id", "api_token_id"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
api_token_id = mapped_column(StringUUID, nullable=False)
path: Mapped[str] = mapped_column(String(255), nullable=False)
request = mapped_column(db.Text, nullable=True)
response = mapped_column(db.Text, nullable=True)
request = mapped_column(sa.Text, nullable=True)
response = mapped_column(sa.Text, nullable=True)
ip: Mapped[str] = mapped_column(String(255), nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageChain(Base):
__tablename__ = "message_chains"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_chain_pkey"),
db.Index("message_chain_message_id_idx", "message_id"),
sa.PrimaryKeyConstraint("id", name="message_chain_pkey"),
sa.Index("message_chain_message_id_idx", "message_id"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
message_id = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
input = mapped_column(db.Text, nullable=True)
output = mapped_column(db.Text, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
input = mapped_column(sa.Text, nullable=True)
output = mapped_column(sa.Text, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
class MessageAgentThought(Base):
__tablename__ = "message_agent_thoughts"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
db.Index("message_agent_thought_message_id_idx", "message_id"),
db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
sa.Index("message_agent_thought_message_id_idx", "message_id"),
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
message_id = mapped_column(StringUUID, nullable=False)
message_chain_id = mapped_column(StringUUID, nullable=True)
position: Mapped[int] = mapped_column(db.Integer, nullable=False)
thought = mapped_column(db.Text, nullable=True)
tool = mapped_column(db.Text, nullable=True)
tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
tool_meta_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
tool_input = mapped_column(db.Text, nullable=True)
observation = mapped_column(db.Text, nullable=True)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
thought = mapped_column(sa.Text, nullable=True)
tool = mapped_column(sa.Text, nullable=True)
tool_labels_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text"))
tool_meta_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text"))
tool_input = mapped_column(sa.Text, nullable=True)
observation = mapped_column(sa.Text, nullable=True)
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
tool_process_data = mapped_column(db.Text, nullable=True)
message = mapped_column(db.Text, nullable=True)
message_token: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
message_unit_price = mapped_column(db.Numeric, nullable=True)
message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
message_files = mapped_column(db.Text, nullable=True)
answer = db.Column(db.Text, nullable=True)
answer_token: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
answer_unit_price = mapped_column(db.Numeric, nullable=True)
answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
tokens: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
total_price = mapped_column(db.Numeric, nullable=True)
tool_process_data = mapped_column(sa.Text, nullable=True)
message = mapped_column(sa.Text, nullable=True)
message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
message_unit_price = mapped_column(sa.Numeric, nullable=True)
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
message_files = mapped_column(sa.Text, nullable=True)
answer = db.Column(sa.Text, nullable=True)
answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
answer_unit_price = mapped_column(sa.Numeric, nullable=True)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
total_price = mapped_column(sa.Numeric, nullable=True)
currency = mapped_column(String, nullable=True)
latency: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True)
latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
def files(self) -> list:
@ -1769,80 +1769,80 @@ class MessageAgentThought(Base):
class DatasetRetrieverResource(Base):
__tablename__ = "dataset_retriever_resources"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
db.Index("dataset_retriever_resource_message_id_idx", "message_id"),
sa.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
sa.Index("dataset_retriever_resource_message_id_idx", "message_id"),
)
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
message_id = mapped_column(StringUUID, nullable=False)
position: Mapped[int] = mapped_column(db.Integer, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
dataset_name = mapped_column(db.Text, nullable=False)
dataset_name = mapped_column(sa.Text, nullable=False)
document_id = mapped_column(StringUUID, nullable=True)
document_name = mapped_column(db.Text, nullable=False)
data_source_type = mapped_column(db.Text, nullable=True)
document_name = mapped_column(sa.Text, nullable=False)
data_source_type = mapped_column(sa.Text, nullable=True)
segment_id = mapped_column(StringUUID, nullable=True)
score: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True)
content = mapped_column(db.Text, nullable=False)
hit_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
word_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
segment_position: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
index_node_hash = mapped_column(db.Text, nullable=True)
retriever_from = mapped_column(db.Text, nullable=False)
score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
content = mapped_column(sa.Text, nullable=False)
hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
index_node_hash = mapped_column(sa.Text, nullable=True)
retriever_from = mapped_column(sa.Text, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
class Tag(Base):
__tablename__ = "tags"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tag_pkey"),
db.Index("tag_type_idx", "type"),
db.Index("tag_name_idx", "name"),
sa.PrimaryKeyConstraint("id", name="tag_pkey"),
sa.Index("tag_type_idx", "type"),
sa.Index("tag_name_idx", "name"),
)
TAG_TYPE_LIST = ["knowledge", "app"]
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=True)
type = mapped_column(String(16), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class TagBinding(Base):
__tablename__ = "tag_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
db.Index("tag_bind_target_id_idx", "target_id"),
db.Index("tag_bind_tag_id_idx", "tag_id"),
sa.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
sa.Index("tag_bind_target_id_idx", "target_id"),
sa.Index("tag_bind_tag_id_idx", "tag_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=True)
tag_id = mapped_column(StringUUID, nullable=True)
target_id = mapped_column(StringUUID, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class TraceAppConfig(Base):
__tablename__ = "trace_app_config"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
db.Index("trace_app_config_app_id_idx", "app_id"),
sa.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
sa.Index("trace_app_config_app_id_idx", "app_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
tracing_provider = mapped_column(String(255), nullable=True)
tracing_config = mapped_column(db.JSON, nullable=True)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
tracing_config = mapped_column(sa.JSON, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
is_active: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
@property
def tracing_config_dict(self):

View File

@ -2,11 +2,11 @@ from datetime import datetime
from enum import Enum
from typing import Optional
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, text
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .engine import db
from .types import StringUUID
@ -47,9 +47,9 @@ class Provider(Base):
__tablename__ = "providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="provider_pkey"),
db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"),
db.UniqueConstraint(
sa.PrimaryKeyConstraint("id", name="provider_pkey"),
sa.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"),
sa.UniqueConstraint(
"tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota"
),
)
@ -60,15 +60,15 @@ class Provider(Base):
provider_type: Mapped[str] = mapped_column(
String(40), nullable=False, server_default=text("'custom'::character varying")
)
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
quota_type: Mapped[Optional[str]] = mapped_column(
String(40), nullable=True, server_default=text("''::character varying")
)
quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True)
quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0)
quota_limit: Mapped[Optional[int]] = mapped_column(sa.BigInteger, nullable=True)
quota_used: Mapped[Optional[int]] = mapped_column(sa.BigInteger, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -104,9 +104,9 @@ class ProviderModel(Base):
__tablename__ = "provider_models"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="provider_model_pkey"),
db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"),
db.UniqueConstraint(
sa.PrimaryKeyConstraint("id", name="provider_model_pkey"),
sa.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"),
sa.UniqueConstraint(
"tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name"
),
)
@ -116,8 +116,8 @@ class ProviderModel(Base):
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -125,8 +125,8 @@ class ProviderModel(Base):
class TenantDefaultModel(Base):
__tablename__ = "tenant_default_models"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
sa.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
@ -141,8 +141,8 @@ class TenantDefaultModel(Base):
class TenantPreferredModelProvider(Base):
__tablename__ = "tenant_preferred_model_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
sa.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
sa.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
@ -156,8 +156,8 @@ class TenantPreferredModelProvider(Base):
class ProviderOrder(Base):
__tablename__ = "provider_orders"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="provider_order_pkey"),
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
sa.PrimaryKeyConstraint("id", name="provider_order_pkey"),
sa.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
@ -167,9 +167,9 @@ class ProviderOrder(Base):
payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False)
payment_id: Mapped[Optional[str]] = mapped_column(String(191))
transaction_id: Mapped[Optional[str]] = mapped_column(String(191))
quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1"))
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
currency: Mapped[Optional[str]] = mapped_column(String(40))
total_amount: Mapped[Optional[int]] = mapped_column(db.Integer)
total_amount: Mapped[Optional[int]] = mapped_column(sa.Integer)
payment_status: Mapped[str] = mapped_column(
String(40), nullable=False, server_default=text("'wait_pay'::character varying")
)
@ -187,8 +187,8 @@ class ProviderModelSetting(Base):
__tablename__ = "provider_model_settings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"),
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
sa.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"),
sa.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
@ -196,8 +196,8 @@ class ProviderModelSetting(Base):
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -209,8 +209,8 @@ class LoadBalancingModelConfig(Base):
__tablename__ = "load_balancing_model_configs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"),
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
sa.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"),
sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
@ -219,7 +219,7 @@ class LoadBalancingModelConfig(Base):
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -2,50 +2,50 @@ import json
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from models.base import Base
from .engine import db
from .types import StringUUID
class DataSourceOauthBinding(Base):
__tablename__ = "data_source_oauth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
db.Index("source_binding_tenant_id_idx", "tenant_id"),
db.Index("source_info_idx", "source_info", postgresql_using="gin"),
sa.PrimaryKeyConstraint("id", name="source_binding_pkey"),
sa.Index("source_binding_tenant_id_idx", "tenant_id"),
sa.Index("source_info_idx", "source_info", postgresql_using="gin"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
source_info = mapped_column(JSONB, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
disabled: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
class DataSourceApiKeyAuthBinding(Base):
__tablename__ = "data_source_api_key_auth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"),
db.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
sa.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
sa.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"),
sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
category: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
credentials = mapped_column(db.Text, nullable=True) # JSON
credentials = mapped_column(sa.Text, nullable=True) # JSON
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
disabled: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
def to_dict(self):
return {

View File

@ -1,6 +1,7 @@
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from celery import states # type: ignore
from sqlalchemy import DateTime, String
from sqlalchemy.orm import Mapped, mapped_column
@ -16,7 +17,7 @@ class CeleryTask(Base):
__tablename__ = "celery_taskmeta"
id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
task_id = mapped_column(String(155), unique=True)
status = mapped_column(String(50), default=states.PENDING)
result = mapped_column(db.PickleType, nullable=True)
@ -26,12 +27,12 @@ class CeleryTask(Base):
onupdate=lambda: naive_utc_now(),
nullable=True,
)
traceback = mapped_column(db.Text, nullable=True)
traceback = mapped_column(sa.Text, nullable=True)
name = mapped_column(String(155), nullable=True)
args = mapped_column(db.LargeBinary, nullable=True)
kwargs = mapped_column(db.LargeBinary, nullable=True)
args = mapped_column(sa.LargeBinary, nullable=True)
kwargs = mapped_column(sa.LargeBinary, nullable=True)
worker = mapped_column(String(155), nullable=True)
retries: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
retries: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
queue = mapped_column(String(155), nullable=True)
@ -41,7 +42,7 @@ class CeleryTaskSet(Base):
__tablename__ = "celery_tasksetmeta"
id: Mapped[int] = mapped_column(
db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
)
taskset_id = mapped_column(String(155), unique=True)
result = mapped_column(db.PickleType, nullable=True)

View File

@ -25,33 +25,33 @@ from .types import StringUUID
class ToolOAuthSystemClient(Base):
__tablename__ = "tool_oauth_system_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
sa.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
plugin_id = mapped_column(String(512), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
# tenant level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthTenantClient(Base):
__tablename__ = "tool_oauth_tenant_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
sa.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
@property
def oauth_params(self) -> dict:
@ -65,14 +65,14 @@ class BuiltinToolProvider(Base):
__tablename__ = "tool_builtin_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
sa.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
sa.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
)
# id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(
String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
String(256), nullable=False, server_default=sa.text("'API KEY 1'::character varying")
)
# id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
@ -81,19 +81,19 @@ class BuiltinToolProvider(Base):
# name of the tool provider
provider: Mapped[str] = mapped_column(String(256), nullable=False)
# credential of the tool provider
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
# credential type, e.g., "api-key", "oauth2"
credential_type: Mapped[str] = mapped_column(
String(32), nullable=False, server_default=db.text("'api-key'::character varying")
String(32), nullable=False, server_default=sa.text("'api-key'::character varying")
)
expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))
@property
def credentials(self) -> dict:
@ -107,28 +107,28 @@ class ApiToolProvider(Base):
__tablename__ = "tool_api_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
sa.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# name of the api provider
name = mapped_column(String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
name = mapped_column(String(255), nullable=False, server_default=sa.text("'API KEY 1'::character varying"))
# icon
icon: Mapped[str] = mapped_column(String(255), nullable=False)
# original schema
schema = mapped_column(db.Text, nullable=False)
schema = mapped_column(sa.Text, nullable=False)
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
# who created this tool
user_id = mapped_column(StringUUID, nullable=False)
# tenant id
tenant_id = mapped_column(StringUUID, nullable=False)
# description of the provider
description = mapped_column(db.Text, nullable=False)
description = mapped_column(sa.Text, nullable=False)
# json format tools
tools_str = mapped_column(db.Text, nullable=False)
tools_str = mapped_column(sa.Text, nullable=False)
# json format credentials
credentials_str = mapped_column(db.Text, nullable=False)
credentials_str = mapped_column(sa.Text, nullable=False)
# privacy policy
privacy_policy = mapped_column(String(255), nullable=True)
# custom_disclaimer
@ -167,11 +167,11 @@ class ToolLabelBinding(Base):
__tablename__ = "tool_label_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"),
db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
sa.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"),
sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# tool id
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
# tool type
@ -187,12 +187,12 @@ class WorkflowToolProvider(Base):
__tablename__ = "tool_workflow_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"),
db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"),
db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
sa.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"),
sa.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"),
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# name of the workflow provider
name: Mapped[str] = mapped_column(String(255), nullable=False)
# label of the workflow provider
@ -208,17 +208,17 @@ class WorkflowToolProvider(Base):
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# description of the provider
description: Mapped[str] = mapped_column(db.Text, nullable=False)
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
# parameter configuration
parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]")
parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]")
# privacy policy
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="")
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
@property
@ -245,19 +245,19 @@ class MCPToolProvider(Base):
__tablename__ = "tool_mcp_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
db.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"),
db.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"),
db.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
sa.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
sa.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"),
sa.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"),
sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# name of the mcp provider
name: Mapped[str] = mapped_column(String(40), nullable=False)
# server identifier of the mcp provider
server_identifier: Mapped[str] = mapped_column(String(64), nullable=False)
# encrypted url of the mcp provider
server_url: Mapped[str] = mapped_column(db.Text, nullable=False)
server_url: Mapped[str] = mapped_column(sa.Text, nullable=False)
# hash of server_url for uniqueness check
server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False)
# icon of the mcp provider
@ -267,16 +267,16 @@ class MCPToolProvider(Base):
# who created this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# encrypted credentials
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
# authed
authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False)
authed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
# tools
tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]")
tools: Mapped[str] = mapped_column(sa.Text, nullable=False, default="[]")
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
)
def load_user(self) -> Account | None:
@ -347,9 +347,9 @@ class ToolModelInvoke(Base):
"""
__tablename__ = "tool_model_invokes"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# who invoke this tool
user_id = mapped_column(StringUUID, nullable=False)
# tenant id
@ -361,18 +361,18 @@ class ToolModelInvoke(Base):
# tool name
tool_name = mapped_column(String(128), nullable=False)
# invoke parameters
model_parameters = mapped_column(db.Text, nullable=False)
model_parameters = mapped_column(sa.Text, nullable=False)
# prompt messages
prompt_messages = mapped_column(db.Text, nullable=False)
prompt_messages = mapped_column(sa.Text, nullable=False)
# invoke response
model_response = mapped_column(db.Text, nullable=False)
model_response = mapped_column(sa.Text, nullable=False)
prompt_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0"))
total_price = mapped_column(db.Numeric(10, 7))
prompt_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_price = mapped_column(sa.Numeric(10, 7))
currency: Mapped[str] = mapped_column(String(255), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@ -386,13 +386,13 @@ class ToolConversationVariables(Base):
__tablename__ = "tool_conversation_variables"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"),
sa.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"),
# add index for user_id and conversation_id
db.Index("user_id_idx", "user_id"),
db.Index("conversation_id_idx", "conversation_id"),
sa.Index("user_id_idx", "user_id"),
sa.Index("conversation_id_idx", "conversation_id"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# conversation user id
user_id = mapped_column(StringUUID, nullable=False)
# tenant id
@ -400,7 +400,7 @@ class ToolConversationVariables(Base):
# conversation id
conversation_id = mapped_column(StringUUID, nullable=False)
# variables pool
variables_str = mapped_column(db.Text, nullable=False)
variables_str = mapped_column(sa.Text, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@ -417,11 +417,11 @@ class ToolFile(Base):
__tablename__ = "tool_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_file_pkey"),
db.Index("tool_file_conversation_id_idx", "conversation_id"),
sa.PrimaryKeyConstraint("id", name="tool_file_pkey"),
sa.Index("tool_file_conversation_id_idx", "conversation_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# conversation user id
user_id: Mapped[str] = mapped_column(StringUUID)
# tenant id
@ -448,30 +448,30 @@ class DeprecatedPublishedAppTool(Base):
__tablename__ = "tool_published_apps"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
sa.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# id of the app
app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# who published this tool
description = mapped_column(db.Text, nullable=False)
description = mapped_column(sa.Text, nullable=False)
# llm_description of the tool, for LLM
llm_description = mapped_column(db.Text, nullable=False)
llm_description = mapped_column(sa.Text, nullable=False)
# query description, query will be seem as a parameter of the tool,
# to describe this parameter to llm, we need this field
query_description = mapped_column(db.Text, nullable=False)
query_description = mapped_column(sa.Text, nullable=False)
# query name, the name of the query parameter
query_name = mapped_column(String(40), nullable=False)
# name of the tool provider
tool_name = mapped_column(String(40), nullable=False)
# author
author = mapped_column(String(40), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
@property
def description_i18n(self) -> I18nObject:

View File

@ -1,5 +1,6 @@
from datetime import datetime
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column
@ -13,15 +14,15 @@ from .types import StringUUID
class SavedMessage(Base):
__tablename__ = "saved_messages"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="saved_message_pkey"),
db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
sa.PrimaryKeyConstraint("id", name="saved_message_pkey"),
sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
message_id = mapped_column(StringUUID, nullable=False)
created_by_role = mapped_column(
String(255), nullable=False, server_default=db.text("'end_user'::character varying")
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
@ -34,15 +35,15 @@ class SavedMessage(Base):
class PinnedConversation(Base):
__tablename__ = "pinned_conversations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
sa.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
)
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID)
created_by_role = mapped_column(
String(255), nullable=False, server_default=db.text("'end_user'::character varying")
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -6,6 +6,7 @@ from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4
import sqlalchemy as sa
from flask_login import current_user
from sqlalchemy import DateTime, orm
@ -24,7 +25,6 @@ from ._workflow_exc import NodeNotFoundError, WorkflowDataError
if TYPE_CHECKING:
from models.model import AppMode
import sqlalchemy as sa
from sqlalchemy import Index, PrimaryKeyConstraint, String, UniqueConstraint, func
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
@ -117,11 +117,11 @@ class Workflow(Base):
__tablename__ = "workflows"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_pkey"),
db.Index("workflow_version_idx", "tenant_id", "app_id", "version"),
sa.PrimaryKeyConstraint("id", name="workflow_pkey"),
sa.Index("workflow_version_idx", "tenant_id", "app_id", "version"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
type: Mapped[str] = mapped_column(String(255), nullable=False)
@ -140,10 +140,10 @@ class Workflow(Base):
server_onupdate=func.current_timestamp(),
)
_environment_variables: Mapped[str] = mapped_column(
"environment_variables", db.Text, nullable=False, server_default="{}"
"environment_variables", sa.Text, nullable=False, server_default="{}"
)
_conversation_variables: Mapped[str] = mapped_column(
"conversation_variables", db.Text, nullable=False, server_default="{}"
"conversation_variables", sa.Text, nullable=False, server_default="{}"
)
VERSION_DRAFT = "draft"
@ -491,11 +491,11 @@ class WorkflowRun(Base):
__tablename__ = "workflow_runs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
@ -503,19 +503,19 @@ class WorkflowRun(Base):
type: Mapped[str] = mapped_column(String(255))
triggered_from: Mapped[str] = mapped_column(String(255))
version: Mapped[str] = mapped_column(String(255))
graph: Mapped[Optional[str]] = mapped_column(db.Text)
inputs: Mapped[Optional[str]] = mapped_column(db.Text)
graph: Mapped[Optional[str]] = mapped_column(sa.Text)
inputs: Mapped[Optional[str]] = mapped_column(sa.Text)
status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0"))
error: Mapped[Optional[str]] = mapped_column(sa.Text)
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
@property
def created_by_account(self):
@ -704,25 +704,25 @@ class WorkflowNodeExecutionModel(Base):
),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID)
triggered_from: Mapped[str] = mapped_column(String(255))
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
index: Mapped[int] = mapped_column(db.Integer)
index: Mapped[int] = mapped_column(sa.Integer)
predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255))
node_execution_id: Mapped[Optional[str]] = mapped_column(String(255))
node_id: Mapped[str] = mapped_column(String(255))
node_type: Mapped[str] = mapped_column(String(255))
title: Mapped[str] = mapped_column(String(255))
inputs: Mapped[Optional[str]] = mapped_column(db.Text)
process_data: Mapped[Optional[str]] = mapped_column(db.Text)
outputs: Mapped[Optional[str]] = mapped_column(db.Text)
inputs: Mapped[Optional[str]] = mapped_column(sa.Text)
process_data: Mapped[Optional[str]] = mapped_column(sa.Text)
outputs: Mapped[Optional[str]] = mapped_column(sa.Text)
status: Mapped[str] = mapped_column(String(255))
error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0"))
execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text)
error: Mapped[Optional[str]] = mapped_column(sa.Text)
elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0"))
execution_metadata: Mapped[Optional[str]] = mapped_column(sa.Text)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
created_by_role: Mapped[str] = mapped_column(String(255))
created_by: Mapped[str] = mapped_column(StringUUID)
@ -834,11 +834,11 @@ class WorkflowAppLog(Base):
__tablename__ = "workflow_app_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
sa.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
sa.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@ -864,6 +864,19 @@ class WorkflowAppLog(Base):
created_by_role = CreatorUserRole(self.created_by_role)
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
def to_dict(self):
return {
"id": self.id,
"tenant_id": self.tenant_id,
"app_id": self.app_id,
"workflow_id": self.workflow_id,
"workflow_run_id": self.workflow_run_id,
"created_from": self.created_from,
"created_by_role": self.created_by_role,
"created_by": self.created_by,
"created_at": self.created_at,
}
class ConversationVariable(Base):
__tablename__ = "workflow_conversation_variables"
@ -871,7 +884,7 @@ class ConversationVariable(Base):
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
data: Mapped[str] = mapped_column(db.Text, nullable=False)
data: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), index=True
)
@ -933,7 +946,7 @@ class WorkflowDraftVariable(Base):
__allow_unmapped__ = True
# id is the unique identifier of a draft variable.
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
created_at: Mapped[datetime] = mapped_column(
DateTime,

View File

@ -49,6 +49,8 @@ dependencies = [
"opentelemetry-instrumentation==0.48b0",
"opentelemetry-instrumentation-celery==0.48b0",
"opentelemetry-instrumentation-flask==0.48b0",
"opentelemetry-instrumentation-redis==0.48b0",
"opentelemetry-instrumentation-requests==0.48b0",
"opentelemetry-instrumentation-sqlalchemy==0.48b0",
"opentelemetry-propagator-b3==1.27.0",
# opentelemetry-proto1.28.0 depends on protobuf (>=5.0,<6.0),
@ -114,6 +116,7 @@ dev = [
"pytest-cov~=4.1.0",
"pytest-env~=1.1.3",
"pytest-mock~=3.14.0",
"testcontainers~=4.10.0",
"types-aiofiles~=24.1.0",
"types-beautifulsoup4~=4.12.0",
"types-cachetools~=5.5.0",

View File

@ -12,6 +12,7 @@ import yaml # type: ignore
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from packaging import version
from packaging.version import parse as parse_version
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -269,7 +270,7 @@ class AppDslService:
check_dependencies_pending_data = None
if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
elif imported_version <= "0.1.5":
elif parse_version(imported_version) <= parse_version("0.1.5"):
if "workflow" in data:
graph = data.get("workflow", {}).get("graph", {})
dependencies_list = self._extract_dependencies_from_workflow_graph(graph)

View File

@ -1,5 +1,6 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Union
from typing import Any, Optional, Union
from openai._exceptions import RateLimitError
@ -15,6 +16,7 @@ from libs.helper import RateLimiter
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
from services.billing_service import BillingService
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.workflow_service import WorkflowService
@ -86,7 +88,8 @@ class AppGenerateService:
request_id=request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow = cls._get_workflow(app_model, invoke_from)
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().generate(
@ -101,7 +104,8 @@ class AppGenerateService:
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW.value:
workflow = cls._get_workflow(app_model, invoke_from)
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
@ -210,14 +214,27 @@ class AppGenerateService:
)
@classmethod
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow:
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow:
"""
Get workflow
:param app_model: app model
:param invoke_from: invoke from
:param workflow_id: optional workflow id to specify a specific version
:return:
"""
workflow_service = WorkflowService()
# If workflow_id is specified, get the specific workflow version
if workflow_id:
try:
workflow_uuid = uuid.UUID(workflow_id)
except ValueError:
raise WorkflowIdFormatError(f"Invalid workflow_id format: '{workflow_id}'. ")
workflow = workflow_service.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id)
if not workflow:
raise WorkflowNotFoundError(f"Workflow not found with id: {workflow_id}")
return workflow
if invoke_from == InvokeFrom.DEBUGGER:
# fetch draft workflow by app_model
workflow = workflow_service.get_draft_workflow(app_model=app_model)

View File

@ -159,9 +159,9 @@ class BillingService:
):
limiter_key = f"{account_id}:{tenant_id}"
if cls.compliance_download_rate_limiter.is_rate_limited(limiter_key):
from controllers.console.error import CompilanceRateLimitError
from controllers.console.error import ComplianceRateLimitError
raise CompilanceRateLimitError()
raise ComplianceRateLimitError()
json = {
"doc_name": doc_name,

View File

@ -13,7 +13,19 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
from models.model import App, Conversation, Message
from models.model import (
App,
AppAnnotationHitHistory,
Conversation,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from models.workflow import WorkflowAppLog
from repositories.factory import DifyAPIRepositoryFactory
from services.billing_service import BillingService
@ -21,6 +33,85 @@ logger = logging.getLogger(__name__)
class ClearFreePlanTenantExpiredLogs:
@classmethod
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None:
"""
Clean up message-related tables to avoid data redundancy.
This method cleans up tables that have foreign key relationships with Message.
Args:
session: Database session, the same with the one in process_tenant method
tenant_id: Tenant ID for logging purposes
batch_message_ids: List of message IDs to clean up
"""
if not batch_message_ids:
return
# Clean up each related table
related_tables = [
(MessageFeedback, "message_feedbacks"),
(MessageFile, "message_files"),
(MessageAnnotation, "message_annotations"),
(MessageChain, "message_chains"),
(MessageAgentThought, "message_agent_thoughts"),
(AppAnnotationHitHistory, "app_annotation_hit_histories"),
(SavedMessage, "saved_messages"),
]
for model, table_name in related_tables:
# Query records related to expired messages
records = (
session.query(model)
.filter(
model.message_id.in_(batch_message_ids), # type: ignore
)
.all()
)
if len(records) == 0:
continue
# Save records before deletion
record_ids = [record.id for record in records]
try:
record_data = []
for record in records:
try:
if hasattr(record, "to_dict"):
record_data.append(record.to_dict())
else:
# if record doesn't have to_dict method, we need to transform it to dict manually
record_dict = {}
for column in record.__table__.columns:
record_dict[column.name] = getattr(record, column.name)
record_data.append(record_dict)
except Exception:
logger.exception("Failed to transform %s record: %s", table_name, record.id)
continue
if record_data:
storage.save(
f"free_plan_tenant_expired_logs/"
f"{tenant_id}/{table_name}/{datetime.datetime.now().strftime('%Y-%m-%d')}"
f"-{time.time()}.json",
json.dumps(
jsonable_encoder(record_data),
).encode("utf-8"),
)
except Exception:
logger.exception("Failed to save %s records", table_name)
session.query(model).filter(
model.id.in_(record_ids), # type: ignore
).delete(synchronize_session=False)
click.echo(
click.style(
f"[{datetime.datetime.now()}] Processed {len(record_ids)} "
f"{table_name} records for tenant {tenant_id}"
)
)
@classmethod
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
with flask_app.app_context():
@ -58,6 +149,7 @@ class ClearFreePlanTenantExpiredLogs:
Message.id.in_(message_ids),
).delete(synchronize_session=False)
cls._clear_message_related_tables(session, tenant_id, message_ids)
session.commit()
click.echo(
@ -199,6 +291,48 @@ class ClearFreePlanTenantExpiredLogs:
if len(workflow_runs) < batch:
break
while True:
with Session(db.engine).no_autoflush as session:
workflow_app_logs = (
session.query(WorkflowAppLog)
.filter(
WorkflowAppLog.tenant_id == tenant_id,
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
if len(workflow_app_logs) == 0:
break
# save workflow app logs
storage.save(
f"free_plan_tenant_expired_logs/"
f"{tenant_id}/workflow_app_logs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
f"-{time.time()}.json",
json.dumps(
jsonable_encoder(
[workflow_app_log.to_dict() for workflow_app_log in workflow_app_logs],
),
).encode("utf-8"),
)
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
# delete workflow app logs
session.query(WorkflowAppLog).filter(
WorkflowAppLog.id.in_(workflow_app_log_ids),
).delete(synchronize_session=False)
session.commit()
click.echo(
click.style(
f"[{datetime.datetime.now()}] Processed {len(workflow_app_log_ids)}"
f" workflow app logs for tenant {tenant_id}"
)
)
@classmethod
def process(cls, days: int, batch: int, tenant_ids: list[str]):
"""

View File

@ -266,7 +266,7 @@ class DatasetService:
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(f"The dataset in unavailable, due to: {ex.description}")
raise ValueError(f"The dataset is unavailable, due to: {ex.description}")
@staticmethod
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str):
@ -370,7 +370,7 @@ class DatasetService:
raise ValueError("External knowledge api id is required.")
# Update metadata fields
dataset.updated_by = user.id if user else None
dataset.updated_at = datetime.datetime.utcnow()
dataset.updated_at = naive_utc_now()
db.session.add(dataset)
# Update external knowledge binding
@ -2372,7 +2372,7 @@ class SegmentService:
)
if not segments:
return
real_deal_segmment_ids = []
real_deal_segment_ids = []
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
@ -2382,10 +2382,10 @@ class SegmentService:
segment.disabled_at = None
segment.disabled_by = None
db.session.add(segment)
real_deal_segmment_ids.append(segment.id)
real_deal_segment_ids.append(segment.id)
db.session.commit()
enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
elif action == "disable":
segments = (
db.session.query(DocumentSegment)
@ -2399,7 +2399,7 @@ class SegmentService:
)
if not segments:
return
real_deal_segmment_ids = []
real_deal_segment_ids = []
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
@ -2409,10 +2409,10 @@ class SegmentService:
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id
db.session.add(segment)
real_deal_segmment_ids.append(segment.id)
real_deal_segment_ids.append(segment.id)
db.session.commit()
disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
else:
raise InvalidActionError()
@ -2670,7 +2670,7 @@ class SegmentService:
# check segment
segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
if not segment:

View File

@ -52,6 +52,16 @@ class EnterpriseService:
return data.get("result", False)
@classmethod
def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_codes: list[str]):
if not app_codes:
return {}
body = {"userId": user_id, "appCodes": app_codes}
data = EnterpriseRequest.send_request("POST", "/webapp/permission/batch", json=body)
if not data:
raise ValueError("No data found.")
return data.get("permissions", {})
@classmethod
def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings:
if not app_id:

View File

@ -8,3 +8,11 @@ class WorkflowHashNotEqualError(Exception):
class IsDraftWorkflowError(Exception):
pass
class WorkflowNotFoundError(Exception):
pass
class WorkflowIdFormatError(Exception):
pass

View File

@ -79,7 +79,10 @@ class MetadataService:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
for document in documents:
doc_metadata = copy.deepcopy(document.doc_metadata)
if not document.doc_metadata:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
value = doc_metadata.pop(old_name, None)
doc_metadata[name] = value
document.doc_metadata = doc_metadata
@ -109,7 +112,10 @@ class MetadataService:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
for document in documents:
doc_metadata = copy.deepcopy(document.doc_metadata)
if not document.doc_metadata:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata.pop(metadata.name, None)
document.doc_metadata = doc_metadata
db.session.add(document)
@ -137,7 +143,6 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset.id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
dataset.built_in_field_enabled = True
db.session.add(dataset)
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
if documents:
@ -153,6 +158,7 @@ class MetadataService:
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
document.doc_metadata = doc_metadata
db.session.add(document)
dataset.built_in_field_enabled = True
db.session.commit()
except Exception:
logging.exception("Enable built-in field failed")
@ -166,13 +172,15 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset.id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
dataset.built_in_field_enabled = False
db.session.add(dataset)
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
document_ids = []
if documents:
for document in documents:
doc_metadata = copy.deepcopy(document.doc_metadata)
if not document.doc_metadata:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata.pop(BuiltInField.document_name.value, None)
doc_metadata.pop(BuiltInField.uploader.value, None)
doc_metadata.pop(BuiltInField.upload_date.value, None)
@ -181,6 +189,7 @@ class MetadataService:
document.doc_metadata = doc_metadata
db.session.add(document)
document_ids.append(document.id)
dataset.built_in_field_enabled = False
db.session.commit()
except Exception:
logging.exception("Disable built-in field failed")

View File

@ -2,6 +2,7 @@ import json
import logging
import click
import sqlalchemy as sa
from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
from models.engine import db
@ -38,7 +39,7 @@ class PluginDataMigration:
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
limit 1000"""
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql))
rs = conn.execute(sa.text(sql))
current_iter_count = 0
for i in rs:
@ -94,7 +95,7 @@ limit 1000"""
:provider_name
{update_retrieval_model_sql}
where id = :record_id"""
conn.execute(db.text(sql), params)
conn.execute(sa.text(sql), params)
click.echo(
click.style(
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
@ -148,7 +149,7 @@ limit 1000"""
params = {"last_id": last_id or ""}
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql), params)
rs = conn.execute(sa.text(sql), params)
current_iter_count = 0
batch_updates = []
@ -193,7 +194,7 @@ limit 1000"""
SET {provider_column_name} = :updated_value
WHERE id = :record_id
"""
conn.execute(db.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
click.echo(
click.style(
f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]",

View File

@ -9,6 +9,7 @@ from typing import Any, Optional
from uuid import uuid4
import click
import sqlalchemy as sa
import tqdm
from flask import Flask, current_app
from sqlalchemy.orm import Session
@ -197,7 +198,7 @@ class PluginMigration:
"""
with Session(db.engine) as session:
rs = session.execute(
db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
sa.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
)
result = []
for row in rs:

View File

@ -422,7 +422,7 @@ class WorkflowDraftVariableService:
description=conv_var.description,
)
draft_conv_vars.append(draft_var)
_batch_upsert_draft_varaible(
_batch_upsert_draft_variable(
self._session,
draft_conv_vars,
policy=_UpsertPolicy.IGNORE,
@ -434,7 +434,7 @@ class _UpsertPolicy(StrEnum):
OVERWRITE = "overwrite"
def _batch_upsert_draft_varaible(
def _batch_upsert_draft_variable(
session: Session,
draft_vars: Sequence[WorkflowDraftVariable],
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
@ -721,7 +721,7 @@ class DraftVariableSaver:
draft_vars = self._build_variables_from_start_mapping(outputs)
else:
draft_vars = self._build_variables_from_mapping(outputs)
_batch_upsert_draft_varaible(self._session, draft_vars)
_batch_upsert_draft_variable(self._session, draft_vars)
@staticmethod
def _should_variable_be_editable(node_id: str, name: str) -> bool:

View File

@ -129,7 +129,10 @@ class WorkflowService:
if not workflow:
return None
if workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError(f"Workflow is draft version, id={workflow_id}")
raise IsDraftWorkflowError(
f"Cannot use draft workflow version. Workflow ID: {workflow_id}. "
f"Please use a published workflow version or leave workflow_id empty."
)
return workflow
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
@ -441,9 +444,9 @@ class WorkflowService:
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
"""
Run draft workflow node
Run free workflow node
"""
# run draft workflow node
# run free workflow node
start_at = time.perf_counter()
node_execution = self._handle_node_run_result(

View File

@ -3,6 +3,7 @@ import time
from collections.abc import Callable
import click
import sqlalchemy as sa
from celery import shared_task # type: ignore
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
@ -331,7 +332,7 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
with db.engine.begin() as conn:
rs = conn.execute(db.text(query_sql), params)
rs = conn.execute(sa.text(query_sql), params)
if rs.rowcount == 0:
break

View File

@ -0,0 +1,328 @@
"""
TestContainers-based integration test configuration for Dify API.
This module provides containerized test infrastructure using TestContainers library
to spin up real database and service instances for integration testing. This approach
ensures tests run against actual service implementations rather than mocks, providing
more reliable and realistic test scenarios.
"""
import logging
import os
from collections.abc import Generator
from typing import Optional
import pytest
from flask import Flask
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs
from testcontainers.postgres import PostgresContainer
from testcontainers.redis import RedisContainer
from app_factory import create_app
from models import db
# Configure logging for test containers
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class DifyTestContainers:
"""
Manages all test containers required for Dify integration tests.
This class provides a centralized way to manage multiple containers
needed for comprehensive integration testing, including databases,
caches, and search engines.
"""
def __init__(self):
"""Initialize container management with default configurations."""
self.postgres: Optional[PostgresContainer] = None
self.redis: Optional[RedisContainer] = None
self.dify_sandbox: Optional[DockerContainer] = None
self._containers_started = False
logger.info("DifyTestContainers initialized - ready to manage test containers")
def start_containers_with_env(self) -> None:
"""
Start all required containers for integration testing.
This method initializes and starts PostgreSQL, Redis
containers with appropriate configurations for Dify testing. Containers
are started in dependency order to ensure proper initialization.
"""
if self._containers_started:
logger.info("Containers already started - skipping container startup")
return
logger.info("Starting test containers for Dify integration tests...")
# Start PostgreSQL container for main application database
# PostgreSQL is used for storing user data, workflows, and application state
logger.info("Initializing PostgreSQL container...")
self.postgres = PostgresContainer(
image="postgres:16-alpine",
)
self.postgres.start()
db_host = self.postgres.get_container_host_ip()
db_port = self.postgres.get_exposed_port(5432)
os.environ["DB_HOST"] = db_host
os.environ["DB_PORT"] = str(db_port)
os.environ["DB_USERNAME"] = self.postgres.username
os.environ["DB_PASSWORD"] = self.postgres.password
os.environ["DB_DATABASE"] = self.postgres.dbname
logger.info(
"PostgreSQL container started successfully - Host: %s, Port: %s User: %s, Database: %s",
db_host,
db_port,
self.postgres.username,
self.postgres.dbname,
)
# Wait for PostgreSQL to be ready
logger.info("Waiting for PostgreSQL to be ready to accept connections...")
wait_for_logs(self.postgres, "is ready to accept connections", timeout=30)
logger.info("PostgreSQL container is ready and accepting connections")
# Install uuid-ossp extension for UUID generation
logger.info("Installing uuid-ossp extension...")
try:
import psycopg2
conn = psycopg2.connect(
host=db_host,
port=db_port,
user=self.postgres.username,
password=self.postgres.password,
database=self.postgres.dbname,
)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
cursor.close()
conn.close()
logger.info("uuid-ossp extension installed successfully")
except Exception as e:
logger.warning("Failed to install uuid-ossp extension: %s", e)
# Set up storage environment variables
os.environ["STORAGE_TYPE"] = "opendal"
os.environ["OPENDAL_SCHEME"] = "fs"
os.environ["OPENDAL_FS_ROOT"] = "storage"
# Start Redis container for caching and session management
# Redis is used for storing session data, cache entries, and temporary data
logger.info("Initializing Redis container...")
self.redis = RedisContainer(image="redis:latest", port=6379)
self.redis.start()
redis_host = self.redis.get_container_host_ip()
redis_port = self.redis.get_exposed_port(6379)
os.environ["REDIS_HOST"] = redis_host
os.environ["REDIS_PORT"] = str(redis_port)
logger.info("Redis container started successfully - Host: %s, Port: %s", redis_host, redis_port)
# Wait for Redis to be ready
logger.info("Waiting for Redis to be ready to accept connections...")
wait_for_logs(self.redis, "Ready to accept connections", timeout=30)
logger.info("Redis container is ready and accepting connections")
# Start Dify Sandbox container for code execution environment
# Dify Sandbox provides a secure environment for executing user code
logger.info("Initializing Dify Sandbox container...")
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest")
self.dify_sandbox.with_exposed_ports(8194)
self.dify_sandbox.env = {
"API_KEY": "test_api_key",
}
self.dify_sandbox.start()
sandbox_host = self.dify_sandbox.get_container_host_ip()
sandbox_port = self.dify_sandbox.get_exposed_port(8194)
os.environ["CODE_EXECUTION_ENDPOINT"] = f"http://{sandbox_host}:{sandbox_port}"
os.environ["CODE_EXECUTION_API_KEY"] = "test_api_key"
logger.info("Dify Sandbox container started successfully - Host: %s, Port: %s", sandbox_host, sandbox_port)
# Wait for Dify Sandbox to be ready
logger.info("Waiting for Dify Sandbox to be ready to accept connections...")
wait_for_logs(self.dify_sandbox, "config init success", timeout=60)
logger.info("Dify Sandbox container is ready and accepting connections")
self._containers_started = True
logger.info("All test containers started successfully")
def stop_containers(self) -> None:
"""
Stop and clean up all test containers.
This method ensures proper cleanup of all containers to prevent
resource leaks and conflicts between test runs.
"""
if not self._containers_started:
logger.info("No containers to stop - containers were not started")
return
logger.info("Stopping and cleaning up test containers...")
containers = [self.redis, self.postgres, self.dify_sandbox]
for container in containers:
if container:
try:
container_name = container.image
logger.info("Stopping container: %s", container_name)
container.stop()
logger.info("Successfully stopped container: %s", container_name)
except Exception as e:
# Log error but don't fail the test cleanup
logger.warning("Failed to stop container %s: %s", container, e)
self._containers_started = False
logger.info("All test containers stopped and cleaned up successfully")
# Global container manager instance
_container_manager = DifyTestContainers()
def _create_app_with_containers() -> Flask:
"""
Create Flask application configured to use test containers.
This function creates a Flask application instance that is configured
to connect to the test containers instead of the default development
or production databases.
Returns:
Flask: Configured Flask application for containerized testing
"""
logger.info("Creating Flask application with test container configuration...")
# Re-create the config after environment variables have been set
from configs import dify_config
# Force re-creation of config with new environment variables
dify_config.__dict__.clear()
dify_config.__init__()
# Create and configure the Flask application
logger.info("Initializing Flask application...")
app = create_app()
logger.info("Flask application created successfully")
# Initialize database schema
logger.info("Creating database schema...")
with app.app_context():
db.create_all()
logger.info("Database schema created successfully")
logger.info("Flask application configured and ready for testing")
return app
@pytest.fixture(scope="session")
def set_up_containers_and_env() -> Generator[DifyTestContainers, None, None]:
"""
Session-scoped fixture to manage test containers.
This fixture ensures containers are started once per test session
and properly cleaned up when all tests are complete. This approach
improves test performance by reusing containers across multiple tests.
Yields:
DifyTestContainers: Container manager instance
"""
logger.info("=== Starting test session container management ===")
_container_manager.start_containers_with_env()
logger.info("Test containers ready for session")
yield _container_manager
logger.info("=== Cleaning up test session containers ===")
_container_manager.stop_containers()
logger.info("Test session container cleanup completed")
@pytest.fixture(scope="session")
def flask_app_with_containers(set_up_containers_and_env) -> Flask:
"""
Session-scoped Flask application fixture using test containers.
This fixture provides a Flask application instance that is configured
to use the test containers for all database and service connections.
Args:
containers: Container manager fixture
Returns:
Flask: Configured Flask application
"""
logger.info("=== Creating session-scoped Flask application ===")
app = _create_app_with_containers()
logger.info("Session-scoped Flask application created successfully")
return app
@pytest.fixture
def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]:
"""
Request context fixture for containerized Flask application.
This fixture provides a Flask request context for tests that need
to interact with the Flask application within a request scope.
Args:
flask_app_with_containers: Flask application fixture
Yields:
None: Request context is active during yield
"""
logger.debug("Creating Flask request context...")
with flask_app_with_containers.test_request_context():
logger.debug("Flask request context active")
yield
logger.debug("Flask request context closed")
@pytest.fixture
def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]:
"""
Test client fixture for containerized Flask application.
This fixture provides a Flask test client that can be used to make
HTTP requests to the containerized application for integration testing.
Args:
flask_app_with_containers: Flask application fixture
Yields:
FlaskClient: Test client instance
"""
logger.debug("Creating Flask test client...")
with flask_app_with_containers.test_client() as client:
logger.debug("Flask test client ready")
yield client
logger.debug("Flask test client closed")
@pytest.fixture
def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]:
"""
Database session fixture for containerized testing.
This fixture provides a SQLAlchemy database session that is connected
to the test PostgreSQL container, allowing tests to interact with
the database directly.
Args:
flask_app_with_containers: Flask application fixture
Yields:
Session: Database session instance
"""
logger.debug("Creating database session...")
with flask_app_with_containers.app_context():
session = db.session()
logger.debug("Database session created and ready")
try:
yield session
finally:
session.close()
logger.debug("Database session closed")

View File

@ -0,0 +1,371 @@
import unittest
from datetime import UTC, datetime
from typing import Optional
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.file import File, FileTransferMethod, FileType
from extensions.ext_database import db
from factories.file_factory import StorageKeyLoader
from models import ToolFile, UploadFile
from models.enums import CreatorUserRole
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
class TestStorageKeyLoader(unittest.TestCase):
"""
Integration tests for StorageKeyLoader class.
Tests the batched loading of storage keys from the database for files
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
"""
def setUp(self):
"""Set up test data before each test method."""
self.session = db.session()
self.tenant_id = str(uuid4())
self.user_id = str(uuid4())
self.conversation_id = str(uuid4())
# Create test data that will be cleaned up after each test
self.test_upload_files = []
self.test_tool_files = []
# Create StorageKeyLoader instance
self.loader = StorageKeyLoader(self.session, self.tenant_id)
def tearDown(self):
"""Clean up test data after each test method."""
self.session.rollback()
def _create_upload_file(
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
) -> UploadFile:
"""Helper method to create an UploadFile record for testing."""
if file_id is None:
file_id = str(uuid4())
if storage_key is None:
storage_key = f"test_storage_key_{uuid4()}"
if tenant_id is None:
tenant_id = self.tenant_id
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
key=storage_key,
name="test_file.txt",
size=1024,
extension=".txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
created_at=datetime.now(UTC),
used=False,
)
upload_file.id = file_id
self.session.add(upload_file)
self.session.flush()
self.test_upload_files.append(upload_file)
return upload_file
def _create_tool_file(
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
) -> ToolFile:
"""Helper method to create a ToolFile record for testing."""
if file_id is None:
file_id = str(uuid4())
if file_key is None:
file_key = f"test_file_key_{uuid4()}"
if tenant_id is None:
tenant_id = self.tenant_id
tool_file = ToolFile()
tool_file.id = file_id
tool_file.user_id = self.user_id
tool_file.tenant_id = tenant_id
tool_file.conversation_id = self.conversation_id
tool_file.file_key = file_key
tool_file.mimetype = "text/plain"
tool_file.original_url = "http://example.com/file.txt"
tool_file.name = "test_tool_file.txt"
tool_file.size = 2048
self.session.add(tool_file)
self.session.flush()
self.test_tool_files.append(tool_file)
return tool_file
def _create_file(
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
) -> File:
"""Helper method to create a File object for testing."""
if tenant_id is None:
tenant_id = self.tenant_id
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
file_related_id = None
remote_url = None
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
file_related_id = related_id
elif transfer_method == FileTransferMethod.REMOTE_URL:
remote_url = "https://example.com/test_file.txt"
file_related_id = related_id
return File(
id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id,
type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=file_related_id,
remote_url=remote_url,
filename="test_file.txt",
extension=".txt",
mime_type="text/plain",
size=1024,
storage_key="initial_key",
)
def test_load_storage_keys_local_file(self):
"""Test loading storage keys for LOCAL_FILE transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Load storage keys
self.loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == upload_file.key
def test_load_storage_keys_remote_url(self):
"""Test loading storage keys for REMOTE_URL transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
# Load storage keys
self.loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == upload_file.key
def test_load_storage_keys_tool_file(self):
"""Test loading storage keys for TOOL_FILE transfer method."""
# Create test data
tool_file = self._create_tool_file()
file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
# Load storage keys
self.loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == tool_file.file_key
def test_load_storage_keys_mixed_methods(self):
"""Test batch loading with mixed transfer methods."""
# Create test data for different transfer methods
upload_file1 = self._create_upload_file()
upload_file2 = self._create_upload_file()
tool_file = self._create_tool_file()
file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
files = [file1, file2, file3]
# Load storage keys
self.loader.load_storage_keys(files)
# Verify all storage keys were loaded correctly
assert file1._storage_key == upload_file1.key
assert file2._storage_key == upload_file2.key
assert file3._storage_key == tool_file.file_key
def test_load_storage_keys_empty_list(self):
"""Test with empty file list."""
# Should not raise any exceptions
self.loader.load_storage_keys([])
def test_load_storage_keys_tenant_mismatch(self):
"""Test tenant_id validation."""
# Create file with different tenant_id
upload_file = self._create_upload_file()
file = self._create_file(
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
)
# Should raise ValueError for tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file])
assert "invalid file, expected tenant_id" in str(context.value)
def test_load_storage_keys_missing_file_id(self):
"""Test with None file.related_id."""
# Create a file with valid parameters first, then manually set related_id to None
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
file.related_id = None
# Should raise ValueError for None file related_id
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file])
assert str(context.value) == "file id should not be None."
def test_load_storage_keys_nonexistent_upload_file_records(self):
"""Test with missing UploadFile database records."""
# Create file with non-existent upload file id
non_existent_id = str(uuid4())
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Should raise ValueError for missing record
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
def test_load_storage_keys_nonexistent_tool_file_records(self):
"""Test with missing ToolFile database records."""
# Create file with non-existent tool file id
non_existent_id = str(uuid4())
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)
# Should raise ValueError for missing record
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
def test_load_storage_keys_invalid_uuid(self):
"""Test with invalid UUID format."""
# Create a file with valid parameters first, then manually set invalid related_id
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
file.related_id = "invalid-uuid-format"
# Should raise ValueError for invalid UUID
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
def test_load_storage_keys_batch_efficiency(self):
"""Test batched operations use efficient queries."""
# Create multiple files of different types
upload_files = [self._create_upload_file() for _ in range(3)]
tool_files = [self._create_tool_file() for _ in range(2)]
files = []
files.extend(
[self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
)
files.extend(
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
)
# Mock the session to count queries
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
self.loader.load_storage_keys(files)
# Should make exactly 2 queries (one for upload_files, one for tool_files)
assert mock_scalars.call_count == 2
# Verify all storage keys were loaded correctly
for i, file in enumerate(files[:3]):
assert file._storage_key == upload_files[i].key
for i, file in enumerate(files[3:]):
assert file._storage_key == tool_files[i].file_key
def test_load_storage_keys_tenant_isolation(self):
"""Test that tenant isolation works correctly."""
# Create files for different tenants
other_tenant_id = str(uuid4())
# Create upload file for current tenant
upload_file_current = self._create_upload_file()
file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)
# Create upload file for other tenant (but don't add to cleanup list)
upload_file_other = UploadFile(
tenant_id=other_tenant_id,
storage_type="local",
key="other_tenant_key",
name="other_file.txt",
size=1024,
extension=".txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
created_at=datetime.now(UTC),
used=False,
)
upload_file_other.id = str(uuid4())
self.session.add(upload_file_other)
self.session.flush()
# Create file for other tenant but try to load with current tenant's loader
file_other = self._create_file(
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
)
# Should raise ValueError due to tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file_other])
assert "invalid file, expected tenant_id" in str(context.value)
# Current tenant's file should still work
self.loader.load_storage_keys([file_current])
assert file_current._storage_key == upload_file_current.key
def test_load_storage_keys_mixed_tenant_batch(self):
"""Test batch with mixed tenant files (should fail on first mismatch)."""
# Create files for current tenant
upload_file_current = self._create_upload_file()
file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)
# Create file for different tenant
other_tenant_id = str(uuid4())
file_other = self._create_file(
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
)
# Should raise ValueError on tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file_current, file_other])
assert "invalid file, expected tenant_id" in str(context.value)
def test_load_storage_keys_duplicate_file_ids(self):
"""Test handling of duplicate file IDs in the batch."""
# Create upload file
upload_file = self._create_upload_file()
# Create two File objects with same related_id
file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Should handle duplicates gracefully
self.loader.load_storage_keys([file1, file2])
# Both files should have the same storage key
assert file1._storage_key == upload_file.key
assert file2._storage_key == upload_file.key
def test_load_storage_keys_session_isolation(self):
"""Test that the loader uses the provided session correctly."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Create loader with different session (same underlying connection)
with Session(bind=db.engine) as other_session:
other_loader = StorageKeyLoader(other_session, self.tenant_id)
with pytest.raises(ValueError):
other_loader.load_storage_keys([file])

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,739 @@
import pytest
from faker import Faker
from core.variables.segments import StringSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from models import App, Workflow
from models.enums import DraftVariableType
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import (
UpdateNotSupportedError,
WorkflowDraftVariableService,
)
class TestWorkflowDraftVariableService:
"""
Comprehensive integration tests for WorkflowDraftVariableService using testcontainers.
This test class covers all major functionality of the WorkflowDraftVariableService:
- CRUD operations for workflow draft variables (Create, Read, Update, Delete)
- Variable listing and filtering by type (conversation, system, node)
- Variable updates and resets with proper validation
- Variable deletion operations at different scopes
- Special functionality like prefill and conversation ID retrieval
- Error handling for various edge cases and invalid operations
All tests use the testcontainers infrastructure to ensure proper database isolation
and realistic testing environment with actual database interactions.
"""
@pytest.fixture
def mock_external_service_dependencies(self):
"""
Mock setup for external service dependencies.
WorkflowDraftVariableService doesn't have external dependencies that need mocking,
so this fixture returns an empty dictionary to maintain consistency with other test classes.
This ensures the test structure remains consistent across different service test files.
"""
# WorkflowDraftVariableService doesn't have external dependencies that need mocking
return {}
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None):
"""
Helper method to create a test app with realistic data for testing.
This method creates a complete App instance with all required fields populated
using Faker for generating realistic test data. The app is configured for
workflow mode to support workflow draft variable testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies (unused in this service)
fake: Faker instance for generating test data, creates new instance if not provided
Returns:
App: Created test app instance with all required fields populated
"""
fake = fake or Faker()
app = App()
app.id = fake.uuid4()
app.tenant_id = fake.uuid4()
app.name = fake.company()
app.description = fake.text()
app.mode = "workflow"
app.icon_type = "emoji"
app.icon = "🤖"
app.icon_background = "#FFEAD5"
app.enable_site = True
app.enable_api = True
app.created_by = fake.uuid4()
app.updated_by = app.created_by
from extensions.ext_database import db
db.session.add(app)
db.session.commit()
return app
def _create_test_workflow(self, db_session_with_containers, app, fake=None):
"""
Helper method to create a test workflow associated with an app.
This method creates a Workflow instance using the proper factory method
to ensure all required fields are set correctly. The workflow is configured
as a draft version with basic graph structure for testing workflow variables.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app: The app to associate the workflow with
fake: Faker instance for generating test data, creates new instance if not provided
Returns:
Workflow: Created test workflow instance with proper configuration
"""
fake = fake or Faker()
workflow = Workflow.new(
tenant_id=app.tenant_id,
app_id=app.id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features="{}",
created_by=app.created_by,
environment_variables=[],
conversation_variables=[],
)
from extensions.ext_database import db
db.session.add(workflow)
db.session.commit()
return workflow
def _create_test_variable(
self, db_session_with_containers, app_id, node_id, name, value, variable_type="conversation", fake=None
):
"""
Helper method to create a test workflow draft variable with proper configuration.
This method creates different types of variables (conversation, system, node) using
the appropriate factory methods to ensure proper initialization. Each variable type
has specific requirements and this method handles the creation logic for all types.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
app_id: ID of the app to associate the variable with
node_id: ID of the node (or special constants like CONVERSATION_VARIABLE_NODE_ID)
name: Name of the variable for identification
value: StringSegment value for the variable content
variable_type: Type of variable ("conversation", "system", "node") determining creation method
fake: Faker instance for generating test data, creates new instance if not provided
Returns:
WorkflowDraftVariable: Created test variable instance with proper type configuration
"""
fake = fake or Faker()
if variable_type == "conversation":
# Create conversation variable using the appropriate factory method
variable = WorkflowDraftVariable.new_conversation_variable(
app_id=app_id,
name=name,
value=value,
description=fake.text(max_nb_chars=20),
)
elif variable_type == "system":
# Create system variable with editable flag and execution context
variable = WorkflowDraftVariable.new_sys_variable(
app_id=app_id,
name=name,
value=value,
node_execution_id=fake.uuid4(),
editable=True,
)
else: # node variable
# Create node variable with visibility and editability settings
variable = WorkflowDraftVariable.new_node_variable(
app_id=app_id,
node_id=node_id,
name=name,
value=value,
node_execution_id=fake.uuid4(),
visible=True,
editable=True,
)
from extensions.ext_database import db
db.session.add(variable)
db.session.commit()
return variable
def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test getting a single variable by ID successfully.
This test verifies that the service can retrieve a specific variable
by its ID and that the returned variable contains the correct data.
It ensures the basic CRUD read operation works correctly for workflow draft variables.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
test_value = StringSegment(value=fake.word())
variable = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_variable = service.get_variable(variable.id)
assert retrieved_variable is not None
assert retrieved_variable.id == variable.id
assert retrieved_variable.name == "test_var"
assert retrieved_variable.app_id == app.id
assert retrieved_variable.get_value().value == test_value.value
def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test getting a variable that doesn't exist.
This test verifies that the service returns None when trying to
retrieve a variable with a non-existent ID. This ensures proper
handling of missing data scenarios.
"""
fake = Faker()
non_existent_id = fake.uuid4()
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_variable = service.get_variable(non_existent_id)
assert retrieved_variable is None
def test_get_draft_variables_by_selectors_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting variables by selectors successfully.
This test verifies that the service can retrieve multiple variables
using selector pairs (node_id, variable_name) and returns the correct
variables for each selector. This is useful for bulk variable retrieval
operations in workflow execution contexts.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
var1_value = StringSegment(value=fake.word())
var2_value = StringSegment(value=fake.word())
var3_value = StringSegment(value=fake.word())
var1 = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var1", var1_value, fake=fake
)
var2 = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var2", var2_value, fake=fake
)
var3 = self._create_test_variable(
db_session_with_containers, app.id, "test_node_1", "var3", var3_value, "node", fake=fake
)
selectors = [
[CONVERSATION_VARIABLE_NODE_ID, "var1"],
[CONVERSATION_VARIABLE_NODE_ID, "var2"],
["test_node_1", "var3"],
]
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors)
assert len(retrieved_variables) == 3
var_names = [var.name for var in retrieved_variables]
assert "var1" in var_names
assert "var2" in var_names
assert "var3" in var_names
for var in retrieved_variables:
if var.name == "var1":
assert var.get_value().value == var1_value.value
elif var.name == "var2":
assert var.get_value().value == var2_value.value
elif var.name == "var3":
assert var.get_value().value == var3_value.value
def test_list_variables_without_values_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test listing variables without values successfully with pagination.
This test verifies that the service can list variables with pagination
and that the returned variables don't include their values (for performance).
This is important for scenarios where only variable metadata is needed
without loading the actual content.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
for i in range(5):
test_value = StringSegment(value=fake.numerify("value##"))
self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_variables_without_values(app.id, page=1, limit=3)
assert result.total == 5
assert len(result.variables) == 3
assert result.variables[0].created_at >= result.variables[1].created_at
assert result.variables[1].created_at >= result.variables[2].created_at
for var in result.variables:
assert var.name is not None
assert var.app_id == app.id
def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test listing variables for a specific node successfully.
This test verifies that the service can filter and return only
variables associated with a specific node ID. This is crucial for
workflow execution where variables need to be scoped to specific nodes.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
node_id = fake.word()
var1_value = StringSegment(value=fake.word())
var2_value = StringSegment(value=fake.word())
var3_value = StringSegment(value=fake.word())
self._create_test_variable(db_session_with_containers, app.id, node_id, "var1", var1_value, "node", fake=fake)
self._create_test_variable(db_session_with_containers, app.id, node_id, "var2", var3_value, "node", fake=fake)
self._create_test_variable(
db_session_with_containers, app.id, "other_node", "var3", var2_value, "node", fake=fake
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_node_variables(app.id, node_id)
assert len(result.variables) == 2
for var in result.variables:
assert var.node_id == node_id
assert var.app_id == app.id
var_names = [var.name for var in result.variables]
assert "var1" in var_names
assert "var2" in var_names
assert "var3" not in var_names
def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test listing conversation variables successfully.
This test verifies that the service can filter and return only
conversation variables, excluding system and node variables.
Conversation variables are user-facing variables that can be
modified during conversation flows.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
conv_var1_value = StringSegment(value=fake.word())
conv_var2_value = StringSegment(value=fake.word())
conv_var1 = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var1", conv_var1_value, fake=fake
)
conv_var2 = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var2", conv_var2_value, fake=fake
)
sys_var_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var", sys_var_value, "system", fake=fake
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_conversation_variables(app.id)
assert len(result.variables) == 2
for var in result.variables:
assert var.node_id == CONVERSATION_VARIABLE_NODE_ID
assert var.app_id == app.id
assert var.get_variable_type() == DraftVariableType.CONVERSATION
var_names = [var.name for var in result.variables]
assert "conv_var1" in var_names
assert "conv_var2" in var_names
assert "sys_var" not in var_names
def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test updating a variable's name and value successfully.
This test verifies that the service can update both the name and value
of an editable variable and that the changes are persisted correctly.
It also checks that the last_edited_at timestamp is updated appropriately.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
original_value = StringSegment(value=fake.word())
new_value = StringSegment(value=fake.word())
variable = self._create_test_variable(
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
"original_name",
original_value,
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
updated_variable = service.update_variable(variable, name="new_name", value=new_value)
assert updated_variable.name == "new_name"
assert updated_variable.get_value().value == new_value.value
assert updated_variable.last_edited_at is not None
from extensions.ext_database import db
db.session.refresh(variable)
assert variable.name == "new_name"
assert variable.get_value().value == new_value.value
assert variable.last_edited_at is not None
def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test that updating a non-editable variable raises an exception.
This test verifies that the service properly prevents updates to
variables that are not marked as editable. This is important for
maintaining data integrity and preventing unauthorized modifications
to system-controlled variables.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
original_value = StringSegment(value=fake.word())
new_value = StringSegment(value=fake.word())
variable = WorkflowDraftVariable.new_sys_variable(
app_id=app.id,
name=fake.word(), # This is typically not editable
value=original_value,
node_execution_id=fake.uuid4(),
editable=False, # Set as non-editable
)
from extensions.ext_database import db
db.session.add(variable)
db.session.commit()
service = WorkflowDraftVariableService(db_session_with_containers)
with pytest.raises(UpdateNotSupportedError) as exc_info:
service.update_variable(variable, name="new_name", value=new_value)
assert "variable not support updating" in str(exc_info.value)
assert variable.id in str(exc_info.value)
def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test resetting conversation variable successfully.
This test verifies that the service can reset a conversation variable
to its default value and clear the last_edited_at timestamp.
This functionality is useful for reverting user modifications
back to the original workflow configuration.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake)
from core.variables.variables import StringVariable
conv_var = StringVariable(
id=fake.uuid4(),
name="test_conv_var",
value="default_value",
selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"],
)
workflow.conversation_variables = [conv_var]
from extensions.ext_database import db
db.session.commit()
modified_value = StringSegment(value=fake.word())
variable = self._create_test_variable(
db_session_with_containers,
app.id,
CONVERSATION_VARIABLE_NODE_ID,
"test_conv_var",
modified_value,
fake=fake,
)
variable.last_edited_at = fake.date_time()
db.session.commit()
service = WorkflowDraftVariableService(db_session_with_containers)
reset_variable = service.reset_variable(workflow, variable)
assert reset_variable is not None
assert reset_variable.get_value().value == "default_value"
assert reset_variable.last_edited_at is None
db.session.refresh(variable)
assert variable.get_value().value == "default_value"
assert variable.last_edited_at is None
def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test deleting a single variable successfully.
This test verifies that the service can delete a specific variable
and that it's properly removed from the database. It ensures that
the deletion operation is atomic and complete.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
test_value = StringSegment(value=fake.word())
variable = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
)
from extensions.ext_database import db
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None
service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_variable(variable)
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None
def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test deleting all variables for a workflow successfully.
This test verifies that the service can delete all variables
associated with a specific app/workflow. This is useful for
cleanup operations when workflows are deleted or reset.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
for i in range(3):
test_value = StringSegment(value=fake.numerify("value##"))
self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake
)
other_app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
other_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers, other_app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), other_value, fake=fake
)
from extensions.ext_database import db
app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
assert len(app_variables) == 3
assert len(other_app_variables) == 1
service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_workflow_variables(app.id)
app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
assert len(app_variables_after) == 0
assert len(other_app_variables_after) == 1
def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test deleting all variables for a specific node successfully.
This test verifies that the service can delete all variables
associated with a specific node while preserving variables
for other nodes and conversation variables. This is important
for node-specific cleanup operations in workflow management.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
node_id = fake.word()
for i in range(2):
test_value = StringSegment(value=fake.numerify("node_value##"))
self._create_test_variable(
db_session_with_containers, app.id, node_id, fake.word(), test_value, "node", fake=fake
)
other_node_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers, app.id, "other_node", fake.word(), other_node_value, "node", fake=fake
)
conv_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), conv_value, fake=fake
)
from extensions.ext_database import db
target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
other_node_variables = (
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
)
conv_variables = (
db.session.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
.all()
)
assert len(target_node_variables) == 2
assert len(other_node_variables) == 1
assert len(conv_variables) == 1
service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_node_variables(app.id, node_id)
target_node_variables_after = (
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
)
other_node_variables_after = (
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
)
conv_variables_after = (
db.session.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
.all()
)
assert len(target_node_variables_after) == 0
assert len(other_node_variables_after) == 1
assert len(conv_variables_after) == 1
def test_prefill_conversation_variable_default_values_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test prefill conversation variable default values successfully.
This test verifies that the service can automatically create
conversation variables with default values based on the workflow
configuration when none exist. This is important for initializing
workflow variables with proper defaults from the workflow definition.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake)
from core.variables.variables import StringVariable
conv_var1 = StringVariable(
id=fake.uuid4(),
name="conv_var1",
value="default_value1",
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var1"],
)
conv_var2 = StringVariable(
id=fake.uuid4(),
name="conv_var2",
value="default_value2",
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"],
)
workflow.conversation_variables = [conv_var1, conv_var2]
from extensions.ext_database import db
db.session.commit()
service = WorkflowDraftVariableService(db_session_with_containers)
service.prefill_conversation_variable_default_values(workflow)
draft_variables = (
db.session.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
.all()
)
assert len(draft_variables) == 2
var_names = [var.name for var in draft_variables]
assert "conv_var1" in var_names
assert "conv_var2" in var_names
for var in draft_variables:
assert var.app_id == app.id
assert var.node_id == CONVERSATION_VARIABLE_NODE_ID
assert var.editable is True
assert var.get_variable_type() == DraftVariableType.CONVERSATION
def test_get_conversation_id_from_draft_variable_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting conversation ID from draft variable successfully.
This test verifies that the service can extract the conversation ID
from a system variable named "conversation_id". This is important
for maintaining conversation context across workflow executions.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
conversation_id = fake.uuid4()
conv_id_value = StringSegment(value=conversation_id)
self._create_test_variable(
db_session_with_containers,
app.id,
SYSTEM_VARIABLE_NODE_ID,
"conversation_id",
conv_id_value,
"system",
fake=fake,
)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
assert retrieved_conv_id == conversation_id
def test_get_conversation_id_from_draft_variable_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test getting conversation ID when it doesn't exist.
This test verifies that the service returns None when no
conversation_id variable exists for the app. This ensures
proper handling of missing conversation context scenarios.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
assert retrieved_conv_id is None
def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test listing system variables successfully.
This test verifies that the service can filter and return only
system variables, excluding conversation and node variables.
System variables are internal variables used by the workflow
engine for maintaining state and context.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
sys_var1_value = StringSegment(value=fake.word())
sys_var2_value = StringSegment(value=fake.word())
sys_var1 = self._create_test_variable(
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var1", sys_var1_value, "system", fake=fake
)
sys_var2 = self._create_test_variable(
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var2", sys_var2_value, "system", fake=fake
)
conv_var_value = StringSegment(value=fake.word())
self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var", conv_var_value, fake=fake
)
service = WorkflowDraftVariableService(db_session_with_containers)
result = service.list_system_variables(app.id)
assert len(result.variables) == 2
for var in result.variables:
assert var.node_id == SYSTEM_VARIABLE_NODE_ID
assert var.app_id == app.id
assert var.get_variable_type() == DraftVariableType.SYS
var_names = [var.name for var in result.variables]
assert "sys_var1" in var_names
assert "sys_var2" in var_names
assert "conv_var" not in var_names
def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test getting variables by name successfully for different types.
This test verifies that the service can retrieve variables by name
for different variable types (conversation, system, node). This
functionality is important for variable lookup operations during
workflow execution and user interactions.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
test_value = StringSegment(value=fake.word())
conv_var = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_conv_var", test_value, fake=fake
)
sys_var = self._create_test_variable(
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "test_sys_var", test_value, "system", fake=fake
)
node_var = self._create_test_variable(
db_session_with_containers, app.id, "test_node", "test_node_var", test_value, "node", fake=fake
)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var")
assert retrieved_conv_var is not None
assert retrieved_conv_var.name == "test_conv_var"
assert retrieved_conv_var.node_id == CONVERSATION_VARIABLE_NODE_ID
retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var")
assert retrieved_sys_var is not None
assert retrieved_sys_var.name == "test_sys_var"
assert retrieved_sys_var.node_id == SYSTEM_VARIABLE_NODE_ID
retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var")
assert retrieved_node_var is not None
assert retrieved_node_var.name == "test_node_var"
assert retrieved_node_var.node_id == "test_node"
def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test getting variables by name when they don't exist.
This test verifies that the service returns None when trying to
retrieve variables by name that don't exist. This ensures proper
handling of missing variable scenarios for all variable types.
"""
fake = Faker()
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
service = WorkflowDraftVariableService(db_session_with_containers)
retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var")
assert retrieved_conv_var is None
retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var")
assert retrieved_sys_var is None
retrieved_node_var = service.get_node_variable(app.id, "test_node", "non_existent_node_var")
assert retrieved_node_var is None

View File

@ -0,0 +1,11 @@
import pytest
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
CODE_LANGUAGE = "unsupported_language"
def test_unsupported_with_code_template():
with pytest.raises(CodeExecutionError) as e:
CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={})
assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"

View File

@ -0,0 +1,47 @@
from textwrap import dedent
from .test_utils import CodeExecutorTestMixin
class TestJavaScriptCodeExecutor(CodeExecutorTestMixin):
"""Test class for JavaScript code executor functionality."""
def test_javascript_plain(self, flask_app_with_containers):
"""Test basic JavaScript code execution with console.log output"""
CodeExecutor, CodeLanguage = self.code_executor_imports
code = 'console.log("Hello World")'
result_message = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code)
assert result_message == "Hello World\n"
def test_javascript_json(self, flask_app_with_containers):
"""Test JavaScript code execution with JSON output"""
CodeExecutor, CodeLanguage = self.code_executor_imports
code = dedent("""
obj = {'Hello': 'World'}
console.log(JSON.stringify(obj))
""")
result = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code)
assert result == '{"Hello":"World"}\n'
def test_javascript_with_code_template(self, flask_app_with_containers):
"""Test JavaScript workflow code template execution with inputs"""
CodeExecutor, CodeLanguage = self.code_executor_imports
JavascriptCodeProvider, _ = self.javascript_imports
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JAVASCRIPT,
code=JavascriptCodeProvider.get_default_code(),
inputs={"arg1": "Hello", "arg2": "World"},
)
assert result == {"result": "HelloWorld"}
def test_javascript_get_runner_script(self, flask_app_with_containers):
"""Test JavaScript template transformer runner script generation"""
_, NodeJsTemplateTransformer = self.javascript_imports
runner_script = NodeJsTemplateTransformer.get_runner_script()
assert runner_script.count(NodeJsTemplateTransformer._code_placeholder) == 1
assert runner_script.count(NodeJsTemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(NodeJsTemplateTransformer._result_tag) == 2

View File

@ -0,0 +1,42 @@
import base64
from .test_utils import CodeExecutorTestMixin
class TestJinja2CodeExecutor(CodeExecutorTestMixin):
"""Test class for Jinja2 code executor functionality."""
def test_jinja2(self, flask_app_with_containers):
"""Test basic Jinja2 template execution with variable substitution"""
CodeExecutor, CodeLanguage = self.code_executor_imports
_, Jinja2TemplateTransformer = self.jinja2_imports
template = "Hello {{template}}"
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
code = (
Jinja2TemplateTransformer.get_runner_script()
.replace(Jinja2TemplateTransformer._code_placeholder, template)
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
)
result = CodeExecutor.execute_code(
language=CodeLanguage.JINJA2, preload=Jinja2TemplateTransformer.get_preload_script(), code=code
)
assert result == "<<RESULT>>Hello World<<RESULT>>\n"
def test_jinja2_with_code_template(self, flask_app_with_containers):
"""Test Jinja2 workflow code template execution with inputs"""
CodeExecutor, CodeLanguage = self.code_executor_imports
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code="Hello {{template}}", inputs={"template": "World"}
)
assert result == {"result": "Hello World"}
def test_jinja2_get_runner_script(self, flask_app_with_containers):
"""Test Jinja2 template transformer runner script generation"""
_, Jinja2TemplateTransformer = self.jinja2_imports
runner_script = Jinja2TemplateTransformer.get_runner_script()
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2

View File

@ -0,0 +1,47 @@
from textwrap import dedent
from .test_utils import CodeExecutorTestMixin
class TestPython3CodeExecutor(CodeExecutorTestMixin):
"""Test class for Python3 code executor functionality."""
def test_python3_plain(self, flask_app_with_containers):
"""Test basic Python3 code execution with print output"""
CodeExecutor, CodeLanguage = self.code_executor_imports
code = 'print("Hello World")'
result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code)
assert result == "Hello World\n"
def test_python3_json(self, flask_app_with_containers):
"""Test Python3 code execution with JSON output"""
CodeExecutor, CodeLanguage = self.code_executor_imports
code = dedent("""
import json
print(json.dumps({'Hello': 'World'}))
""")
result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code)
assert result == '{"Hello": "World"}\n'
def test_python3_with_code_template(self, flask_app_with_containers):
"""Test Python3 workflow code template execution with inputs"""
CodeExecutor, CodeLanguage = self.code_executor_imports
Python3CodeProvider, _ = self.python3_imports
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.PYTHON3,
code=Python3CodeProvider.get_default_code(),
inputs={"arg1": "Hello", "arg2": "World"},
)
assert result == {"result": "HelloWorld"}
def test_python3_get_runner_script(self, flask_app_with_containers):
"""Test Python3 template transformer runner script generation"""
_, Python3TemplateTransformer = self.python3_imports
runner_script = Python3TemplateTransformer.get_runner_script()
assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1
assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Python3TemplateTransformer._result_tag) == 2

View File

@ -0,0 +1,115 @@
"""
Test utilities for code executor integration tests.
This module provides lazy import functions to avoid module loading issues
that occur when modules are imported before the flask_app_with_containers fixture
has set up the proper environment variables and configuration.
"""
import importlib
from typing import TYPE_CHECKING
if TYPE_CHECKING:
pass
def force_reload_code_executor():
"""
Force reload the code_executor module to reinitialize code_execution_endpoint_url.
This function should be called after setting up environment variables
to ensure the code_execution_endpoint_url is initialized with the correct value.
"""
try:
import core.helper.code_executor.code_executor
importlib.reload(core.helper.code_executor.code_executor)
except Exception as e:
# Log the error but don't fail the test
print(f"Warning: Failed to reload code_executor module: {e}")
def get_code_executor_imports():
"""
Lazy import function for core CodeExecutor classes.
Returns:
tuple: (CodeExecutor, CodeLanguage) classes
"""
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
return CodeExecutor, CodeLanguage
def get_javascript_imports():
"""
Lazy import function for JavaScript-specific modules.
Returns:
tuple: (JavascriptCodeProvider, NodeJsTemplateTransformer) classes
"""
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
return JavascriptCodeProvider, NodeJsTemplateTransformer
def get_python3_imports():
"""
Lazy import function for Python3-specific modules.
Returns:
tuple: (Python3CodeProvider, Python3TemplateTransformer) classes
"""
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
return Python3CodeProvider, Python3TemplateTransformer
def get_jinja2_imports():
"""
Lazy import function for Jinja2-specific modules.
Returns:
tuple: (None, Jinja2TemplateTransformer) classes
"""
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
return None, Jinja2TemplateTransformer
class CodeExecutorTestMixin:
"""
Mixin class providing lazy import methods for code executor tests.
This mixin helps avoid module loading issues by deferring imports
until after the flask_app_with_containers fixture has set up the environment.
"""
def setup_method(self):
"""
Setup method called before each test method.
Force reload the code_executor module to ensure fresh initialization.
"""
force_reload_code_executor()
@property
def code_executor_imports(self):
"""Property to get CodeExecutor and CodeLanguage classes."""
return get_code_executor_imports()
@property
def javascript_imports(self):
"""Property to get JavaScript-specific classes."""
return get_javascript_imports()
@property
def python3_imports(self):
"""Property to get Python3-specific classes."""
return get_python3_imports()
@property
def jinja2_imports(self):
"""Property to get Jinja2-specific classes."""
return get_jinja2_imports()

View File

@ -1,4 +1,4 @@
from core.variables.types import SegmentType
from core.variables.types import ArrayValidation, SegmentType
class TestSegmentTypeIsArrayType:
@ -17,7 +17,6 @@ class TestSegmentTypeIsArrayType:
value is tested for the is_array_type method.
"""
# Arrange
all_segment_types = set(SegmentType)
expected_array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
@ -58,3 +57,27 @@ class TestSegmentTypeIsArrayType:
for seg_type in enum_values:
is_array = seg_type.is_array_type()
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
class TestSegmentTypeIsValidArrayValidation:
"""
Test SegmentType.is_valid with array types using different validation strategies.
"""
def test_array_validation_all_success(self):
value = ["hello", "world", "foo"]
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
def test_array_validation_all_fail(self):
value = ["hello", 123, "world"]
# Should return False, since 123 is not a string
assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
def test_array_validation_first(self):
value = ["hello", 123, None]
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST)
def test_array_validation_none(self):
value = [1, 2, 3]
# validation is None, skip
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE)

View File

@ -0,0 +1,168 @@
import datetime
from unittest.mock import Mock, patch
import pytest
from sqlalchemy.orm import Session
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
class TestClearFreePlanTenantExpiredLogs:
"""Unit tests for ClearFreePlanTenantExpiredLogs._clear_message_related_tables method."""
@pytest.fixture
def mock_session(self):
"""Create a mock database session."""
session = Mock(spec=Session)
session.query.return_value.filter.return_value.all.return_value = []
session.query.return_value.filter.return_value.delete.return_value = 0
return session
@pytest.fixture
def mock_storage(self):
"""Create a mock storage object."""
storage = Mock()
storage.save.return_value = None
return storage
@pytest.fixture
def sample_message_ids(self):
"""Sample message IDs for testing."""
return ["msg-1", "msg-2", "msg-3"]
@pytest.fixture
def sample_records(self):
"""Sample records for testing."""
records = []
for i in range(3):
record = Mock()
record.id = f"record-{i}"
record.to_dict.return_value = {
"id": f"record-{i}",
"message_id": f"msg-{i}",
"created_at": datetime.datetime.now().isoformat(),
}
records.append(record)
return records
def test_clear_message_related_tables_empty_message_ids(self, mock_session):
"""Test that method returns early when message_ids is empty."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", [])
# Should not call any database operations
mock_session.query.assert_not_called()
mock_storage.save.assert_not_called()
def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids):
"""Test when no related records are found."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.filter.return_value.all.return_value = []
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should call query for each related table but find no records
assert mock_session.query.call_count > 0
mock_storage.save.assert_not_called()
def test_clear_message_related_tables_with_records_and_to_dict(
self, mock_session, sample_message_ids, sample_records
):
"""Test when records are found and have to_dict method."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.filter.return_value.all.return_value = sample_records
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should call to_dict on each record (called once per table, so 7 times total)
for record in sample_records:
assert record.to_dict.call_count == 7
# Should save backup data
assert mock_storage.save.call_count > 0
def test_clear_message_related_tables_with_records_no_to_dict(self, mock_session, sample_message_ids):
"""Test when records are found but don't have to_dict method."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
# Create records without to_dict method
records = []
for i in range(2):
record = Mock()
mock_table = Mock()
mock_id_column = Mock()
mock_id_column.name = "id"
mock_message_id_column = Mock()
mock_message_id_column.name = "message_id"
mock_table.columns = [mock_id_column, mock_message_id_column]
record.__table__ = mock_table
record.id = f"record-{i}"
record.message_id = f"msg-{i}"
del record.to_dict
records.append(record)
# Mock records for first table only, empty for others
mock_session.query.return_value.filter.return_value.all.side_effect = [
records,
[],
[],
[],
[],
[],
[],
]
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should save backup data even without to_dict
assert mock_storage.save.call_count > 0
def test_clear_message_related_tables_storage_error_continues(
self, mock_session, sample_message_ids, sample_records
):
"""Test that method continues even when storage.save fails."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_storage.save.side_effect = Exception("Storage error")
mock_session.query.return_value.filter.return_value.all.return_value = sample_records
# Should not raise exception
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should still delete records even if backup fails
assert mock_session.query.return_value.filter.return_value.delete.called
def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids):
"""Test that method continues even when record serialization fails."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
record = Mock()
record.id = "record-1"
record.to_dict.side_effect = Exception("Serialization error")
mock_session.query.return_value.filter.return_value.all.return_value = [record]
# Should not raise exception
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should still delete records even if serialization fails
assert mock_session.query.return_value.filter.return_value.delete.called
def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records):
"""Test that deletion is called for found records."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.filter.return_value.all.return_value = sample_records
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
# Should call delete for each table that has records
assert mock_session.query.return_value.filter.return_value.delete.called
def test_clear_message_related_tables_logging_output(
self, mock_session, sample_message_ids, sample_records, capsys
):
"""Test that logging output is generated."""
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
mock_session.query.return_value.filter.return_value.all.return_value = sample_records
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
pass

View File

@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.11, <3.13"
resolution-markers = [
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'",
@ -1265,6 +1265,8 @@ dependencies = [
{ name = "opentelemetry-instrumentation" },
{ name = "opentelemetry-instrumentation-celery" },
{ name = "opentelemetry-instrumentation-flask" },
{ name = "opentelemetry-instrumentation-redis" },
{ name = "opentelemetry-instrumentation-requests" },
{ name = "opentelemetry-instrumentation-sqlalchemy" },
{ name = "opentelemetry-propagator-b3" },
{ name = "opentelemetry-proto" },
@ -1318,6 +1320,7 @@ dev = [
{ name = "pytest-mock" },
{ name = "ruff" },
{ name = "scipy-stubs" },
{ name = "testcontainers" },
{ name = "types-aiofiles" },
{ name = "types-beautifulsoup4" },
{ name = "types-cachetools" },
@ -1447,6 +1450,8 @@ requires-dist = [
{ name = "opentelemetry-instrumentation", specifier = "==0.48b0" },
{ name = "opentelemetry-instrumentation-celery", specifier = "==0.48b0" },
{ name = "opentelemetry-instrumentation-flask", specifier = "==0.48b0" },
{ name = "opentelemetry-instrumentation-redis", specifier = "==0.48b0" },
{ name = "opentelemetry-instrumentation-requests", specifier = "==0.48b0" },
{ name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.48b0" },
{ name = "opentelemetry-propagator-b3", specifier = "==1.27.0" },
{ name = "opentelemetry-proto", specifier = "==1.27.0" },
@ -1500,6 +1505,7 @@ dev = [
{ name = "pytest-mock", specifier = "~=3.14.0" },
{ name = "ruff", specifier = "~=0.12.3" },
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
{ name = "testcontainers", specifier = "~=4.10.0" },
{ name = "types-aiofiles", specifier = "~=24.1.0" },
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
{ name = "types-cachetools", specifier = "~=5.5.0" },
@ -1600,6 +1606,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
]
[[package]]
name = "docker"
version = "7.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "requests" },
{ name = "urllib3" },
]
sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" },
]
[[package]]
name = "docstring-parser"
version = "0.16"
@ -3654,6 +3674,36 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/78/3d/fcde4f8f0bf9fa1ee73a12304fa538076fb83fe0a2ae966ab0f0b7da5109/opentelemetry_instrumentation_flask-0.48b0-py3-none-any.whl", hash = "sha256:26b045420b9d76e85493b1c23fcf27517972423480dc6cf78fd6924248ba5808", size = 14588, upload-time = "2024-08-28T21:26:58.504Z" },
]
[[package]]
name = "opentelemetry-instrumentation-redis"
version = "0.48b0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "opentelemetry-api" },
{ name = "opentelemetry-instrumentation" },
{ name = "opentelemetry-semantic-conventions" },
{ name = "wrapt" },
]
sdist = { url = "https://files.pythonhosted.org/packages/70/be/92e98e4c7f275be3d373899a41b0a7d4df64266657d985dbbdb9a54de0d5/opentelemetry_instrumentation_redis-0.48b0.tar.gz", hash = "sha256:61e33e984b4120e1b980d9fba6e9f7ca0c8d972f9970654d8f6e9f27fa115a8c", size = 10511, upload-time = "2024-08-28T21:28:15.061Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/94/40/892f30d400091106309cc047fd3f6d76a828fedd984a953fd5386b78a2fb/opentelemetry_instrumentation_redis-0.48b0-py3-none-any.whl", hash = "sha256:48c7f2e25cbb30bde749dc0d8b9c74c404c851f554af832956b9630b27f5bcb7", size = 11610, upload-time = "2024-08-28T21:27:18.759Z" },
]
[[package]]
name = "opentelemetry-instrumentation-requests"
version = "0.48b0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "opentelemetry-api" },
{ name = "opentelemetry-instrumentation" },
{ name = "opentelemetry-semantic-conventions" },
{ name = "opentelemetry-util-http" },
]
sdist = { url = "https://files.pythonhosted.org/packages/52/ac/5eb78efde21ff21d0ad5dc8c6cc6a0f8ae482ce8a46293c2f45a628b6166/opentelemetry_instrumentation_requests-0.48b0.tar.gz", hash = "sha256:67ab9bd877a0352ee0db4616c8b4ae59736ddd700c598ed907482d44f4c9a2b3", size = 14120, upload-time = "2024-08-28T21:28:16.933Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/43/df/0df9226d1b14f29d23c07e6194b9fd5ad50e7d987b7fd13df7dcf718aeb1/opentelemetry_instrumentation_requests-0.48b0-py3-none-any.whl", hash = "sha256:d4f01852121d0bd4c22f14f429654a735611d4f7bf3cf93f244bdf1489b2233d", size = 12366, upload-time = "2024-08-28T21:27:20.771Z" },
]
[[package]]
name = "opentelemetry-instrumentation-sqlalchemy"
version = "0.48b0"
@ -5468,6 +5518,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" },
]
[[package]]
name = "testcontainers"
version = "4.10.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "docker" },
{ name = "python-dotenv" },
{ name = "typing-extensions" },
{ name = "urllib3" },
{ name = "wrapt" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a1/49/9c618aff1c50121d183cdfbc3a4a5cf2727a2cde1893efe6ca55c7009196/testcontainers-4.10.0.tar.gz", hash = "sha256:03f85c3e505d8b4edeb192c72a961cebbcba0dd94344ae778b4a159cb6dcf8d3", size = 63327, upload-time = "2025-04-02T16:13:27.582Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1c/0a/824b0c1ecf224802125279c3effff2e25ed785ed046e67da6e53d928de4c/testcontainers-4.10.0-py3-none-any.whl", hash = "sha256:31ed1a81238c7e131a2a29df6db8f23717d892b592fa5a1977fd0dcd0c23fc23", size = 107414, upload-time = "2025-04-02T16:13:25.785Z" },
]
[[package]]
name = "tidb-vector"
version = "0.0.9"

View File

@ -15,3 +15,6 @@ dev/pytest/pytest_workflow.sh
# Unit tests
dev/pytest/pytest_unit_tests.sh
# TestContainers tests
dev/pytest/pytest_testcontainers.sh

View File

@ -0,0 +1,7 @@
#!/bin/bash
set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
pytest api/tests/test_containers_integration_tests

View File

@ -538,7 +538,7 @@ services:
milvus-standalone:
container_name: milvus-standalone
image: milvusdb/milvus:v2.5.0-beta
image: milvusdb/milvus:v2.5.15
profiles:
- milvus
command: [ 'milvus', 'run', 'standalone' ]

View File

@ -1087,7 +1087,7 @@ services:
milvus-standalone:
container_name: milvus-standalone
image: milvusdb/milvus:v2.5.0-beta
image: milvusdb/milvus:v2.5.15
profiles:
- milvus
command: [ 'milvus', 'run', 'standalone' ]

View File

@ -265,7 +265,6 @@ export default translation
fs.writeFileSync(path.join(testZhDir, 'pages.ts'), file2Content)
const allEnKeys = await getKeysFromLanguage('en-US')
const allZhKeys = await getKeysFromLanguage('zh-Hans')
// Test file filtering logic
const targetFile = 'components'
@ -563,4 +562,201 @@ export default translation
expect(enKeys.length - zhKeysExtra.length).toBe(-2) // -2 means 2 extra keys
})
})
describe('Auto-remove multiline key-value pairs', () => {
// Helper function to simulate removeExtraKeysFromFile logic
function removeExtraKeysFromFile(content: string, keysToRemove: string[]): string {
const lines = content.split('\n')
const linesToRemove: number[] = []
for (const keyToRemove of keysToRemove) {
let targetLineIndex = -1
const linesToRemoveForKey: number[] = []
// Find the key line (simplified for single-level keys in test)
for (let i = 0; i < lines.length; i++) {
const line = lines[i]
const keyPattern = new RegExp(`^\\s*${keyToRemove}\\s*:`)
if (keyPattern.test(line)) {
targetLineIndex = i
break
}
}
if (targetLineIndex !== -1) {
linesToRemoveForKey.push(targetLineIndex)
// Check if this is a multiline key-value pair
const keyLine = lines[targetLineIndex]
const trimmedKeyLine = keyLine.trim()
// If key line ends with ":" (not complete value), it's likely multiline
if (trimmedKeyLine.endsWith(':') && !trimmedKeyLine.includes('{') && !trimmedKeyLine.match(/:\s*['"`]/)) {
// Find the value lines that belong to this key
let currentLine = targetLineIndex + 1
let foundValue = false
while (currentLine < lines.length) {
const line = lines[currentLine]
const trimmed = line.trim()
// Skip empty lines
if (trimmed === '') {
currentLine++
continue
}
// Check if this line starts a new key (indicates end of current value)
if (trimmed.match(/^\w+\s*:/))
break
// Check if this line is part of the value
if (trimmed.startsWith('\'') || trimmed.startsWith('"') || trimmed.startsWith('`') || foundValue) {
linesToRemoveForKey.push(currentLine)
foundValue = true
// Check if this line ends the value (ends with quote and comma/no comma)
if ((trimmed.endsWith('\',') || trimmed.endsWith('",') || trimmed.endsWith('`,')
|| trimmed.endsWith('\'') || trimmed.endsWith('"') || trimmed.endsWith('`'))
&& !trimmed.startsWith('//'))
break
}
else {
break
}
currentLine++
}
}
linesToRemove.push(...linesToRemoveForKey)
}
}
// Remove duplicates and sort in reverse order
const uniqueLinesToRemove = [...new Set(linesToRemove)].sort((a, b) => b - a)
for (const lineIndex of uniqueLinesToRemove)
lines.splice(lineIndex, 1)
return lines.join('\n')
}
it('should remove single-line key-value pairs correctly', () => {
const content = `const translation = {
keepThis: 'This should stay',
removeThis: 'This should be removed',
alsoKeep: 'This should also stay',
}
export default translation`
const result = removeExtraKeysFromFile(content, ['removeThis'])
expect(result).toContain('keepThis: \'This should stay\'')
expect(result).toContain('alsoKeep: \'This should also stay\'')
expect(result).not.toContain('removeThis: \'This should be removed\'')
})
it('should remove multiline key-value pairs completely', () => {
const content = `const translation = {
keepThis: 'This should stay',
removeMultiline:
'This is a multiline value that should be removed completely',
alsoKeep: 'This should also stay',
}
export default translation`
const result = removeExtraKeysFromFile(content, ['removeMultiline'])
expect(result).toContain('keepThis: \'This should stay\'')
expect(result).toContain('alsoKeep: \'This should also stay\'')
expect(result).not.toContain('removeMultiline:')
expect(result).not.toContain('This is a multiline value that should be removed completely')
})
it('should handle mixed single-line and multiline removals', () => {
const content = `const translation = {
keepThis: 'Keep this',
removeSingle: 'Remove this single line',
removeMultiline:
'Remove this multiline value',
anotherMultiline:
'Another multiline that spans multiple lines',
keepAnother: 'Keep this too',
}
export default translation`
const result = removeExtraKeysFromFile(content, ['removeSingle', 'removeMultiline', 'anotherMultiline'])
expect(result).toContain('keepThis: \'Keep this\'')
expect(result).toContain('keepAnother: \'Keep this too\'')
expect(result).not.toContain('removeSingle:')
expect(result).not.toContain('removeMultiline:')
expect(result).not.toContain('anotherMultiline:')
expect(result).not.toContain('Remove this single line')
expect(result).not.toContain('Remove this multiline value')
expect(result).not.toContain('Another multiline that spans multiple lines')
})
it('should properly detect multiline vs single-line patterns', () => {
const multilineContent = `const translation = {
singleLine: 'This is single line',
multilineKey:
'This is multiline',
keyWithColon: 'Value with: colon inside',
objectKey: {
nested: 'value'
},
}
export default translation`
// Test that single line with colon in value is not treated as multiline
const result1 = removeExtraKeysFromFile(multilineContent, ['keyWithColon'])
expect(result1).not.toContain('keyWithColon:')
expect(result1).not.toContain('Value with: colon inside')
// Test that true multiline is handled correctly
const result2 = removeExtraKeysFromFile(multilineContent, ['multilineKey'])
expect(result2).not.toContain('multilineKey:')
expect(result2).not.toContain('This is multiline')
// Test that object key removal works (note: this is a simplified test)
// In real scenario, object removal would be more complex
const result3 = removeExtraKeysFromFile(multilineContent, ['objectKey'])
expect(result3).not.toContain('objectKey: {')
// Note: Our simplified test function doesn't handle nested object removal perfectly
// This is acceptable as it's testing the main multiline string removal functionality
})
it('should handle real-world Polish translation structure', () => {
const polishContent = `const translation = {
createApp: 'UTWÓRZ APLIKACJĘ',
newApp: {
captionAppType: 'Jaki typ aplikacji chcesz stworzyć?',
chatbotDescription:
'Zbuduj aplikację opartą na czacie. Ta aplikacja używa formatu pytań i odpowiedzi.',
agentDescription:
'Zbuduj inteligentnego agenta, który może autonomicznie wybierać narzędzia.',
basic: 'Podstawowy',
},
}
export default translation`
const result = removeExtraKeysFromFile(polishContent, ['captionAppType', 'chatbotDescription', 'agentDescription'])
expect(result).toContain('createApp: \'UTWÓRZ APLIKACJĘ\'')
expect(result).toContain('basic: \'Podstawowy\'')
expect(result).not.toContain('captionAppType:')
expect(result).not.toContain('chatbotDescription:')
expect(result).not.toContain('agentDescription:')
expect(result).not.toContain('Jaki typ aplikacji')
expect(result).not.toContain('Zbuduj aplikację opartą na czacie')
expect(result).not.toContain('Zbuduj inteligentnego agenta')
})
})
})

View File

@ -0,0 +1,305 @@
/**
* Document Detail Navigation Fix Verification Test
*
* This test specifically validates that the backToPrev function in the document detail
* component correctly preserves pagination and filter states.
*/
import { fireEvent, render, screen } from '@testing-library/react'
import { useRouter } from 'next/navigation'
import { useDocumentDetail, useDocumentMetadata } from '@/service/knowledge/use-document'
// Mock Next.js router
const mockPush = jest.fn()
jest.mock('next/navigation', () => ({
useRouter: jest.fn(() => ({
push: mockPush,
})),
}))
// Mock the document service hooks
jest.mock('@/service/knowledge/use-document', () => ({
useDocumentDetail: jest.fn(),
useDocumentMetadata: jest.fn(),
useInvalidDocumentList: jest.fn(() => jest.fn()),
}))
// Mock other dependencies
jest.mock('@/context/dataset-detail', () => ({
useDatasetDetailContext: jest.fn(() => [null]),
}))
jest.mock('@/service/use-base', () => ({
useInvalid: jest.fn(() => jest.fn()),
}))
jest.mock('@/service/knowledge/use-segment', () => ({
useSegmentListKey: jest.fn(),
useChildSegmentListKey: jest.fn(),
}))
// Create a minimal version of the DocumentDetail component that includes our fix
const DocumentDetailWithFix = ({ datasetId, documentId }: { datasetId: string; documentId: string }) => {
const router = useRouter()
// This is the FIXED implementation from detail/index.tsx
const backToPrev = () => {
// Preserve pagination and filter states when navigating back
const searchParams = new URLSearchParams(window.location.search)
const queryString = searchParams.toString()
const separator = queryString ? '?' : ''
const backPath = `/datasets/${datasetId}/documents${separator}${queryString}`
router.push(backPath)
}
return (
<div data-testid="document-detail-fixed">
<button data-testid="back-button-fixed" onClick={backToPrev}>
Back to Documents
</button>
<div data-testid="document-info">
Dataset: {datasetId}, Document: {documentId}
</div>
</div>
)
}
describe('Document Detail Navigation Fix Verification', () => {
beforeEach(() => {
jest.clearAllMocks()
// Mock successful API responses
;(useDocumentDetail as jest.Mock).mockReturnValue({
data: {
id: 'doc-123',
name: 'Test Document',
display_status: 'available',
enabled: true,
archived: false,
},
error: null,
})
;(useDocumentMetadata as jest.Mock).mockReturnValue({
data: null,
error: null,
})
})
describe('Query Parameter Preservation', () => {
test('preserves pagination state (page 3, limit 25)', () => {
// Simulate user coming from page 3 with 25 items per page
Object.defineProperty(window, 'location', {
value: {
search: '?page=3&limit=25',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
// User clicks back button
fireEvent.click(screen.getByTestId('back-button-fixed'))
// Should preserve the pagination state
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=3&limit=25')
console.log('✅ Pagination state preserved: page=3&limit=25')
})
test('preserves search keyword and filters', () => {
// Simulate user with search and filters applied
Object.defineProperty(window, 'location', {
value: {
search: '?page=2&limit=10&keyword=API%20documentation&status=active',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
fireEvent.click(screen.getByTestId('back-button-fixed'))
// Should preserve all query parameters
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=2&limit=10&keyword=API+documentation&status=active')
console.log('✅ Search and filters preserved')
})
test('handles complex query parameters with special characters', () => {
// Test with complex query string including encoded characters
Object.defineProperty(window, 'location', {
value: {
search: '?page=1&limit=50&keyword=test%20%26%20debug&sort=name&order=desc&filter=%7B%22type%22%3A%22pdf%22%7D',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
fireEvent.click(screen.getByTestId('back-button-fixed'))
// URLSearchParams will normalize the encoding, but preserve all parameters
const expectedCall = mockPush.mock.calls[0][0]
expect(expectedCall).toMatch(/^\/datasets\/dataset-123\/documents\?/)
expect(expectedCall).toMatch(/page=1/)
expect(expectedCall).toMatch(/limit=50/)
expect(expectedCall).toMatch(/keyword=test/)
expect(expectedCall).toMatch(/sort=name/)
expect(expectedCall).toMatch(/order=desc/)
console.log('✅ Complex query parameters handled:', expectedCall)
})
test('handles empty query parameters gracefully', () => {
// No query parameters in URL
Object.defineProperty(window, 'location', {
value: {
search: '',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
fireEvent.click(screen.getByTestId('back-button-fixed'))
// Should navigate to clean documents URL
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents')
console.log('✅ Empty parameters handled gracefully')
})
})
describe('Different Dataset IDs', () => {
test('works with different dataset identifiers', () => {
Object.defineProperty(window, 'location', {
value: {
search: '?page=5&limit=10',
},
writable: true,
})
// Test with different dataset ID format
render(<DocumentDetailWithFix datasetId="ds-prod-2024-001" documentId="doc-456" />)
fireEvent.click(screen.getByTestId('back-button-fixed'))
expect(mockPush).toHaveBeenCalledWith('/datasets/ds-prod-2024-001/documents?page=5&limit=10')
console.log('✅ Works with different dataset ID formats')
})
})
describe('Real User Scenarios', () => {
test('scenario: user searches, goes to page 3, views document, clicks back', () => {
// User searched for "API" and navigated to page 3
Object.defineProperty(window, 'location', {
value: {
search: '?keyword=API&page=3&limit=10',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="main-dataset" documentId="api-doc-123" />)
// User decides to go back to continue browsing
fireEvent.click(screen.getByTestId('back-button-fixed'))
// Should return to page 3 of API search results
expect(mockPush).toHaveBeenCalledWith('/datasets/main-dataset/documents?keyword=API&page=3&limit=10')
console.log('✅ Real user scenario: search + pagination preserved')
})
test('scenario: user applies multiple filters, goes to document, returns', () => {
// User has applied multiple filters and is on page 2
Object.defineProperty(window, 'location', {
value: {
search: '?page=2&limit=25&status=active&type=pdf&sort=created_at&order=desc',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="filtered-dataset" documentId="filtered-doc" />)
fireEvent.click(screen.getByTestId('back-button-fixed'))
// All filters should be preserved
expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-dataset/documents?page=2&limit=25&status=active&type=pdf&sort=created_at&order=desc')
console.log('✅ Complex filtering scenario preserved')
})
})
describe('Error Handling and Edge Cases', () => {
test('handles malformed query parameters gracefully', () => {
// Test with potentially problematic query string
Object.defineProperty(window, 'location', {
value: {
search: '?page=invalid&limit=&keyword=test&=emptykey&malformed',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
// Should not throw errors
expect(() => {
fireEvent.click(screen.getByTestId('back-button-fixed'))
}).not.toThrow()
// Should still attempt navigation (URLSearchParams will clean up the parameters)
expect(mockPush).toHaveBeenCalled()
const navigationPath = mockPush.mock.calls[0][0]
expect(navigationPath).toMatch(/^\/datasets\/dataset-123\/documents/)
console.log('✅ Malformed parameters handled gracefully:', navigationPath)
})
test('handles very long query strings', () => {
// Test with a very long query string
const longKeyword = 'a'.repeat(1000)
Object.defineProperty(window, 'location', {
value: {
search: `?page=1&keyword=${longKeyword}`,
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
expect(() => {
fireEvent.click(screen.getByTestId('back-button-fixed'))
}).not.toThrow()
expect(mockPush).toHaveBeenCalled()
console.log('✅ Long query strings handled')
})
})
describe('Performance Verification', () => {
test('navigation function executes quickly', () => {
Object.defineProperty(window, 'location', {
value: {
search: '?page=1&limit=10&keyword=test',
},
writable: true,
})
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
const startTime = performance.now()
fireEvent.click(screen.getByTestId('back-button-fixed'))
const endTime = performance.now()
const executionTime = endTime - startTime
// Should execute in less than 10ms
expect(executionTime).toBeLessThan(10)
console.log(`⚡ Navigation execution time: ${executionTime.toFixed(2)}ms`)
})
})
})

View File

@ -0,0 +1,83 @@
/**
* Document List Sorting Tests
*/
describe('Document List Sorting', () => {
const mockDocuments = [
{ id: '1', name: 'Beta.pdf', word_count: 500, hit_count: 10, created_at: 1699123456 },
{ id: '2', name: 'Alpha.txt', word_count: 200, hit_count: 25, created_at: 1699123400 },
{ id: '3', name: 'Gamma.docx', word_count: 800, hit_count: 5, created_at: 1699123500 },
]
const sortDocuments = (docs: any[], field: string, order: 'asc' | 'desc') => {
return [...docs].sort((a, b) => {
let aValue: any
let bValue: any
switch (field) {
case 'name':
aValue = a.name?.toLowerCase() || ''
bValue = b.name?.toLowerCase() || ''
break
case 'word_count':
aValue = a.word_count || 0
bValue = b.word_count || 0
break
case 'hit_count':
aValue = a.hit_count || 0
bValue = b.hit_count || 0
break
case 'created_at':
aValue = a.created_at
bValue = b.created_at
break
default:
return 0
}
if (field === 'name') {
const result = aValue.localeCompare(bValue)
return order === 'asc' ? result : -result
}
else {
const result = aValue - bValue
return order === 'asc' ? result : -result
}
})
}
test('sorts by name descending (default for UI consistency)', () => {
const sorted = sortDocuments(mockDocuments, 'name', 'desc')
expect(sorted.map(doc => doc.name)).toEqual(['Gamma.docx', 'Beta.pdf', 'Alpha.txt'])
})
test('sorts by name ascending (after toggle)', () => {
const sorted = sortDocuments(mockDocuments, 'name', 'asc')
expect(sorted.map(doc => doc.name)).toEqual(['Alpha.txt', 'Beta.pdf', 'Gamma.docx'])
})
test('sorts by word_count descending', () => {
const sorted = sortDocuments(mockDocuments, 'word_count', 'desc')
expect(sorted.map(doc => doc.word_count)).toEqual([800, 500, 200])
})
test('sorts by hit_count descending', () => {
const sorted = sortDocuments(mockDocuments, 'hit_count', 'desc')
expect(sorted.map(doc => doc.hit_count)).toEqual([25, 10, 5])
})
test('sorts by created_at descending (newest first)', () => {
const sorted = sortDocuments(mockDocuments, 'created_at', 'desc')
expect(sorted.map(doc => doc.created_at)).toEqual([1699123500, 1699123456, 1699123400])
})
test('handles empty values correctly', () => {
const docsWithEmpty = [
{ id: '1', name: 'Test', word_count: 100, hit_count: 5, created_at: 1699123456 },
{ id: '2', name: 'Empty', word_count: 0, hit_count: 0, created_at: 1699123400 },
]
const sorted = sortDocuments(docsWithEmpty, 'word_count', 'desc')
expect(sorted.map(doc => doc.word_count)).toEqual([100, 0])
})
})

View File

@ -0,0 +1,290 @@
/**
* Navigation Utilities Test
*
* Tests for the navigation utility functions to ensure they handle
* query parameter preservation correctly across different scenarios.
*/
import {
createBackNavigation,
createNavigationPath,
createNavigationPathWithParams,
datasetNavigation,
extractQueryParams,
mergeQueryParams,
} from '@/utils/navigation'
// Mock router for testing
const mockPush = jest.fn()
const mockRouter = { push: mockPush }
describe('Navigation Utilities', () => {
beforeEach(() => {
jest.clearAllMocks()
})
describe('createNavigationPath', () => {
test('preserves query parameters by default', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3&limit=10&keyword=test' },
writable: true,
})
const path = createNavigationPath('/datasets/123/documents')
expect(path).toBe('/datasets/123/documents?page=3&limit=10&keyword=test')
})
test('returns clean path when preserveParams is false', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3&limit=10' },
writable: true,
})
const path = createNavigationPath('/datasets/123/documents', false)
expect(path).toBe('/datasets/123/documents')
})
test('handles empty query parameters', () => {
Object.defineProperty(window, 'location', {
value: { search: '' },
writable: true,
})
const path = createNavigationPath('/datasets/123/documents')
expect(path).toBe('/datasets/123/documents')
})
test('handles errors gracefully', () => {
// Mock window.location to throw an error
Object.defineProperty(window, 'location', {
get: () => {
throw new Error('Location access denied')
},
configurable: true,
})
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
const path = createNavigationPath('/datasets/123/documents')
expect(path).toBe('/datasets/123/documents')
expect(consoleSpy).toHaveBeenCalledWith('Failed to preserve query parameters:', expect.any(Error))
consoleSpy.mockRestore()
})
})
describe('createBackNavigation', () => {
test('creates function that navigates with preserved params', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=2&limit=25' },
writable: true,
})
const backFn = createBackNavigation(mockRouter, '/datasets/123/documents')
backFn()
expect(mockPush).toHaveBeenCalledWith('/datasets/123/documents?page=2&limit=25')
})
test('creates function that navigates without params when specified', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=2&limit=25' },
writable: true,
})
const backFn = createBackNavigation(mockRouter, '/datasets/123/documents', false)
backFn()
expect(mockPush).toHaveBeenCalledWith('/datasets/123/documents')
})
})
describe('extractQueryParams', () => {
test('extracts specified parameters', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3&limit=10&keyword=test&other=value' },
writable: true,
})
const params = extractQueryParams(['page', 'limit', 'keyword'])
expect(params).toEqual({
page: '3',
limit: '10',
keyword: 'test',
})
})
test('handles missing parameters', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3' },
writable: true,
})
const params = extractQueryParams(['page', 'limit', 'missing'])
expect(params).toEqual({
page: '3',
})
})
test('handles errors gracefully', () => {
Object.defineProperty(window, 'location', {
get: () => {
throw new Error('Location access denied')
},
configurable: true,
})
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
const params = extractQueryParams(['page', 'limit'])
expect(params).toEqual({})
expect(consoleSpy).toHaveBeenCalledWith('Failed to extract query parameters:', expect.any(Error))
consoleSpy.mockRestore()
})
})
describe('createNavigationPathWithParams', () => {
test('creates path with specified parameters', () => {
const path = createNavigationPathWithParams('/datasets/123/documents', {
page: 1,
limit: 25,
keyword: 'search term',
})
expect(path).toBe('/datasets/123/documents?page=1&limit=25&keyword=search+term')
})
test('filters out empty values', () => {
const path = createNavigationPathWithParams('/datasets/123/documents', {
page: 1,
limit: '',
keyword: 'test',
empty: null,
undefined,
})
expect(path).toBe('/datasets/123/documents?page=1&keyword=test')
})
test('handles errors gracefully', () => {
// Mock URLSearchParams to throw an error
const originalURLSearchParams = globalThis.URLSearchParams
globalThis.URLSearchParams = jest.fn(() => {
throw new Error('URLSearchParams error')
}) as any
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
const path = createNavigationPathWithParams('/datasets/123/documents', { page: 1 })
expect(path).toBe('/datasets/123/documents')
expect(consoleSpy).toHaveBeenCalledWith('Failed to create navigation path with params:', expect.any(Error))
consoleSpy.mockRestore()
globalThis.URLSearchParams = originalURLSearchParams
})
})
describe('mergeQueryParams', () => {
test('merges new params with existing ones', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3&limit=10' },
writable: true,
})
const merged = mergeQueryParams({ keyword: 'test', page: '1' })
const result = merged.toString()
expect(result).toContain('page=1') // overridden
expect(result).toContain('limit=10') // preserved
expect(result).toContain('keyword=test') // added
})
test('removes parameters when value is null', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3&limit=10&keyword=test' },
writable: true,
})
const merged = mergeQueryParams({ keyword: null, filter: 'active' })
const result = merged.toString()
expect(result).toContain('page=3')
expect(result).toContain('limit=10')
expect(result).not.toContain('keyword')
expect(result).toContain('filter=active')
})
test('creates fresh params when preserveExisting is false', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=3&limit=10' },
writable: true,
})
const merged = mergeQueryParams({ keyword: 'test' }, false)
const result = merged.toString()
expect(result).toBe('keyword=test')
})
})
describe('datasetNavigation', () => {
test('backToDocuments creates correct navigation function', () => {
Object.defineProperty(window, 'location', {
value: { search: '?page=2&limit=25' },
writable: true,
})
const backFn = datasetNavigation.backToDocuments(mockRouter, 'dataset-123')
backFn()
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=2&limit=25')
})
test('toDocumentDetail creates correct navigation function', () => {
const detailFn = datasetNavigation.toDocumentDetail(mockRouter, 'dataset-123', 'doc-456')
detailFn()
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents/doc-456')
})
test('toDocumentSettings creates correct navigation function', () => {
const settingsFn = datasetNavigation.toDocumentSettings(mockRouter, 'dataset-123', 'doc-456')
settingsFn()
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents/doc-456/settings')
})
})
describe('Real-world Integration Scenarios', () => {
test('complete user workflow: list -> detail -> back', () => {
// User starts on page 3 with search
Object.defineProperty(window, 'location', {
value: { search: '?page=3&keyword=API&limit=25' },
writable: true,
})
// Create back navigation function (as would be done in detail component)
const backToDocuments = datasetNavigation.backToDocuments(mockRouter, 'main-dataset')
// User clicks back
backToDocuments()
// Should return to exact same list state
expect(mockPush).toHaveBeenCalledWith('/datasets/main-dataset/documents?page=3&keyword=API&limit=25')
})
test('user applies filters then views document', () => {
// Complex filter state
Object.defineProperty(window, 'location', {
value: { search: '?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc' },
writable: true,
})
const backFn = createBackNavigation(mockRouter, '/datasets/filtered-set/documents')
backFn()
expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-set/documents?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc')
})
})
})

View File

@ -0,0 +1,396 @@
/**
* Unified Tags Editing - Pure Logic Tests
*
* This test file validates the core business logic and state management
* behaviors introduced in the recent 7 commits without requiring complex mocks.
*/
describe('Unified Tags Editing - Pure Logic Tests', () => {
describe('Tag State Management Logic', () => {
it('should detect when tag values have changed', () => {
const currentValue = ['tag1', 'tag2']
const newSelectedTagIDs = ['tag1', 'tag3']
// This is the valueNotChanged logic from TagSelector component
const valueNotChanged
= currentValue.length === newSelectedTagIDs.length
&& currentValue.every(v => newSelectedTagIDs.includes(v))
&& newSelectedTagIDs.every(v => currentValue.includes(v))
expect(valueNotChanged).toBe(false)
})
it('should correctly identify unchanged tag values', () => {
const currentValue = ['tag1', 'tag2']
const newSelectedTagIDs = ['tag2', 'tag1'] // Same tags, different order
const valueNotChanged
= currentValue.length === newSelectedTagIDs.length
&& currentValue.every(v => newSelectedTagIDs.includes(v))
&& newSelectedTagIDs.every(v => currentValue.includes(v))
expect(valueNotChanged).toBe(true)
})
it('should calculate correct tag operations for binding/unbinding', () => {
const currentValue = ['tag1', 'tag2']
const selectedTagIDs = ['tag2', 'tag3']
// This is the handleValueChange logic from TagSelector
const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v))
const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v))
expect(addTagIDs).toEqual(['tag3'])
expect(removeTagIDs).toEqual(['tag1'])
})
it('should handle empty tag arrays correctly', () => {
const currentValue: string[] = []
const selectedTagIDs = ['tag1']
const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v))
const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v))
expect(addTagIDs).toEqual(['tag1'])
expect(removeTagIDs).toEqual([])
expect(currentValue.length).toBe(0) // Verify empty array usage
})
it('should handle removing all tags', () => {
const currentValue = ['tag1', 'tag2']
const selectedTagIDs: string[] = []
const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v))
const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v))
expect(addTagIDs).toEqual([])
expect(removeTagIDs).toEqual(['tag1', 'tag2'])
expect(selectedTagIDs.length).toBe(0) // Verify empty array usage
})
})
describe('Fallback Logic (from layout-main.tsx)', () => {
it('should trigger fallback when tags are missing or empty', () => {
const appDetailWithoutTags = { tags: [] }
const appDetailWithTags = { tags: [{ id: 'tag1' }] }
const appDetailWithUndefinedTags = { tags: undefined as any }
// This simulates the condition in layout-main.tsx
const shouldFallback1 = !appDetailWithoutTags.tags || appDetailWithoutTags.tags.length === 0
const shouldFallback2 = !appDetailWithTags.tags || appDetailWithTags.tags.length === 0
const shouldFallback3 = !appDetailWithUndefinedTags.tags || appDetailWithUndefinedTags.tags.length === 0
expect(shouldFallback1).toBe(true) // Empty array should trigger fallback
expect(shouldFallback2).toBe(false) // Has tags, no fallback needed
expect(shouldFallback3).toBe(true) // Undefined tags should trigger fallback
})
it('should preserve tags when fallback succeeds', () => {
const originalAppDetail = { tags: [] as any[] }
const fallbackResult = { tags: [{ id: 'tag1', name: 'fallback-tag' }] }
// This simulates the successful fallback in layout-main.tsx
if (fallbackResult?.tags)
originalAppDetail.tags = fallbackResult.tags
expect(originalAppDetail.tags).toEqual(fallbackResult.tags)
expect(originalAppDetail.tags.length).toBe(1)
})
it('should continue with empty tags when fallback fails', () => {
const originalAppDetail: { tags: any[] } = { tags: [] }
const fallbackResult: { tags?: any[] } | null = null
// This simulates fallback failure in layout-main.tsx
if (fallbackResult?.tags)
originalAppDetail.tags = fallbackResult.tags
expect(originalAppDetail.tags).toEqual([])
})
})
describe('TagSelector Auto-initialization Logic', () => {
it('should trigger getTagList when tagList is empty', () => {
const tagList: any[] = []
let getTagListCalled = false
const getTagList = () => {
getTagListCalled = true
}
// This simulates the useEffect in TagSelector
if (tagList.length === 0)
getTagList()
expect(getTagListCalled).toBe(true)
})
it('should not trigger getTagList when tagList has items', () => {
const tagList = [{ id: 'tag1', name: 'existing-tag' }]
let getTagListCalled = false
const getTagList = () => {
getTagListCalled = true
}
// This simulates the useEffect in TagSelector
if (tagList.length === 0)
getTagList()
expect(getTagListCalled).toBe(false)
})
})
describe('State Initialization Patterns', () => {
it('should maintain AppCard tag state pattern', () => {
const app = { tags: [{ id: 'tag1', name: 'test' }] }
// Original AppCard pattern: useState(app.tags)
const initialTags = app.tags
expect(Array.isArray(initialTags)).toBe(true)
expect(initialTags.length).toBe(1)
expect(initialTags).toBe(app.tags) // Reference equality for AppCard
})
it('should maintain AppInfo tag state pattern', () => {
const appDetail = { tags: [{ id: 'tag1', name: 'test' }] }
// New AppInfo pattern: useState(appDetail?.tags || [])
const initialTags = appDetail?.tags || []
expect(Array.isArray(initialTags)).toBe(true)
expect(initialTags.length).toBe(1)
})
it('should handle undefined appDetail gracefully in AppInfo', () => {
const appDetail = undefined
// AppInfo pattern with undefined appDetail
const initialTags = (appDetail as any)?.tags || []
expect(Array.isArray(initialTags)).toBe(true)
expect(initialTags.length).toBe(0)
})
})
describe('CSS Class and Layout Logic', () => {
it('should apply correct minimum width condition', () => {
const minWidth = 'true'
// This tests the minWidth logic in TagSelector
const shouldApplyMinWidth = minWidth && '!min-w-80'
expect(shouldApplyMinWidth).toBe('!min-w-80')
})
it('should not apply minimum width when not specified', () => {
const minWidth = undefined
const shouldApplyMinWidth = minWidth && '!min-w-80'
expect(shouldApplyMinWidth).toBeFalsy()
})
it('should handle overflow layout classes correctly', () => {
// This tests the layout pattern from AppCard and new AppInfo
const overflowLayoutClasses = {
container: 'flex w-0 grow items-center',
inner: 'w-full',
truncate: 'truncate',
}
expect(overflowLayoutClasses.container).toContain('w-0 grow')
expect(overflowLayoutClasses.inner).toContain('w-full')
expect(overflowLayoutClasses.truncate).toBe('truncate')
})
})
describe('fetchAppWithTags Service Logic', () => {
it('should correctly find app by ID from app list', () => {
const appList = [
{ id: 'app1', name: 'App 1', tags: [] },
{ id: 'test-app-id', name: 'Test App', tags: [{ id: 'tag1', name: 'test' }] },
{ id: 'app3', name: 'App 3', tags: [] },
]
const targetAppId = 'test-app-id'
// This simulates the logic in fetchAppWithTags
const foundApp = appList.find(app => app.id === targetAppId)
expect(foundApp).toBeDefined()
expect(foundApp?.id).toBe('test-app-id')
expect(foundApp?.tags.length).toBe(1)
})
it('should return null when app not found', () => {
const appList = [
{ id: 'app1', name: 'App 1' },
{ id: 'app2', name: 'App 2' },
]
const targetAppId = 'nonexistent-app'
const foundApp = appList.find(app => app.id === targetAppId) || null
expect(foundApp).toBeNull()
})
it('should handle empty app list', () => {
const appList: any[] = []
const targetAppId = 'any-app'
const foundApp = appList.find(app => app.id === targetAppId) || null
expect(foundApp).toBeNull()
expect(appList.length).toBe(0) // Verify empty array usage
})
})
describe('Data Structure Validation', () => {
it('should maintain consistent tag data structure', () => {
const tag = {
id: 'tag1',
name: 'test-tag',
type: 'app',
binding_count: 1,
}
expect(tag).toHaveProperty('id')
expect(tag).toHaveProperty('name')
expect(tag).toHaveProperty('type')
expect(tag).toHaveProperty('binding_count')
expect(tag.type).toBe('app')
expect(typeof tag.binding_count).toBe('number')
})
it('should handle tag arrays correctly', () => {
const tags = [
{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 1 },
{ id: 'tag2', name: 'Tag 2', type: 'app', binding_count: 0 },
]
expect(Array.isArray(tags)).toBe(true)
expect(tags.length).toBe(2)
expect(tags.every(tag => tag.type === 'app')).toBe(true)
})
it('should validate app data structure with tags', () => {
const app = {
id: 'test-app',
name: 'Test App',
tags: [
{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 1 },
],
}
expect(app).toHaveProperty('id')
expect(app).toHaveProperty('name')
expect(app).toHaveProperty('tags')
expect(Array.isArray(app.tags)).toBe(true)
expect(app.tags.length).toBe(1)
})
})
describe('Performance and Edge Cases', () => {
it('should handle large tag arrays efficiently', () => {
const largeTags = Array.from({ length: 100 }, (_, i) => `tag${i}`)
const selectedTags = ['tag1', 'tag50', 'tag99']
// Performance test: filtering should be efficient
const startTime = Date.now()
const addTags = selectedTags.filter(tag => !largeTags.includes(tag))
const removeTags = largeTags.filter(tag => !selectedTags.includes(tag))
const endTime = Date.now()
expect(endTime - startTime).toBeLessThan(10) // Should be very fast
expect(addTags.length).toBe(0) // All selected tags exist
expect(removeTags.length).toBe(97) // 100 - 3 = 97 tags to remove
})
it('should handle malformed tag data gracefully', () => {
const mixedData = [
{ id: 'valid1', name: 'Valid Tag', type: 'app', binding_count: 1 },
{ id: 'invalid1' }, // Missing required properties
null,
undefined,
{ id: 'valid2', name: 'Another Valid', type: 'app', binding_count: 0 },
]
// Filter out invalid entries
const validTags = mixedData.filter((tag): tag is { id: string; name: string; type: string; binding_count: number } =>
tag != null
&& typeof tag === 'object'
&& 'id' in tag
&& 'name' in tag
&& 'type' in tag
&& 'binding_count' in tag
&& typeof tag.binding_count === 'number',
)
expect(validTags.length).toBe(2)
expect(validTags.every(tag => tag.id && tag.name)).toBe(true)
})
it('should handle concurrent tag operations correctly', () => {
const operations = [
{ type: 'add', tagIds: ['tag1', 'tag2'] },
{ type: 'remove', tagIds: ['tag3'] },
{ type: 'add', tagIds: ['tag4'] },
]
// Simulate processing operations
const results = operations.map(op => ({
...op,
processed: true,
timestamp: Date.now(),
}))
expect(results.length).toBe(3)
expect(results.every(result => result.processed)).toBe(true)
})
})
describe('Backward Compatibility Verification', () => {
it('should not break existing AppCard behavior', () => {
// Verify AppCard continues to work with original patterns
const originalAppCardLogic = {
initializeTags: (app: any) => app.tags,
updateTags: (_currentTags: any[], newTags: any[]) => newTags,
shouldRefresh: true,
}
const app = { tags: [{ id: 'tag1', name: 'original' }] }
const initializedTags = originalAppCardLogic.initializeTags(app)
expect(initializedTags).toBe(app.tags)
expect(originalAppCardLogic.shouldRefresh).toBe(true)
})
it('should ensure AppInfo follows AppCard patterns', () => {
// Verify AppInfo uses compatible state management
const appCardPattern = (app: any) => app.tags
const appInfoPattern = (appDetail: any) => appDetail?.tags || []
const appWithTags = { tags: [{ id: 'tag1' }] }
const appWithoutTags = { tags: [] }
const undefinedApp = undefined
expect(appCardPattern(appWithTags)).toEqual(appInfoPattern(appWithTags))
expect(appInfoPattern(appWithoutTags)).toEqual([])
expect(appInfoPattern(undefinedApp)).toEqual([])
})
it('should maintain consistent API parameters', () => {
// Verify service layer maintains expected parameters
const fetchAppListParams = {
url: '/apps',
params: { page: 1, limit: 100 },
}
const tagApiParams = {
bindTag: (tagIDs: string[], targetID: string, type: string) => ({ tagIDs, targetID, type }),
unBindTag: (tagID: string, targetID: string, type: string) => ({ tagID, targetID, type }),
}
expect(fetchAppListParams.url).toBe('/apps')
expect(fetchAppListParams.params.limit).toBe(100)
const bindResult = tagApiParams.bindTag(['tag1'], 'app1', 'app')
expect(bindResult.tagIDs).toEqual(['tag1'])
expect(bindResult.type).toBe('app')
})
})
})

View File

@ -0,0 +1,212 @@
/**
* XSS Fix Verification Test
*
* This test verifies that the XSS vulnerability in check-code pages has been
* properly fixed by replacing dangerouslySetInnerHTML with safe React rendering.
*/
import React from 'react'
import { cleanup, render } from '@testing-library/react'
import '@testing-library/jest-dom'
// Mock i18next with the new safe translation structure
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => {
if (key === 'login.checkCode.tipsPrefix')
return 'We send a verification code to '
return key
},
}),
}))
// Mock Next.js useSearchParams
jest.mock('next/navigation', () => ({
useSearchParams: () => ({
get: (key: string) => {
if (key === 'email')
return 'test@example.com<script>alert("XSS")</script>'
return null
},
}),
}))
// Fixed CheckCode component implementation (current secure version)
const SecureCheckCodeComponent = ({ email }: { email: string }) => {
const { t } = require('react-i18next').useTranslation()
return (
<div>
<h1>Check Code</h1>
<p>
<span>
{t('login.checkCode.tipsPrefix')}
<strong>{email}</strong>
</span>
</p>
</div>
)
}
// Vulnerable implementation for comparison (what we fixed)
const VulnerableCheckCodeComponent = ({ email }: { email: string }) => {
const mockTranslation = (key: string, params?: any) => {
if (key === 'login.checkCode.tips' && params?.email)
return `We send a verification code to <strong>${params.email}</strong>`
return key
}
return (
<div>
<h1>Check Code</h1>
<p>
<span dangerouslySetInnerHTML={{ __html: mockTranslation('login.checkCode.tips', { email }) }}></span>
</p>
</div>
)
}
describe('XSS Fix Verification - Check Code Pages Security', () => {
afterEach(() => {
cleanup()
})
const maliciousEmail = 'test@example.com<script>alert("XSS")</script>'
it('should securely render email with HTML characters as text (FIXED VERSION)', () => {
console.log('\n🔒 Security Fix Verification Report')
console.log('===================================')
const { container } = render(<SecureCheckCodeComponent email={maliciousEmail} />)
const spanElement = container.querySelector('span')
const strongElement = container.querySelector('strong')
const scriptElements = container.querySelectorAll('script')
console.log('\n✅ Fixed Implementation Results:')
console.log('- Email rendered in strong tag:', strongElement?.textContent)
console.log('- HTML tags visible as text:', strongElement?.textContent?.includes('<script>'))
console.log('- Script elements created:', scriptElements.length)
console.log('- Full text content:', spanElement?.textContent)
// Verify secure behavior
expect(strongElement?.textContent).toBe(maliciousEmail) // Email rendered as text
expect(strongElement?.textContent).toContain('<script>') // HTML visible as text
expect(scriptElements).toHaveLength(0) // No script elements created
expect(spanElement?.textContent).toBe(`We send a verification code to ${maliciousEmail}`)
console.log('\n🎯 Security Status: SECURE - HTML automatically escaped by React')
})
it('should demonstrate the vulnerability that was fixed (VULNERABLE VERSION)', () => {
const { container } = render(<VulnerableCheckCodeComponent email={maliciousEmail} />)
const spanElement = container.querySelector('span')
const strongElement = container.querySelector('strong')
const scriptElements = container.querySelectorAll('script')
console.log('\n⚠ Previous Vulnerable Implementation:')
console.log('- HTML content:', spanElement?.innerHTML)
console.log('- Strong element text:', strongElement?.textContent)
console.log('- Script elements created:', scriptElements.length)
console.log('- Script content:', scriptElements[0]?.textContent)
// Verify vulnerability exists in old implementation
expect(scriptElements).toHaveLength(1) // Script element was created
expect(scriptElements[0]?.textContent).toBe('alert("XSS")') // Contains malicious code
expect(spanElement?.innerHTML).toContain('<script>') // Raw HTML in DOM
console.log('\n❌ Security Status: VULNERABLE - dangerouslySetInnerHTML creates script elements')
})
it('should verify all affected components use the secure pattern', () => {
console.log('\n📋 Component Security Audit')
console.log('============================')
// Test multiple malicious inputs
const testCases = [
'user@test.com<img src=x onerror=alert(1)>',
'test@evil.com<div onclick="alert(2)">click</div>',
'admin@site.com<script>document.cookie="stolen"</script>',
'normal@email.com',
]
testCases.forEach((testEmail, index) => {
const { container } = render(<SecureCheckCodeComponent email={testEmail} />)
const strongElement = container.querySelector('strong')
const scriptElements = container.querySelectorAll('script')
const imgElements = container.querySelectorAll('img')
const divElements = container.querySelectorAll('div:not([data-testid])')
console.log(`\n📧 Test Case ${index + 1}: ${testEmail.substring(0, 20)}...`)
console.log(` - Script elements: ${scriptElements.length}`)
console.log(` - Img elements: ${imgElements.length}`)
console.log(` - Malicious divs: ${divElements.length - 1}`) // -1 for container div
console.log(` - Text content: ${strongElement?.textContent === testEmail ? 'SAFE' : 'ISSUE'}`)
// All should be safe
expect(scriptElements).toHaveLength(0)
expect(imgElements).toHaveLength(0)
expect(strongElement?.textContent).toBe(testEmail)
})
console.log('\n✅ All test cases passed - secure rendering confirmed')
})
it('should validate the translation structure is secure', () => {
console.log('\n🔍 Translation Security Analysis')
console.log('=================================')
const { t } = require('react-i18next').useTranslation()
const prefix = t('login.checkCode.tipsPrefix')
console.log('- Translation key used: login.checkCode.tipsPrefix')
console.log('- Translation value:', prefix)
console.log('- Contains HTML tags:', prefix.includes('<'))
console.log('- Pure text content:', !prefix.includes('<') && !prefix.includes('>'))
// Verify translation is plain text
expect(prefix).toBe('We send a verification code to ')
expect(prefix).not.toContain('<')
expect(prefix).not.toContain('>')
expect(typeof prefix).toBe('string')
console.log('\n✅ Translation structure is secure - no HTML content')
})
it('should confirm React automatic escaping works correctly', () => {
console.log('\n⚡ React Security Mechanism Test')
console.log('=================================')
// Test React's automatic escaping with various inputs
const dangerousInputs = [
'<script>alert("xss")</script>',
'<img src="x" onerror="alert(1)">',
'"><script>alert(2)</script>',
'\'>alert(3)</script>',
'<div onclick="alert(4)">click</div>',
]
dangerousInputs.forEach((input, index) => {
const TestComponent = () => <strong>{input}</strong>
const { container } = render(<TestComponent />)
const strongElement = container.querySelector('strong')
const scriptElements = container.querySelectorAll('script')
console.log(`\n🧪 Input ${index + 1}: ${input.substring(0, 30)}...`)
console.log(` - Rendered as text: ${strongElement?.textContent === input}`)
console.log(` - No script execution: ${scriptElements.length === 0}`)
expect(strongElement?.textContent).toBe(input)
expect(scriptElements).toHaveLength(0)
})
console.log('\n🛡 React automatic escaping is working perfectly')
})
})
export {}

View File

@ -20,12 +20,18 @@ import cn from '@/utils/classnames'
import { useStore } from '@/app/components/app/store'
import AppSideBar from '@/app/components/app-sidebar'
import type { NavIcon } from '@/app/components/app-sidebar/navLink'
import { fetchAppDetail } from '@/service/apps'
import { fetchAppDetailDirect } from '@/service/apps'
import { useAppContext } from '@/context/app-context'
import Loading from '@/app/components/base/loading'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import type { App } from '@/types/app'
import useDocumentTitle from '@/hooks/use-document-title'
import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
import dynamic from 'next/dynamic'
const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), {
ssr: false,
})
export type IAppDetailLayoutProps = {
children: React.ReactNode
@ -48,6 +54,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
setAppDetail: state.setAppDetail,
setAppSiderbarExpand: state.setAppSiderbarExpand,
})))
const showTagManagementModal = useTagStore(s => s.showTagManagementModal)
const [isLoadingAppDetail, setIsLoadingAppDetail] = useState(false)
const [appDetailRes, setAppDetailRes] = useState<App | null>(null)
const [navigation, setNavigation] = useState<Array<{
@ -111,7 +118,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
useEffect(() => {
setAppDetail()
setIsLoadingAppDetail(true)
fetchAppDetail({ url: '/apps', id: appId }).then((res) => {
fetchAppDetailDirect({ url: '/apps', id: appId }).then((res: App) => {
setAppDetailRes(res)
}).catch((e: any) => {
if (e.status === 404)
@ -163,6 +170,9 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
<div className="grow overflow-hidden bg-components-panel-bg">
{children}
</div>
{showTagManagementModal && (
<TagManagementModal type='app' show={showTagManagementModal} />
)}
</div>
)
}

View File

@ -0,0 +1,156 @@
import React from 'react'
import { render } from '@testing-library/react'
import '@testing-library/jest-dom'
import { OpikIconBig } from '@/app/components/base/icons/src/public/tracing'
// Mock dependencies to isolate the SVG rendering issue
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('SVG Attribute Error Reproduction', () => {
// Capture console errors
const originalError = console.error
let errorMessages: string[] = []
beforeEach(() => {
errorMessages = []
console.error = jest.fn((message) => {
errorMessages.push(message)
originalError(message)
})
})
afterEach(() => {
console.error = originalError
})
it('should reproduce inkscape attribute errors when rendering OpikIconBig', () => {
console.log('\n=== TESTING OpikIconBig SVG ATTRIBUTE ERRORS ===')
// Test multiple renders to check for inconsistency
for (let i = 0; i < 5; i++) {
console.log(`\nRender attempt ${i + 1}:`)
const { unmount } = render(<OpikIconBig />)
// Check for specific inkscape attribute errors
const inkscapeErrors = errorMessages.filter(msg =>
typeof msg === 'string' && msg.includes('inkscape'),
)
if (inkscapeErrors.length > 0) {
console.log(`Found ${inkscapeErrors.length} inkscape errors:`)
inkscapeErrors.forEach((error, index) => {
console.log(` ${index + 1}. ${error.substring(0, 100)}...`)
})
}
else {
console.log('No inkscape errors found in this render')
}
unmount()
// Clear errors for next iteration
errorMessages = []
}
})
it('should analyze the SVG structure causing the errors', () => {
console.log('\n=== ANALYZING SVG STRUCTURE ===')
// Import the JSON data directly
const iconData = require('@/app/components/base/icons/src/public/tracing/OpikIconBig.json')
console.log('Icon structure analysis:')
console.log('- Root element:', iconData.icon.name)
console.log('- Children count:', iconData.icon.children?.length || 0)
// Find problematic elements
const findProblematicElements = (node: any, path = '') => {
const problematicElements: any[] = []
if (node.name && (node.name.includes(':') || node.name.startsWith('sodipodi'))) {
problematicElements.push({
path,
name: node.name,
attributes: Object.keys(node.attributes || {}),
})
}
// Check attributes for inkscape/sodipodi properties
if (node.attributes) {
const problematicAttrs = Object.keys(node.attributes).filter(attr =>
attr.startsWith('inkscape:') || attr.startsWith('sodipodi:'),
)
if (problematicAttrs.length > 0) {
problematicElements.push({
path,
name: node.name,
problematicAttributes: problematicAttrs,
})
}
}
if (node.children) {
node.children.forEach((child: any, index: number) => {
problematicElements.push(
...findProblematicElements(child, `${path}/${node.name}[${index}]`),
)
})
}
return problematicElements
}
const problematicElements = findProblematicElements(iconData.icon, 'root')
console.log(`\n🚨 Found ${problematicElements.length} problematic elements:`)
problematicElements.forEach((element, index) => {
console.log(`\n${index + 1}. Element: ${element.name}`)
console.log(` Path: ${element.path}`)
if (element.problematicAttributes)
console.log(` Problematic attributes: ${element.problematicAttributes.join(', ')}`)
})
})
it('should test the normalizeAttrs function behavior', () => {
console.log('\n=== TESTING normalizeAttrs FUNCTION ===')
const { normalizeAttrs } = require('@/app/components/base/icons/utils')
const testAttributes = {
'inkscape:showpageshadow': '2',
'inkscape:pageopacity': '0.0',
'inkscape:pagecheckerboard': '0',
'inkscape:deskcolor': '#d1d1d1',
'sodipodi:docname': 'opik-icon-big.svg',
'xmlns:inkscape': 'https://www.inkscape.org/namespaces/inkscape',
'xmlns:sodipodi': 'https://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd',
'xmlns:svg': 'https://www.w3.org/2000/svg',
'data-name': 'Layer 1',
'normal-attr': 'value',
'class': 'test-class',
}
console.log('Input attributes:', Object.keys(testAttributes))
const normalized = normalizeAttrs(testAttributes)
console.log('Normalized attributes:', Object.keys(normalized))
console.log('Normalized values:', normalized)
// Check if problematic attributes are still present
const problematicKeys = Object.keys(normalized).filter(key =>
key.toLowerCase().includes('inkscape') || key.toLowerCase().includes('sodipodi'),
)
if (problematicKeys.length > 0)
console.log(`🚨 PROBLEM: Still found problematic attributes: ${problematicKeys.join(', ')}`)
else
console.log('✅ No problematic attributes found after normalization')
})
})

View File

@ -1,12 +1,9 @@
'use client'
import type { FC } from 'react'
import React, { useCallback, useEffect, useRef, useState } from 'react'
import {
RiEqualizer2Line,
} from '@remixicon/react'
import React, { useCallback, useRef, useState } from 'react'
import type { PopupProps } from './config-popup'
import ConfigPopup from './config-popup'
import cn from '@/utils/classnames'
import {
PortalToFollowElem,
PortalToFollowElemContent,
@ -17,13 +14,13 @@ type Props = {
readOnly: boolean
className?: string
hasConfigured: boolean
controlShowPopup?: number
children?: React.ReactNode
} & PopupProps
const ConfigBtn: FC<Props> = ({
className,
hasConfigured,
controlShowPopup,
children,
...popupProps
}) => {
const [open, doSetOpen] = useState(false)
@ -37,13 +34,6 @@ const ConfigBtn: FC<Props> = ({
setOpen(!openRef.current)
}, [setOpen])
useEffect(() => {
if (controlShowPopup)
// setOpen(!openRef.current)
setOpen(true)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [controlShowPopup])
if (popupProps.readOnly && !hasConfigured)
return null
@ -52,14 +42,11 @@ const ConfigBtn: FC<Props> = ({
open={open}
onOpenChange={setOpen}
placement='bottom-end'
offset={{
mainAxis: 12,
crossAxis: hasConfigured ? 8 : 49,
}}
offset={12}
>
<PortalToFollowElemTrigger onClick={handleTrigger}>
<div className={cn(className, 'rounded-md p-1')}>
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
<div className="select-none">
{children}
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-[11]'>

View File

@ -1,8 +1,9 @@
'use client'
import type { FC } from 'react'
import React, { useCallback, useEffect, useState } from 'react'
import React, { useEffect, useState } from 'react'
import {
RiArrowDownDoubleLine,
RiEqualizer2Line,
} from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { usePathname } from 'next/navigation'
@ -180,10 +181,6 @@ const Panel: FC = () => {
})()
}, [])
const [controlShowPopup, setControlShowPopup] = useState<number>(0)
const showPopup = useCallback(() => {
setControlShowPopup(Date.now())
}, [setControlShowPopup])
if (!isLoaded) {
return (
<div className='mb-3 flex items-center justify-between'>
@ -196,46 +193,66 @@ const Panel: FC = () => {
return (
<div className={cn('flex items-center justify-between')}>
<div
className={cn(
'flex cursor-pointer items-center rounded-xl border-l-[0.5px] border-t border-effects-highlight bg-background-default-dodge p-2 shadow-xs hover:border-effects-highlight-lightmode-off hover:bg-background-default-lighter',
controlShowPopup && 'border-effects-highlight-lightmode-off bg-background-default-lighter',
)}
onClick={showPopup}
>
{!inUseTracingProvider && (
<>
{!inUseTracingProvider && (
<ConfigButton
appId={appId}
readOnly={readOnly}
hasConfigured={false}
enabled={enabled}
onStatusChange={handleTracingEnabledChange}
chosenProvider={inUseTracingProvider}
onChooseProvider={handleChooseProvider}
arizeConfig={arizeConfig}
phoenixConfig={phoenixConfig}
langSmithConfig={langSmithConfig}
langFuseConfig={langFuseConfig}
opikConfig={opikConfig}
weaveConfig={weaveConfig}
aliyunConfig={aliyunConfig}
onConfigUpdated={handleTracingConfigUpdated}
onConfigRemoved={handleTracingConfigRemoved}
>
<div
className={cn(
'flex cursor-pointer select-none items-center rounded-xl border-l-[0.5px] border-t border-effects-highlight bg-background-default-dodge p-2 shadow-xs hover:border-effects-highlight-lightmode-off hover:bg-background-default-lighter',
)}
>
<TracingIcon size='md' />
<div className='system-sm-semibold mx-2 text-text-secondary'>{t(`${I18N_PREFIX}.title`)}</div>
<div className='flex items-center' onClick={e => e.stopPropagation()}>
<ConfigButton
appId={appId}
readOnly={readOnly}
hasConfigured={false}
enabled={enabled}
onStatusChange={handleTracingEnabledChange}
chosenProvider={inUseTracingProvider}
onChooseProvider={handleChooseProvider}
arizeConfig={arizeConfig}
phoenixConfig={phoenixConfig}
langSmithConfig={langSmithConfig}
langFuseConfig={langFuseConfig}
opikConfig={opikConfig}
weaveConfig={weaveConfig}
aliyunConfig={aliyunConfig}
onConfigUpdated={handleTracingConfigUpdated}
onConfigRemoved={handleTracingConfigRemoved}
controlShowPopup={controlShowPopup}
/>
<div className='rounded-md p-1'>
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
</div>
<Divider type='vertical' className='h-3.5' />
<div className='rounded-md p-1'>
<RiArrowDownDoubleLine className='h-4 w-4 text-text-tertiary' />
</div>
</>
)}
{hasConfiguredTracing && (
<>
</div>
</ConfigButton>
)}
{hasConfiguredTracing && (
<ConfigButton
appId={appId}
readOnly={readOnly}
hasConfigured
enabled={enabled}
onStatusChange={handleTracingEnabledChange}
chosenProvider={inUseTracingProvider}
onChooseProvider={handleChooseProvider}
arizeConfig={arizeConfig}
phoenixConfig={phoenixConfig}
langSmithConfig={langSmithConfig}
langFuseConfig={langFuseConfig}
opikConfig={opikConfig}
weaveConfig={weaveConfig}
aliyunConfig={aliyunConfig}
onConfigUpdated={handleTracingConfigUpdated}
onConfigRemoved={handleTracingConfigRemoved}
>
<div
className={cn(
'flex cursor-pointer select-none items-center rounded-xl border-l-[0.5px] border-t border-effects-highlight bg-background-default-dodge p-2 shadow-xs hover:border-effects-highlight-lightmode-off hover:bg-background-default-lighter',
)}
>
<div className='ml-4 mr-1 flex items-center'>
<Indicator color={enabled ? 'green' : 'gray'} />
<div className='system-xs-semibold-uppercase ml-1.5 text-text-tertiary'>
@ -243,33 +260,14 @@ const Panel: FC = () => {
</div>
</div>
{InUseProviderIcon && <InUseProviderIcon className='ml-1 h-4' />}
<Divider type='vertical' className='h-3.5' />
<div className='flex items-center' onClick={e => e.stopPropagation()}>
<ConfigButton
appId={appId}
readOnly={readOnly}
hasConfigured
className='ml-2'
enabled={enabled}
onStatusChange={handleTracingEnabledChange}
chosenProvider={inUseTracingProvider}
onChooseProvider={handleChooseProvider}
arizeConfig={arizeConfig}
phoenixConfig={phoenixConfig}
langSmithConfig={langSmithConfig}
langFuseConfig={langFuseConfig}
opikConfig={opikConfig}
weaveConfig={weaveConfig}
aliyunConfig={aliyunConfig}
onConfigUpdated={handleTracingConfigUpdated}
onConfigRemoved={handleTracingConfigRemoved}
controlShowPopup={controlShowPopup}
/>
<div className='ml-2 rounded-md p-1'>
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
</div>
</>
)}
</div >
</div >
<Divider type='vertical' className='h-3.5' />
</div>
</ConfigButton>
)}
</div>
)
}
export default React.memo(Panel)

View File

@ -56,33 +56,50 @@ const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => {
}, [isMobile, setShowTips])
return <div>
{hasRelatedApps && (
<>
{!isMobile && (
<Tooltip
position='right'
noDecoration
popupContent={
<LinkedAppsPanel
relatedApps={relatedApps.data}
isMobile={isMobile}
/>
}
>
<div className='system-xs-medium-uppercase inline-flex cursor-pointer items-center space-x-1 text-text-secondary'>
<span>{relatedAppsTotal || '--'} {t('common.datasetMenus.relatedApp')}</span>
<RiInformation2Line className='h-4 w-4' />
</div>
</Tooltip>
)}
{/* Related apps for desktop */}
<div className={classNames(
'transition-all duration-200 ease-in-out',
(hasRelatedApps && !isMobile)
? 'w-auto opacity-100'
: 'pointer-events-none h-0 w-0 overflow-hidden opacity-0',
)}>
<Tooltip
position='right'
noDecoration
popupContent={
<LinkedAppsPanel
relatedApps={relatedApps?.data || []}
isMobile={isMobile}
/>
}
>
<div className='system-xs-medium-uppercase inline-flex cursor-pointer items-center space-x-1 whitespace-nowrap text-text-secondary'>
<span>{relatedAppsTotal || '--'} {t('common.datasetMenus.relatedApp')}</span>
<RiInformation2Line className='h-4 w-4' />
</div>
</Tooltip>
</div>
{isMobile && <div className={classNames('pb-2 pt-4 text-xs font-medium uppercase text-text-tertiary', 'flex items-center justify-center gap-1 !px-0')}>
{relatedAppsTotal || '--'}
<PaperClipIcon className='h-4 w-4 text-text-secondary' />
</div>}
</>
)}
{!hasRelatedApps && !expand && (
{/* Related apps for mobile */}
<div className={classNames(
'transition-all duration-200 ease-in-out',
(hasRelatedApps && isMobile)
? 'w-auto opacity-100'
: 'pointer-events-none h-0 w-0 overflow-hidden opacity-0',
)}>
<div className={classNames('pb-2 pt-4 text-xs font-medium uppercase text-text-tertiary', 'flex items-center justify-center gap-1 whitespace-nowrap !px-0')}>
{relatedAppsTotal || '--'}
<PaperClipIcon className='h-4 w-4 text-text-secondary' />
</div>
</div>
{/* No related apps tooltip */}
<div className={classNames(
'transition-all duration-200 ease-in-out',
(!hasRelatedApps && !expand)
? 'w-auto opacity-100'
: 'pointer-events-none h-0 w-0 overflow-hidden opacity-0',
)}>
<Tooltip
position='right'
noDecoration
@ -103,12 +120,12 @@ const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => {
</div>
}
>
<div className='system-xs-medium-uppercase inline-flex cursor-pointer items-center space-x-1 text-text-secondary'>
<div className='system-xs-medium-uppercase inline-flex cursor-pointer items-center space-x-1 whitespace-nowrap text-text-secondary'>
<span>{t('common.datasetMenus.noRelatedApp')}</span>
<RiInformation2Line className='h-4 w-4' />
</div>
</Tooltip>
)}
</div>
</div>
}

View File

@ -1,9 +1,7 @@
import React from 'react'
import DatasetUpdateForm from '@/app/components/datasets/create'
type Props = {}
const DatasetCreation = async (props: Props) => {
const DatasetCreation = async () => {
return (
<DatasetUpdateForm />
)

View File

@ -70,7 +70,10 @@ export default function CheckCode() {
<div className='pb-4 pt-2'>
<h2 className='title-4xl-semi-bold text-text-primary'>{t('login.checkCode.checkYourEmail')}</h2>
<p className='body-md-regular mt-2 text-text-secondary'>
<span dangerouslySetInnerHTML={{ __html: t('login.checkCode.tips', { email }) as string }}></span>
<span>
{t('login.checkCode.tipsPrefix')}
<strong>{email}</strong>
</span>
<br />
{t('login.checkCode.validTime')}
</p>

View File

@ -93,7 +93,10 @@ export default function CheckCode() {
<div className='pb-4 pt-2'>
<h2 className='title-4xl-semi-bold text-text-primary'>{t('login.checkCode.checkYourEmail')}</h2>
<p className='body-md-regular mt-2 text-text-secondary'>
<span dangerouslySetInnerHTML={{ __html: t('login.checkCode.tips', { email }) as string }}></span>
<span>
{t('login.checkCode.tipsPrefix')}
<strong>{email}</strong>
</span>
<br />
{t('login.checkCode.validTime')}
</p>

View File

@ -29,15 +29,17 @@ const DatasetInfo: FC<Props> = ({
<div className='mr-3 shrink-0'>
<AppIcon innerIcon={DatasetSvg} className='!border-[0.5px] !border-indigo-100 !bg-indigo-25' />
</div>
{expand && (
<div className='mt-2'>
<div className='system-md-semibold text-text-secondary'>
{name}
</div>
<div className='system-2xs-medium-uppercase mt-1 text-text-tertiary'>{isExternal ? t('dataset.externalTag') : t('dataset.localDocs')}</div>
<div className='system-xs-regular my-3 text-text-tertiary first-letter:capitalize'>{description}</div>
<div className={`transition-all duration-200 ease-in-out ${
expand
? 'mt-2 w-auto opacity-100'
: 'pointer-events-none h-0 w-0 overflow-hidden opacity-0'
}`}>
<div className='system-md-semibold truncate whitespace-nowrap text-text-secondary'>
{name}
</div>
)}
<div className='system-2xs-medium-uppercase mt-1 whitespace-nowrap text-text-tertiary'>{isExternal ? t('dataset.externalTag') : t('dataset.localDocs')}</div>
<div className='system-xs-regular my-3 whitespace-nowrap text-text-tertiary first-letter:capitalize'>{description}</div>
</div>
{extraInfo}
</div>
)

View File

@ -88,7 +88,8 @@ const HeaderOptions: FC<Props> = ({
await clearAllAnnotations(appId)
onAdded()
}
catch (_) {
catch (e) {
console.error(`failed to clear all annotations, ${e}`)
}
finally {
setShowClearConfirm(false)

View File

@ -11,7 +11,7 @@ import SelectTypeItem from '../select-type-item'
import Field from './field'
import Input from '@/app/components/base/input'
import Toast from '@/app/components/base/toast'
import { checkKeys, getNewVarInWorkflow, replaceSpaceWithUnderscreInVarNameInput } from '@/utils/var'
import { checkKeys, getNewVarInWorkflow, replaceSpaceWithUnderscoreInVarNameInput } from '@/utils/var'
import ConfigContext from '@/context/debug-configuration'
import type { InputVar, MoreInfo, UploadFileSetting } from '@/app/components/workflow/types'
import Modal from '@/app/components/base/modal'
@ -111,7 +111,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
}, [checkVariableName, tempPayload.label])
const handleVarNameChange = useCallback((e: ChangeEvent<any>) => {
replaceSpaceWithUnderscreInVarNameInput(e.target)
replaceSpaceWithUnderscoreInVarNameInput(e.target)
const value = e.target.value
const { isValid, errorKey, errorMessageKey } = checkKeys([value], true)
if (!isValid) {

View File

@ -1,6 +1,6 @@
import React from 'react'
import React, { useState } from 'react'
import Link from 'next/link'
import { RiDiscordFill, RiGithubFill } from '@remixicon/react'
import { RiCloseLine, RiDiscordFill, RiGithubFill } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
type CustomLinkProps = {
@ -26,9 +26,24 @@ const CustomLink = React.memo(({
const Footer = () => {
const { t } = useTranslation()
const [isVisible, setIsVisible] = useState(true)
const handleClose = () => {
setIsVisible(false)
}
if (!isVisible)
return null
return (
<footer className='shrink-0 grow-0 px-12 py-6'>
<footer className='relative shrink-0 grow-0 px-12 py-2'>
<button
onClick={handleClose}
className='absolute right-2 top-2 flex h-6 w-6 cursor-pointer items-center justify-center rounded-full transition-colors duration-200 ease-in-out hover:bg-components-main-nav-nav-button-bg-active'
aria-label="Close footer"
>
<RiCloseLine className='h-4 w-4 text-text-tertiary hover:text-text-secondary' />
</button>
<h3 className='text-gradient text-xl font-semibold leading-tight'>{t('app.join')}</h3>
<p className='system-sm-regular mt-1 text-text-tertiary'>{t('app.communityIntro')}</p>
<div className='mt-3 flex items-center gap-2'>

View File

@ -115,8 +115,11 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
}, [])
useEffect(() => {
if (appData?.site.default_language)
changeLanguage(appData.site.default_language)
const setLocaleFromProps = async () => {
if (appData?.site.default_language)
await changeLanguage(appData.site.default_language)
}
setLocaleFromProps()
}, [appData])
const [sidebarCollapseState, setSidebarCollapseState] = useState<boolean>(false)
@ -159,9 +162,21 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
return currentConversationId
}, [currentConversationId, newConversationId])
const { data: appPinnedConversationData, mutate: mutateAppPinnedConversationData } = useSWR(['appConversationData', isInstalledApp, appId, true], () => fetchConversations(isInstalledApp, appId, undefined, true, 100))
const { data: appConversationData, isLoading: appConversationDataLoading, mutate: mutateAppConversationData } = useSWR(['appConversationData', isInstalledApp, appId, false], () => fetchConversations(isInstalledApp, appId, undefined, false, 100))
const { data: appChatListData, isLoading: appChatListDataLoading } = useSWR(chatShouldReloadKey ? ['appChatList', chatShouldReloadKey, isInstalledApp, appId] : null, () => fetchChatList(chatShouldReloadKey, isInstalledApp, appId))
const { data: appPinnedConversationData, mutate: mutateAppPinnedConversationData } = useSWR(
appId ? ['appConversationData', isInstalledApp, appId, true] : null,
() => fetchConversations(isInstalledApp, appId, undefined, true, 100),
{ revalidateOnFocus: false, revalidateOnReconnect: false },
)
const { data: appConversationData, isLoading: appConversationDataLoading, mutate: mutateAppConversationData } = useSWR(
appId ? ['appConversationData', isInstalledApp, appId, false] : null,
() => fetchConversations(isInstalledApp, appId, undefined, false, 100),
{ revalidateOnFocus: false, revalidateOnReconnect: false },
)
const { data: appChatListData, isLoading: appChatListDataLoading } = useSWR(
chatShouldReloadKey ? ['appChatList', chatShouldReloadKey, isInstalledApp, appId] : null,
() => fetchChatList(chatShouldReloadKey, isInstalledApp, appId),
{ revalidateOnFocus: false, revalidateOnReconnect: false },
)
const [clearChatList, setClearChatList] = useState(false)
const [isResponding, setIsResponding] = useState(false)

View File

@ -101,15 +101,15 @@ export const useEmbeddedChatbot = () => {
if (localeParam) {
// If locale parameter exists in URL, use it instead of default
changeLanguage(localeParam)
await changeLanguage(localeParam)
}
else if (localeFromSysVar) {
// If locale is set as a system variable, use that
changeLanguage(localeFromSysVar)
await changeLanguage(localeFromSysVar)
}
else if (appInfo?.site.default_language) {
// Otherwise use the default from app config
changeLanguage(appInfo.site.default_language)
await changeLanguage(appInfo.site.default_language)
}
}

View File

@ -3,6 +3,7 @@ import {
memo,
useMemo,
} from 'react'
import { RiExternalLinkLine } from '@remixicon/react'
import type { AnyFieldApi } from '@tanstack/react-form'
import { useStore } from '@tanstack/react-form'
import cn from '@/utils/classnames'
@ -200,6 +201,22 @@ const BaseField = ({
</div>
)
}
{
formSchema.url && (
<a
className='system-xs-regular mt-4 flex items-center text-text-accent'
href={formSchema?.url}
target='_blank'
>
<span className='break-all'>
{renderI18nObject(formSchema?.help as any)}
</span>
{
<RiExternalLinkLine className='ml-1 h-3 w-3' />
}
</a>
)
}
</div>
</div>
)

View File

@ -14,9 +14,26 @@ export type Attrs = {
export function normalizeAttrs(attrs: Attrs = {}): Attrs {
return Object.keys(attrs).reduce((acc: Attrs, key) => {
// Filter out editor metadata attributes before processing
if (key.startsWith('inkscape:')
|| key.startsWith('sodipodi:')
|| key.startsWith('xmlns:inkscape')
|| key.startsWith('xmlns:sodipodi')
|| key.startsWith('xmlns:svg')
|| key === 'data-name')
return acc
const val = attrs[key]
key = key.replace(/([-]\w)/g, (g: string) => g[1].toUpperCase())
key = key.replace(/([:]\w)/g, (g: string) => g[1].toUpperCase())
// Additional filter after camelCase conversion
if (key === 'xmlnsInkscape'
|| key === 'xmlnsSodipodi'
|| key === 'xmlnsSvg'
|| key === 'dataName')
return acc
switch (key) {
case 'class':
acc.className = val

Some files were not shown because too many files have changed in this diff Show More