mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/model-auth
This commit is contained in:
commit
4e6cb26778
|
|
@ -0,0 +1,27 @@
|
|||
name: "✨ Refactor"
|
||||
description: Refactor existing code for improved readability and maintainability.
|
||||
title: "[Chore/Refactor] "
|
||||
labels:
|
||||
- refactor
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Description
|
||||
placeholder: "Describe the refactor you are proposing."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation
|
||||
placeholder: "Explain why this refactor is necessary."
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
placeholder: "Add any other context or screenshots about the request here."
|
||||
validations:
|
||||
required: false
|
||||
|
|
@ -99,3 +99,6 @@ jobs:
|
|||
|
||||
- name: Run Tool
|
||||
run: uv run --project api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run TestContainers
|
||||
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ on:
|
|||
- "deploy/dev"
|
||||
- "deploy/enterprise"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import secrets
|
|||
from typing import Any, Optional
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from flask import current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
|
@ -457,7 +458,7 @@ def convert_to_agent_apps():
|
|||
"""
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query))
|
||||
rs = conn.execute(sa.text(sql_query))
|
||||
|
||||
apps = []
|
||||
for i in rs:
|
||||
|
|
@ -702,7 +703,7 @@ def fix_app_site_missing():
|
|||
sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id
|
||||
where sites.id is null limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql))
|
||||
rs = conn.execute(sa.text(sql))
|
||||
|
||||
processed_count = 0
|
||||
for i in rs:
|
||||
|
|
@ -916,7 +917,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
)
|
||||
orphaned_message_files = []
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])})
|
||||
|
||||
|
|
@ -937,7 +938,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
|
||||
query = "DELETE FROM message_files WHERE id IN :ids"
|
||||
with db.engine.begin() as conn:
|
||||
conn.execute(db.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
|
||||
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
|
||||
click.echo(
|
||||
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
|
||||
)
|
||||
|
|
@ -954,7 +955,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white"))
|
||||
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]})
|
||||
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
|
||||
|
|
@ -974,7 +975,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
|
||||
elif ids_table["type"] == "text":
|
||||
|
|
@ -989,7 +990,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
|
|
@ -1008,7 +1009,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
f"FROM {ids_table['table']}"
|
||||
)
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
for j in i[0]:
|
||||
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
|
||||
|
|
@ -1037,7 +1038,7 @@ def clear_orphaned_file_records(force: bool):
|
|||
click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white"))
|
||||
query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids"
|
||||
with db.engine.begin() as conn:
|
||||
conn.execute(db.text(query), {"ids": tuple(orphaned_files)})
|
||||
conn.execute(sa.text(query), {"ids": tuple(orphaned_files)})
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red"))
|
||||
return
|
||||
|
|
@ -1107,7 +1108,7 @@ def remove_orphaned_files_on_storage(force: bool):
|
|||
click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white"))
|
||||
query = f"SELECT {files_table['key_column']} FROM {files_table['table']}"
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query))
|
||||
rs = conn.execute(sa.text(query))
|
||||
for i in rs:
|
||||
all_files_in_tables.append(str(i[0]))
|
||||
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ DEFAULT_FILE_NUMBER_LIMITS = 3
|
|||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
|
||||
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
|
||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ from .datasets import (
|
|||
external,
|
||||
hit_testing,
|
||||
metadata,
|
||||
upload_file,
|
||||
website,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "message_count": i.message_count})
|
||||
|
||||
|
|
@ -176,7 +176,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"}
|
||||
|
|
@ -310,7 +310,7 @@ ORDER BY
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
|
|
@ -373,7 +373,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
|
|
@ -435,7 +435,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)})
|
||||
|
||||
|
|
@ -495,7 +495,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)})
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from datetime import datetime
|
|||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
|
@ -71,7 +72,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "runs": i.runs})
|
||||
|
||||
|
|
@ -133,7 +134,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
|
||||
|
|
@ -195,7 +196,7 @@ WHERE
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
|
|
@ -277,7 +278,7 @@ GROUP BY
|
|||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
|
|
|
|||
|
|
@ -642,7 +642,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
|||
return marshal(document_dict, document_status_fields)
|
||||
|
||||
|
||||
class DocumentDetailApi(DocumentResource):
|
||||
class DocumentApi(DocumentResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
@setup_required
|
||||
|
|
@ -730,6 +730,28 @@ class DocumentDetailApi(DocumentResource):
|
|||
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class DocumentProcessingApi(DocumentResource):
|
||||
@setup_required
|
||||
|
|
@ -768,30 +790,6 @@ class DocumentProcessingApi(DocumentResource):
|
|||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DocumentDeleteApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class DocumentMetadataApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -1037,11 +1035,10 @@ api.add_resource(
|
|||
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
|
||||
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(
|
||||
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
|
||||
)
|
||||
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
|
||||
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
|
||||
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
class UploadFileApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""Get upload file."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check upload file
|
||||
if document.data_source_type != "upload_file":
|
||||
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
else:
|
||||
raise ValueError("Upload file id not found in document data source info.")
|
||||
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": url,
|
||||
"download_url": f"{url}&as_attachment=true",
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.timestamp(),
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")
|
||||
|
|
@ -127,7 +127,7 @@ class EducationActivateLimitError(BaseHTTPException):
|
|||
code = 429
|
||||
|
||||
|
||||
class CompilanceRateLimitError(BaseHTTPException):
|
||||
error_code = "compilance_rate_limit"
|
||||
class ComplianceRateLimitError(BaseHTTPException):
|
||||
error_code = "compliance_rate_limit"
|
||||
description = "Rate limit exceeded for downloading compliance report."
|
||||
code = 429
|
||||
|
|
|
|||
|
|
@ -58,21 +58,38 @@ class InstalledAppsListApi(Resource):
|
|||
# filter out apps that user doesn't have access to
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
user_id = current_user.id
|
||||
res = []
|
||||
app_ids = [installed_app["app"].id for installed_app in installed_app_list]
|
||||
webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids)
|
||||
|
||||
# Pre-filter out apps without setting or with sso_verified
|
||||
filtered_installed_apps = []
|
||||
app_id_to_app_code = {}
|
||||
|
||||
for installed_app in installed_app_list:
|
||||
webapp_setting = webapp_settings.get(installed_app["app"].id)
|
||||
if not webapp_setting:
|
||||
app_id = installed_app["app"].id
|
||||
webapp_setting = webapp_settings.get(app_id)
|
||||
if not webapp_setting or webapp_setting.access_mode == "sso_verified":
|
||||
continue
|
||||
if webapp_setting.access_mode == "sso_verified":
|
||||
continue
|
||||
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
|
||||
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
app_code=app_code,
|
||||
):
|
||||
app_code = AppService.get_app_code_by_id(str(app_id))
|
||||
app_id_to_app_code[app_id] = app_code
|
||||
filtered_installed_apps.append(installed_app)
|
||||
|
||||
app_codes = list(app_id_to_app_code.values())
|
||||
|
||||
# Batch permission check
|
||||
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
|
||||
user_id=user_id,
|
||||
app_codes=app_codes,
|
||||
)
|
||||
|
||||
# Keep only allowed apps
|
||||
res = []
|
||||
for installed_app in filtered_installed_apps:
|
||||
app_id = installed_app["app"].id
|
||||
app_code = app_id_to_app_code[app_id]
|
||||
if permissions.get(app_code):
|
||||
res.append(installed_app)
|
||||
|
||||
installed_app_list = res
|
||||
logger.debug("installed_app_list: %s, user_id: %s", installed_app_list, user_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
|
|
@ -30,6 +30,7 @@ from libs import helper
|
|||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
|
|
@ -113,7 +114,7 @@ class ChatApi(Resource):
|
|||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
|
||||
|
||||
parser.add_argument("workflow_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
|
|
@ -128,6 +129,12 @@ class ChatApi(Resource):
|
|||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except IsDraftWorkflowError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except WorkflowIdFormatError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask import request
|
|||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
|
|
@ -34,6 +34,7 @@ from libs.helper import TimestampField
|
|||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
|
@ -120,6 +121,59 @@ class WorkflowRunApi(Resource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowRunByIdApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, workflow_id: str):
|
||||
"""
|
||||
Run specific workflow by ID
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Add workflow_id to args for AppGenerateService
|
||||
args["workflow_id"] = workflow_id
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except IsDraftWorkflowError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except WorkflowIdFormatError as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowTaskStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
|
|
@ -193,5 +247,6 @@ class WorkflowAppLogApi(Resource):
|
|||
|
||||
api.add_resource(WorkflowRunApi, "/workflows/run")
|
||||
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_run_id>")
|
||||
api.add_resource(WorkflowRunByIdApi, "/workflows/<string:workflow_id>/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
|
||||
api.add_resource(WorkflowAppLogApi, "/workflows/logs")
|
||||
|
|
|
|||
|
|
@ -358,39 +358,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
class DocumentDeleteApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id):
|
||||
"""Delete document."""
|
||||
document_id = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# get dataset info
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
try:
|
||||
# delete document
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DocumentListApi(DatasetApiResource):
|
||||
def get(self, tenant_id, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
|
|
@ -473,7 +440,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
|||
return data
|
||||
|
||||
|
||||
class DocumentDetailApi(DatasetApiResource):
|
||||
class DocumentApi(DatasetApiResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
|
|
@ -567,6 +534,37 @@ class DocumentDetailApi(DatasetApiResource):
|
|||
|
||||
return response
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id):
|
||||
"""Delete document."""
|
||||
document_id = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# get dataset info
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
try:
|
||||
# delete document
|
||||
DocumentService.delete_document(document)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DocumentAddByTextApi,
|
||||
|
|
@ -588,7 +586,6 @@ api.add_resource(
|
|||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
|
||||
)
|
||||
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
|
||||
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ class ProviderConfig(BasicProviderConfig):
|
|||
|
||||
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
|
||||
required: bool = False
|
||||
default: Optional[Union[int, str]] = None
|
||||
default: Optional[Union[int, str, float, bool]] = None
|
||||
options: Optional[list[Option]] = None
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def get_attr(*, file: File, attr: FileAttribute):
|
|||
case FileAttribute.TRANSFER_METHOD:
|
||||
return file.transfer_method.value
|
||||
case FileAttribute.URL:
|
||||
return file.remote_url
|
||||
return _to_url(file)
|
||||
case FileAttribute.EXTENSION:
|
||||
return file.extension
|
||||
case FileAttribute.RELATED_ID:
|
||||
|
|
|
|||
|
|
@ -208,6 +208,7 @@ class BasePluginClient:
|
|||
except Exception:
|
||||
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
|
||||
|
||||
logger.error("Error in stream reponse for plugin %s", rep.__dict__)
|
||||
self._handle_plugin_daemon_error(error.error_type, error.message)
|
||||
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
|
||||
if rep.data is None:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from collections.abc import Mapping
|
|||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from extensions.ext_logging import get_request_id
|
||||
|
||||
|
||||
class PluginDaemonError(Exception):
|
||||
"""Base class for all plugin daemon errors."""
|
||||
|
|
@ -11,7 +13,7 @@ class PluginDaemonError(Exception):
|
|||
|
||||
def __str__(self) -> str:
|
||||
# returns the class name and description
|
||||
return f"{self.__class__.__name__}: {self.description}"
|
||||
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"
|
||||
|
||||
|
||||
class PluginDaemonInternalError(PluginDaemonError):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from urllib.parse import urlparse
|
|||
import requests
|
||||
from elasticsearch import Elasticsearch
|
||||
from flask import current_app
|
||||
from packaging.version import parse as parse_version
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
|
|
@ -149,7 +150,7 @@ class ElasticSearchVector(BaseVector):
|
|||
return cast(str, info["version"]["number"])
|
||||
|
||||
def _check_version(self):
|
||||
if self._version < "8.0.0":
|
||||
if parse_version(self._version) < parse_version("8.0.0"):
|
||||
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
|
||||
|
||||
def get_type(self) -> str:
|
||||
|
|
|
|||
|
|
@ -20,9 +20,6 @@ class Tool(ABC):
|
|||
The base class of a tool
|
||||
"""
|
||||
|
||||
entity: ToolEntity
|
||||
runtime: ToolRuntime
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@ class BuiltinTool(Tool):
|
|||
:param meta: the meta data of a tool call processing
|
||||
"""
|
||||
|
||||
provider: str
|
||||
|
||||
def __init__(self, provider: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.provider = provider
|
||||
|
|
|
|||
|
|
@ -21,9 +21,6 @@ API_TOOL_DEFAULT_TIMEOUT = (
|
|||
|
||||
|
||||
class ApiTool(Tool):
|
||||
api_bundle: ApiToolBundle
|
||||
provider_id: str
|
||||
|
||||
"""
|
||||
Api tool
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,23 +8,16 @@ from core.mcp.mcp_client import MCPClient
|
|||
from core.mcp.types import ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
runtime_parameters: Optional[list[ToolParameter]]
|
||||
server_url: str
|
||||
provider_id: str
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.runtime_parameters = None
|
||||
self.server_url = server_url
|
||||
self.provider_id = provider_id
|
||||
|
||||
|
|
|
|||
|
|
@ -9,11 +9,6 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
|
|||
|
||||
|
||||
class PluginTool(Tool):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
runtime_parameters: Optional[list[ToolParameter]]
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
|
||||
) -> None:
|
||||
|
|
@ -21,7 +16,7 @@ class PluginTool(Tool):
|
|||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
self.runtime_parameters = None
|
||||
self.runtime_parameters: Optional[list[ToolParameter]] = None
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.PLUGIN
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from os import listdir, path
|
|||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import TypeAdapter
|
||||
from yarl import URL
|
||||
|
||||
|
|
@ -616,7 +617,7 @@ class ToolManager:
|
|||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas
|
|||
|
||||
|
||||
class DatasetRetrieverTool(Tool):
|
||||
retrieval_tool: DatasetRetrieverBaseTool
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.retrieval_tool = retrieval_tool
|
||||
|
|
|
|||
|
|
@ -25,15 +25,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class WorkflowTool(Tool):
|
||||
workflow_app_id: str
|
||||
version: str
|
||||
workflow_entities: dict[str, Any]
|
||||
workflow_call_depth: int
|
||||
thread_pool_id: Optional[str] = None
|
||||
workflow_as_tool_id: str
|
||||
|
||||
label: str
|
||||
|
||||
"""
|
||||
Workflow tool.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class SegmentType(StrEnum):
|
|||
elif array_validation == ArrayValidation.FIRST:
|
||||
return element_type.is_valid(value[0])
|
||||
else:
|
||||
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
|
||||
return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value)
|
||||
|
||||
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
|
||||
"""
|
||||
|
|
@ -152,7 +152,7 @@ class SegmentType(StrEnum):
|
|||
|
||||
|
||||
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
||||
# ARRAY_ANY does not have correpond element type.
|
||||
# ARRAY_ANY does not have corresponding element type.
|
||||
SegmentType.ARRAY_STRING: SegmentType.STRING,
|
||||
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
|
||||
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
|
||||
|
|
|
|||
|
|
@ -318,6 +318,33 @@ class ToolNode(BaseNode):
|
|||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": message.message.text,
|
||||
}
|
||||
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
|
||||
|
|
|
|||
|
|
@ -136,6 +136,8 @@ def init_app(app: DifyApp):
|
|||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
|
||||
from opentelemetry.instrumentation.celery import CeleryInstrumentor
|
||||
from opentelemetry.instrumentation.flask import FlaskInstrumentor
|
||||
from opentelemetry.instrumentation.redis import RedisInstrumentor
|
||||
from opentelemetry.instrumentation.requests import RequestsInstrumentor
|
||||
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
||||
from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider
|
||||
from opentelemetry.propagate import set_global_textmap
|
||||
|
|
@ -234,6 +236,8 @@ def init_app(app: DifyApp):
|
|||
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
|
||||
instrument_exception_logging()
|
||||
init_sqlalchemy_instrumentor(app)
|
||||
RedisInstrumentor().instrument()
|
||||
RequestsInstrumentor().instrument()
|
||||
atexit.register(shutdown_tracer)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -59,6 +59,8 @@ model_config_fields = {
|
|||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
||||
|
||||
app_detail_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
|
|
@ -77,6 +79,7 @@ app_detail_fields = {
|
|||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
}
|
||||
|
||||
prompt_config_fields = {
|
||||
|
|
@ -92,8 +95,6 @@ model_config_partial_fields = {
|
|||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
||||
|
||||
app_partial_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
|
|
@ -185,7 +186,6 @@ app_detail_fields_with_site = {
|
|||
"enable_api": fields.Boolean,
|
||||
"model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||
"site": fields.Nested(site_fields),
|
||||
"api_base_url": fields.String,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"max_active_requests": fields.Integer,
|
||||
|
|
@ -195,6 +195,8 @@ app_detail_fields_with_site = {
|
|||
"updated_at": TimestampField,
|
||||
"deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
"site": fields.Nested(site_fields),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import json
|
|||
from datetime import datetime
|
||||
from typing import Optional, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import UserMixin # type: ignore
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
|
||||
|
|
@ -83,9 +84,9 @@ class AccountStatus(enum.StrEnum):
|
|||
|
||||
class Account(UserMixin, Base):
|
||||
__tablename__ = "accounts"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
email: Mapped[str] = mapped_column(String(255))
|
||||
password: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
|
|
@ -97,7 +98,7 @@ class Account(UserMixin, Base):
|
|||
last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(16), server_default=db.text("'active'::character varying"))
|
||||
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying"))
|
||||
initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
|
|
@ -195,14 +196,14 @@ class TenantStatus(enum.StrEnum):
|
|||
|
||||
class Tenant(Base):
|
||||
__tablename__ = "tenants"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
encrypt_public_key = db.Column(db.Text)
|
||||
plan: Mapped[str] = mapped_column(String(255), server_default=db.text("'basic'::character varying"))
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=db.text("'normal'::character varying"))
|
||||
custom_config: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
encrypt_public_key = db.Column(sa.Text)
|
||||
plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying"))
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
|
||||
custom_config: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
|
||||
|
|
@ -225,16 +226,16 @@ class Tenant(Base):
|
|||
class TenantAccountJoin(Base):
|
||||
__tablename__ = "tenant_account_joins"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
|
||||
db.Index("tenant_account_join_account_id_idx", "account_id"),
|
||||
db.Index("tenant_account_join_tenant_id_idx", "tenant_id"),
|
||||
db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
|
||||
sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
|
||||
sa.Index("tenant_account_join_account_id_idx", "account_id"),
|
||||
sa.Index("tenant_account_join_tenant_id_idx", "tenant_id"),
|
||||
sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||
current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
|
||||
current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
||||
role: Mapped[str] = mapped_column(String(16), server_default="normal")
|
||||
invited_by: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
|
|
@ -244,12 +245,12 @@ class TenantAccountJoin(Base):
|
|||
class AccountIntegrate(Base):
|
||||
__tablename__ = "account_integrates"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
|
||||
db.UniqueConstraint("account_id", "provider", name="unique_account_provider"),
|
||||
db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
|
||||
sa.UniqueConstraint("account_id", "provider", name="unique_account_provider"),
|
||||
sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
account_id: Mapped[str] = mapped_column(StringUUID)
|
||||
provider: Mapped[str] = mapped_column(String(16))
|
||||
open_id: Mapped[str] = mapped_column(String(255))
|
||||
|
|
@ -261,20 +262,20 @@ class AccountIntegrate(Base):
|
|||
class InvitationCode(Base):
|
||||
__tablename__ = "invitation_codes"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
|
||||
db.Index("invitation_codes_batch_idx", "batch"),
|
||||
db.Index("invitation_codes_code_idx", "code", "status"),
|
||||
sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
|
||||
sa.Index("invitation_codes_batch_idx", "batch"),
|
||||
sa.Index("invitation_codes_code_idx", "code", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(db.Integer)
|
||||
id: Mapped[int] = mapped_column(sa.Integer)
|
||||
batch: Mapped[str] = mapped_column(String(255))
|
||||
code: Mapped[str] = mapped_column(String(32))
|
||||
status: Mapped[str] = mapped_column(String(16), server_default=db.text("'unused'::character varying"))
|
||||
status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying"))
|
||||
used_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
|
||||
used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
|
||||
class TenantPluginPermission(Base):
|
||||
|
|
@ -290,11 +291,11 @@ class TenantPluginPermission(Base):
|
|||
|
||||
__tablename__ = "account_plugin_permissions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"),
|
||||
db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
|
||||
sa.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone")
|
||||
debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone")
|
||||
|
|
@ -313,16 +314,16 @@ class TenantPluginAutoUpgradeStrategy(Base):
|
|||
|
||||
__tablename__ = "tenant_plugin_auto_upgrade_strategies"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"),
|
||||
db.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
|
||||
sa.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only")
|
||||
upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day
|
||||
upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day
|
||||
upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude")
|
||||
exclude_plugins: Mapped[list[str]] = mapped_column(db.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
|
||||
include_plugins: Mapped[list[str]] = mapped_column(db.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
|
||||
exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
|
||||
include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
|
||||
created_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import enum
|
||||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
|
|
@ -19,11 +19,11 @@ class APIBasedExtensionPoint(enum.Enum):
|
|||
class APIBasedExtension(Base):
|
||||
__tablename__ = "api_based_extensions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
|
||||
db.Index("api_based_extension_tenant_idx", "tenant_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
|
||||
sa.Index("api_based_extension_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from datetime import datetime
|
|||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
|
@ -38,23 +39,23 @@ class DatasetPermissionEnum(enum.StrEnum):
|
|||
class Dataset(Base):
|
||||
__tablename__ = "datasets"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
|
||||
db.Index("dataset_tenant_idx", "tenant_id"),
|
||||
db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_pkey"),
|
||||
sa.Index("dataset_tenant_idx", "tenant_id"),
|
||||
sa.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
|
||||
PROVIDER_LIST = ["vendor", "external", None]
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
description = mapped_column(db.Text, nullable=True)
|
||||
provider: Mapped[str] = mapped_column(String(255), server_default=db.text("'vendor'::character varying"))
|
||||
permission: Mapped[str] = mapped_column(String(255), server_default=db.text("'only_me'::character varying"))
|
||||
description = mapped_column(sa.Text, nullable=True)
|
||||
provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'::character varying"))
|
||||
permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'::character varying"))
|
||||
data_source_type = mapped_column(String(255))
|
||||
indexing_technique: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
index_struct = mapped_column(db.Text, nullable=True)
|
||||
index_struct = mapped_column(sa.Text, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -63,7 +64,7 @@ class Dataset(Base):
|
|||
embedding_model_provider = db.Column(String(255), nullable=True) # TODO: mapped_column
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=True)
|
||||
retrieval_model = mapped_column(JSONB, nullable=True)
|
||||
built_in_field_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
built_in_field_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
@property
|
||||
def dataset_keyword_table(self):
|
||||
|
|
@ -262,14 +263,14 @@ class Dataset(Base):
|
|||
class DatasetProcessRule(Base):
|
||||
__tablename__ = "dataset_process_rules"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
|
||||
db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
|
||||
sa.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
mode = mapped_column(String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
|
||||
rules = mapped_column(db.Text, nullable=True)
|
||||
mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying"))
|
||||
rules = mapped_column(sa.Text, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
|
@ -302,20 +303,20 @@ class DatasetProcessRule(Base):
|
|||
class Document(Base):
|
||||
__tablename__ = "documents"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="document_pkey"),
|
||||
db.Index("document_dataset_id_idx", "dataset_id"),
|
||||
db.Index("document_is_paused_idx", "is_paused"),
|
||||
db.Index("document_tenant_idx", "tenant_id"),
|
||||
db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
|
||||
sa.PrimaryKeyConstraint("id", name="document_pkey"),
|
||||
sa.Index("document_dataset_id_idx", "dataset_id"),
|
||||
sa.Index("document_is_paused_idx", "is_paused"),
|
||||
sa.Index("document_tenant_idx", "tenant_id"),
|
||||
sa.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int] = mapped_column(db.Integer, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
data_source_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
data_source_info = mapped_column(db.Text, nullable=True)
|
||||
data_source_info = mapped_column(sa.Text, nullable=True)
|
||||
dataset_process_rule_id = mapped_column(StringUUID, nullable=True)
|
||||
batch: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -328,8 +329,8 @@ class Document(Base):
|
|||
processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# parsing
|
||||
file_id = mapped_column(db.Text, nullable=True)
|
||||
word_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True) # TODO: make this not nullable
|
||||
file_id = mapped_column(sa.Text, nullable=True)
|
||||
word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable
|
||||
parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# cleaning
|
||||
|
|
@ -339,32 +340,32 @@ class Document(Base):
|
|||
splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# indexing
|
||||
tokens: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
|
||||
indexing_latency: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True)
|
||||
tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
indexing_latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# pause
|
||||
is_paused: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
|
||||
is_paused: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
paused_by = mapped_column(StringUUID, nullable=True)
|
||||
paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# error
|
||||
error = mapped_column(db.Text, nullable=True)
|
||||
error = mapped_column(sa.Text, nullable=True)
|
||||
stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# basic fields
|
||||
indexing_status = mapped_column(String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
|
||||
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'::character varying"))
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
archived: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
archived: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
archived_reason = mapped_column(String(255), nullable=True)
|
||||
archived_by = mapped_column(StringUUID, nullable=True)
|
||||
archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
doc_type = mapped_column(String(40), nullable=True)
|
||||
doc_metadata = mapped_column(JSONB, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'::character varying"))
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
|
||||
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
|
||||
|
|
@ -643,44 +644,44 @@ class Document(Base):
|
|||
class DocumentSegment(Base):
|
||||
__tablename__ = "document_segments"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
|
||||
db.Index("document_segment_dataset_id_idx", "dataset_id"),
|
||||
db.Index("document_segment_document_id_idx", "document_id"),
|
||||
db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
|
||||
db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
|
||||
db.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"),
|
||||
db.Index("document_segment_tenant_idx", "tenant_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="document_segment_pkey"),
|
||||
sa.Index("document_segment_dataset_id_idx", "dataset_id"),
|
||||
sa.Index("document_segment_document_id_idx", "document_id"),
|
||||
sa.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
|
||||
sa.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
|
||||
sa.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"),
|
||||
sa.Index("document_segment_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int]
|
||||
content = mapped_column(db.Text, nullable=False)
|
||||
answer = mapped_column(db.Text, nullable=True)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
answer = mapped_column(sa.Text, nullable=True)
|
||||
word_count: Mapped[int]
|
||||
tokens: Mapped[int]
|
||||
|
||||
# indexing fields
|
||||
keywords = mapped_column(db.JSON, nullable=True)
|
||||
keywords = mapped_column(sa.JSON, nullable=True)
|
||||
index_node_id = mapped_column(String(255), nullable=True)
|
||||
index_node_hash = mapped_column(String(255), nullable=True)
|
||||
|
||||
# basic fields
|
||||
hit_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0)
|
||||
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=db.text("'waiting'::character varying"))
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'::character varying"))
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
error = mapped_column(db.Text, nullable=True)
|
||||
error = mapped_column(sa.Text, nullable=True)
|
||||
stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
@property
|
||||
|
|
@ -794,36 +795,36 @@ class DocumentSegment(Base):
|
|||
class ChildChunk(Base):
|
||||
__tablename__ = "child_chunks"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
|
||||
db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
|
||||
db.Index("child_chunks_node_idx", "index_node_id", "dataset_id"),
|
||||
db.Index("child_chunks_segment_idx", "segment_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
|
||||
sa.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
|
||||
sa.Index("child_chunks_node_idx", "index_node_id", "dataset_id"),
|
||||
sa.Index("child_chunks_segment_idx", "segment_id"),
|
||||
)
|
||||
|
||||
# initial fields
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
segment_id = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int] = mapped_column(db.Integer, nullable=False)
|
||||
content = mapped_column(db.Text, nullable=False)
|
||||
word_count: Mapped[int] = mapped_column(db.Integer, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
# indexing fields
|
||||
index_node_id = mapped_column(String(255), nullable=True)
|
||||
index_node_hash = mapped_column(String(255), nullable=True)
|
||||
type = mapped_column(String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
|
||||
type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying"))
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
error = mapped_column(db.Text, nullable=True)
|
||||
error = mapped_column(sa.Text, nullable=True)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
|
|
@ -841,11 +842,11 @@ class ChildChunk(Base):
|
|||
class AppDatasetJoin(Base):
|
||||
__tablename__ = "app_dataset_joins"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
|
||||
db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
|
||||
sa.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
|
@ -858,13 +859,13 @@ class AppDatasetJoin(Base):
|
|||
class DatasetQuery(Base):
|
||||
__tablename__ = "dataset_queries"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
|
||||
db.Index("dataset_query_dataset_id_idx", "dataset_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
|
||||
sa.Index("dataset_query_dataset_id_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
content = mapped_column(db.Text, nullable=False)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
source_app_id = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role = mapped_column(String, nullable=False)
|
||||
|
|
@ -875,15 +876,15 @@ class DatasetQuery(Base):
|
|||
class DatasetKeywordTable(Base):
|
||||
__tablename__ = "dataset_keyword_tables"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
|
||||
db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
|
||||
sa.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
dataset_id = mapped_column(StringUUID, nullable=False, unique=True)
|
||||
keyword_table = mapped_column(db.Text, nullable=False)
|
||||
keyword_table = mapped_column(sa.Text, nullable=False)
|
||||
data_source_type = mapped_column(
|
||||
String(255), nullable=False, server_default=db.text("'database'::character varying")
|
||||
String(255), nullable=False, server_default=sa.text("'database'::character varying")
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -920,19 +921,19 @@ class DatasetKeywordTable(Base):
|
|||
class Embedding(Base):
|
||||
__tablename__ = "embeddings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="embedding_pkey"),
|
||||
db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
|
||||
db.Index("created_at_idx", "created_at"),
|
||||
sa.PrimaryKeyConstraint("id", name="embedding_pkey"),
|
||||
sa.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
|
||||
sa.Index("created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
model_name = mapped_column(
|
||||
String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")
|
||||
String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'::character varying")
|
||||
)
|
||||
hash = mapped_column(String(64), nullable=False)
|
||||
embedding = mapped_column(db.LargeBinary, nullable=False)
|
||||
embedding = mapped_column(sa.LargeBinary, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
provider_name = mapped_column(String(255), nullable=False, server_default=db.text("''::character varying"))
|
||||
provider_name = mapped_column(String(255), nullable=False, server_default=sa.text("''::character varying"))
|
||||
|
||||
def set_embedding(self, embedding_data: list[float]):
|
||||
self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
|
@ -944,14 +945,14 @@ class Embedding(Base):
|
|||
class DatasetCollectionBinding(Base):
|
||||
__tablename__ = "dataset_collection_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
|
||||
db.Index("provider_model_name_idx", "provider_name", "model_name"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
|
||||
sa.Index("provider_model_name_idx", "provider_name", "model_name"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type = mapped_column(String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
|
||||
type = mapped_column(String(40), server_default=sa.text("'dataset'::character varying"), nullable=False)
|
||||
collection_name = mapped_column(String(64), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
|
@ -959,17 +960,17 @@ class DatasetCollectionBinding(Base):
|
|||
class TidbAuthBinding(Base):
|
||||
__tablename__ = "tidb_auth_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
|
||||
db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
|
||||
db.Index("tidb_auth_bindings_active_idx", "active"),
|
||||
db.Index("tidb_auth_bindings_created_at_idx", "created_at"),
|
||||
db.Index("tidb_auth_bindings_status_idx", "status"),
|
||||
sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
|
||||
sa.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
|
||||
sa.Index("tidb_auth_bindings_active_idx", "active"),
|
||||
sa.Index("tidb_auth_bindings_created_at_idx", "created_at"),
|
||||
sa.Index("tidb_auth_bindings_status_idx", "status"),
|
||||
)
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
active: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false"))
|
||||
status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'::character varying"))
|
||||
account: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -979,10 +980,10 @@ class TidbAuthBinding(Base):
|
|||
class Whitelist(Base):
|
||||
__tablename__ = "whitelists"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
|
||||
db.Index("whitelists_tenant_idx", "tenant_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="whitelists_pkey"),
|
||||
sa.Index("whitelists_tenant_idx", "tenant_id"),
|
||||
)
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
@ -991,33 +992,33 @@ class Whitelist(Base):
|
|||
class DatasetPermission(Base):
|
||||
__tablename__ = "dataset_permissions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
|
||||
db.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
|
||||
db.Index("idx_dataset_permissions_account_id", "account_id"),
|
||||
db.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
|
||||
sa.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
|
||||
sa.Index("idx_dataset_permissions_account_id", "account_id"),
|
||||
sa.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), primary_key=True)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
has_permission: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
has_permission: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ExternalKnowledgeApis(Base):
|
||||
__tablename__ = "external_knowledge_apis"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
|
||||
db.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
|
||||
db.Index("external_knowledge_apis_name_idx", "name"),
|
||||
sa.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
|
||||
sa.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
|
||||
sa.Index("external_knowledge_apis_name_idx", "name"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
settings = mapped_column(db.Text, nullable=True)
|
||||
settings = mapped_column(sa.Text, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -1061,18 +1062,18 @@ class ExternalKnowledgeApis(Base):
|
|||
class ExternalKnowledgeBindings(Base):
|
||||
__tablename__ = "external_knowledge_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
|
||||
db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
|
||||
db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
|
||||
db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
|
||||
db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
|
||||
sa.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
|
||||
sa.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
|
||||
sa.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
|
||||
sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
external_knowledge_api_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
external_knowledge_id = mapped_column(db.Text, nullable=False)
|
||||
external_knowledge_id = mapped_column(sa.Text, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -1082,57 +1083,57 @@ class ExternalKnowledgeBindings(Base):
|
|||
class DatasetAutoDisableLog(Base):
|
||||
__tablename__ = "dataset_auto_disable_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
|
||||
db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
|
||||
db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
|
||||
db.Index("dataset_auto_disable_log_created_atx", "created_at"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
|
||||
sa.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
|
||||
sa.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
|
||||
sa.Index("dataset_auto_disable_log_created_atx", "created_at"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
notified: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
|
||||
|
||||
class RateLimitLog(Base):
|
||||
__tablename__ = "rate_limit_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
|
||||
db.Index("rate_limit_log_tenant_idx", "tenant_id"),
|
||||
db.Index("rate_limit_log_operation_idx", "operation"),
|
||||
sa.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
|
||||
sa.Index("rate_limit_log_tenant_idx", "tenant_id"),
|
||||
sa.Index("rate_limit_log_operation_idx", "operation"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
operation: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
|
||||
|
||||
class DatasetMetadata(Base):
|
||||
__tablename__ = "dataset_metadatas"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
|
||||
db.Index("dataset_metadata_tenant_idx", "tenant_id"),
|
||||
db.Index("dataset_metadata_dataset_idx", "dataset_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
|
||||
sa.Index("dataset_metadata_tenant_idx", "tenant_id"),
|
||||
sa.Index("dataset_metadata_dataset_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -1141,14 +1142,14 @@ class DatasetMetadata(Base):
|
|||
class DatasetMetadataBinding(Base):
|
||||
__tablename__ = "dataset_metadata_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
|
||||
db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
|
||||
db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
|
||||
db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
|
||||
db.Index("dataset_metadata_binding_document_idx", "document_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
|
||||
sa.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
|
||||
sa.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
|
||||
sa.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
|
||||
sa.Index("dataset_metadata_binding_document_idx", "document_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
metadata_id = mapped_column(StringUUID, nullable=False)
|
||||
|
|
|
|||
|
|
@ -35,10 +35,10 @@ from .types import StringUUID
|
|||
|
||||
class DifySetup(Base):
|
||||
__tablename__ = "dify_setups"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
|
||||
version: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
setup_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class AppMode(StrEnum):
|
||||
|
|
@ -69,33 +69,33 @@ class IconType(Enum):
|
|||
|
||||
class App(Base):
|
||||
__tablename__ = "apps"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_pkey"), sa.Index("app_tenant_id_idx", "tenant_id"))
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
name: Mapped[str] = mapped_column(String(255))
|
||||
description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying"))
|
||||
description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying"))
|
||||
mode: Mapped[str] = mapped_column(String(255))
|
||||
icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji
|
||||
icon = db.Column(String(255))
|
||||
icon_background: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
app_model_config_id = mapped_column(StringUUID, nullable=True)
|
||||
workflow_id = mapped_column(StringUUID, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=db.text("'normal'::character varying"))
|
||||
enable_site: Mapped[bool] = mapped_column(db.Boolean)
|
||||
enable_api: Mapped[bool] = mapped_column(db.Boolean)
|
||||
api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
|
||||
api_rph: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
|
||||
is_demo: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
|
||||
is_public: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
|
||||
is_universal: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
|
||||
tracing = mapped_column(db.Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
|
||||
enable_site: Mapped[bool] = mapped_column(sa.Boolean)
|
||||
enable_api: Mapped[bool] = mapped_column(sa.Boolean)
|
||||
api_rpm: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"))
|
||||
api_rph: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"))
|
||||
is_demo: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
||||
is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
||||
is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
|
||||
tracing = mapped_column(sa.Text, nullable=True)
|
||||
max_active_requests: Mapped[Optional[int]]
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
@property
|
||||
def desc_or_prompt(self):
|
||||
|
|
@ -302,36 +302,36 @@ class App(Base):
|
|||
|
||||
class AppModelConfig(Base):
|
||||
__tablename__ = "app_model_configs"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id"))
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
provider = mapped_column(String(255), nullable=True)
|
||||
model_id = mapped_column(String(255), nullable=True)
|
||||
configs = mapped_column(db.JSON, nullable=True)
|
||||
configs = mapped_column(sa.JSON, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
opening_statement = mapped_column(db.Text)
|
||||
suggested_questions = mapped_column(db.Text)
|
||||
suggested_questions_after_answer = mapped_column(db.Text)
|
||||
speech_to_text = mapped_column(db.Text)
|
||||
text_to_speech = mapped_column(db.Text)
|
||||
more_like_this = mapped_column(db.Text)
|
||||
model = mapped_column(db.Text)
|
||||
user_input_form = mapped_column(db.Text)
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
opening_statement = mapped_column(sa.Text)
|
||||
suggested_questions = mapped_column(sa.Text)
|
||||
suggested_questions_after_answer = mapped_column(sa.Text)
|
||||
speech_to_text = mapped_column(sa.Text)
|
||||
text_to_speech = mapped_column(sa.Text)
|
||||
more_like_this = mapped_column(sa.Text)
|
||||
model = mapped_column(sa.Text)
|
||||
user_input_form = mapped_column(sa.Text)
|
||||
dataset_query_variable = mapped_column(String(255))
|
||||
pre_prompt = mapped_column(db.Text)
|
||||
agent_mode = mapped_column(db.Text)
|
||||
sensitive_word_avoidance = mapped_column(db.Text)
|
||||
retriever_resource = mapped_column(db.Text)
|
||||
prompt_type = mapped_column(String(255), nullable=False, server_default=db.text("'simple'::character varying"))
|
||||
chat_prompt_config = mapped_column(db.Text)
|
||||
completion_prompt_config = mapped_column(db.Text)
|
||||
dataset_configs = mapped_column(db.Text)
|
||||
external_data_tools = mapped_column(db.Text)
|
||||
file_upload = mapped_column(db.Text)
|
||||
pre_prompt = mapped_column(sa.Text)
|
||||
agent_mode = mapped_column(sa.Text)
|
||||
sensitive_word_avoidance = mapped_column(sa.Text)
|
||||
retriever_resource = mapped_column(sa.Text)
|
||||
prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'::character varying"))
|
||||
chat_prompt_config = mapped_column(sa.Text)
|
||||
completion_prompt_config = mapped_column(sa.Text)
|
||||
dataset_configs = mapped_column(sa.Text)
|
||||
external_data_tools = mapped_column(sa.Text)
|
||||
file_upload = mapped_column(sa.Text)
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
|
|
@ -553,24 +553,24 @@ class AppModelConfig(Base):
|
|||
class RecommendedApp(Base):
|
||||
__tablename__ = "recommended_apps"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
|
||||
db.Index("recommended_app_app_id_idx", "app_id"),
|
||||
db.Index("recommended_app_is_listed_idx", "is_listed", "language"),
|
||||
sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
|
||||
sa.Index("recommended_app_app_id_idx", "app_id"),
|
||||
sa.Index("recommended_app_is_listed_idx", "is_listed", "language"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
description = mapped_column(db.JSON, nullable=False)
|
||||
description = mapped_column(sa.JSON, nullable=False)
|
||||
copyright: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
position: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0)
|
||||
is_listed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=True)
|
||||
install_count: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0)
|
||||
language = mapped_column(String(255), nullable=False, server_default=db.text("'en-US'::character varying"))
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
|
||||
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
|
|
@ -581,20 +581,20 @@ class RecommendedApp(Base):
|
|||
class InstalledApp(Base):
|
||||
__tablename__ = "installed_apps"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="installed_app_pkey"),
|
||||
db.Index("installed_app_tenant_id_idx", "tenant_id"),
|
||||
db.Index("installed_app_app_id_idx", "app_id"),
|
||||
db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
|
||||
sa.PrimaryKeyConstraint("id", name="installed_app_pkey"),
|
||||
sa.Index("installed_app_tenant_id_idx", "tenant_id"),
|
||||
sa.Index("installed_app_app_id_idx", "app_id"),
|
||||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
app_owner_tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0)
|
||||
is_pinned: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
last_used_at = mapped_column(db.DateTime, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
is_pinned: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
last_used_at = mapped_column(sa.DateTime, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
|
|
@ -610,23 +610,23 @@ class InstalledApp(Base):
|
|||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="conversation_pkey"),
|
||||
db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="conversation_pkey"),
|
||||
sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
app_model_config_id = mapped_column(StringUUID, nullable=True)
|
||||
model_provider = mapped_column(String(255), nullable=True)
|
||||
override_model_configs = mapped_column(db.Text)
|
||||
override_model_configs = mapped_column(sa.Text)
|
||||
model_id = mapped_column(String(255), nullable=True)
|
||||
mode: Mapped[str] = mapped_column(String(255))
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
summary = mapped_column(db.Text)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
||||
introduction = mapped_column(db.Text)
|
||||
system_instruction = mapped_column(db.Text)
|
||||
system_instruction_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
summary = mapped_column(sa.Text)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
|
||||
introduction = mapped_column(sa.Text)
|
||||
system_instruction = mapped_column(sa.Text)
|
||||
system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
status: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
# The `invoke_from` records how the conversation is created.
|
||||
|
|
@ -639,18 +639,18 @@ class Conversation(Base):
|
|||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_end_user_id = mapped_column(StringUUID)
|
||||
from_account_id = mapped_column(StringUUID)
|
||||
read_at = mapped_column(db.DateTime)
|
||||
read_at = mapped_column(sa.DateTime)
|
||||
read_account_id = mapped_column(StringUUID)
|
||||
dialogue_count: Mapped[int] = mapped_column(default=0)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
|
||||
message_annotations = db.relationship(
|
||||
"MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
|
||||
)
|
||||
|
||||
is_deleted: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
|
|
@ -892,36 +892,36 @@ class Message(Base):
|
|||
Index("message_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
model_provider = mapped_column(String(255), nullable=True)
|
||||
model_id = mapped_column(String(255), nullable=True)
|
||||
override_model_configs = mapped_column(db.Text)
|
||||
conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
|
||||
query: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
message = mapped_column(db.JSON, nullable=False)
|
||||
message_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
message_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
answer: Mapped[str] = db.Column(db.Text, nullable=False) # TODO make it mapped_column
|
||||
answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
override_model_configs = mapped_column(sa.Text)
|
||||
conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
|
||||
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
|
||||
query: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
message = mapped_column(sa.JSON, nullable=False)
|
||||
message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
message_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
answer: Mapped[str] = db.Column(sa.Text, nullable=False) # TODO make it mapped_column
|
||||
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
parent_message_id = mapped_column(StringUUID, nullable=True)
|
||||
provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0"))
|
||||
total_price = mapped_column(db.Numeric(10, 7))
|
||||
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_price = mapped_column(sa.Numeric(10, 7))
|
||||
currency: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
error = mapped_column(db.Text)
|
||||
message_metadata = mapped_column(db.Text)
|
||||
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
|
||||
error = mapped_column(sa.Text)
|
||||
message_metadata = mapped_column(sa.Text)
|
||||
invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
agent_based: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
|
||||
@property
|
||||
|
|
@ -1228,23 +1228,23 @@ class Message(Base):
|
|||
class MessageFeedback(Base):
|
||||
__tablename__ = "message_feedbacks"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
|
||||
db.Index("message_feedback_app_idx", "app_id"),
|
||||
db.Index("message_feedback_message_idx", "message_id", "from_source"),
|
||||
db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
|
||||
sa.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
|
||||
sa.Index("message_feedback_app_idx", "app_id"),
|
||||
sa.Index("message_feedback_message_idx", "message_id", "from_source"),
|
||||
sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id = mapped_column(StringUUID, nullable=False)
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
rating: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
content = mapped_column(db.Text)
|
||||
content = mapped_column(sa.Text)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_end_user_id = mapped_column(StringUUID)
|
||||
from_account_id = mapped_column(StringUUID)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def from_account(self):
|
||||
|
|
@ -1270,9 +1270,9 @@ class MessageFeedback(Base):
|
|||
class MessageFile(Base):
|
||||
__tablename__ = "message_files"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_file_pkey"),
|
||||
db.Index("message_file_message_idx", "message_id"),
|
||||
db.Index("message_file_created_by_idx", "created_by"),
|
||||
sa.PrimaryKeyConstraint("id", name="message_file_pkey"),
|
||||
sa.Index("message_file_message_idx", "message_id"),
|
||||
sa.Index("message_file_created_by_idx", "created_by"),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
|
|
@ -1296,37 +1296,37 @@ class MessageFile(Base):
|
|||
self.created_by_role = created_by_role.value
|
||||
self.created_by = created_by
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
transfer_method: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||
url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class MessageAnnotation(Base):
|
||||
__tablename__ = "message_annotations"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
|
||||
db.Index("message_annotation_app_idx", "app_id"),
|
||||
db.Index("message_annotation_conversation_idx", "conversation_id"),
|
||||
db.Index("message_annotation_message_idx", "message_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
|
||||
sa.Index("message_annotation_app_idx", "app_id"),
|
||||
sa.Index("message_annotation_conversation_idx", "conversation_id"),
|
||||
sa.Index("message_annotation_message_idx", "message_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, db.ForeignKey("conversations.id"))
|
||||
conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
|
||||
message_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
question = db.Column(db.Text, nullable=True)
|
||||
content = mapped_column(db.Text, nullable=False)
|
||||
hit_count: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
question = db.Column(sa.Text, nullable=True)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
|
|
@ -1342,24 +1342,24 @@ class MessageAnnotation(Base):
|
|||
class AppAnnotationHitHistory(Base):
|
||||
__tablename__ = "app_annotation_hit_histories"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
|
||||
db.Index("app_annotation_hit_histories_app_idx", "app_id"),
|
||||
db.Index("app_annotation_hit_histories_account_idx", "account_id"),
|
||||
db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"),
|
||||
db.Index("app_annotation_hit_histories_message_idx", "message_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
|
||||
sa.Index("app_annotation_hit_histories_app_idx", "app_id"),
|
||||
sa.Index("app_annotation_hit_histories_account_idx", "account_id"),
|
||||
sa.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"),
|
||||
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
source = mapped_column(db.Text, nullable=False)
|
||||
question = mapped_column(db.Text, nullable=False)
|
||||
source = mapped_column(sa.Text, nullable=False)
|
||||
question = mapped_column(sa.Text, nullable=False)
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
score = mapped_column(Float, nullable=False, server_default=db.text("0"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
score = mapped_column(Float, nullable=False, server_default=sa.text("0"))
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
annotation_question = mapped_column(db.Text, nullable=False)
|
||||
annotation_content = mapped_column(db.Text, nullable=False)
|
||||
annotation_question = mapped_column(sa.Text, nullable=False)
|
||||
annotation_content = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
|
|
@ -1380,18 +1380,18 @@ class AppAnnotationHitHistory(Base):
|
|||
class AppAnnotationSetting(Base):
|
||||
__tablename__ = "app_annotation_settings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
|
||||
db.Index("app_annotation_settings_app_idx", "app_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
|
||||
sa.Index("app_annotation_settings_app_idx", "app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
score_threshold = mapped_column(Float, nullable=False, server_default=db.text("0"))
|
||||
score_threshold = mapped_column(Float, nullable=False, server_default=sa.text("0"))
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=False)
|
||||
created_user_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_user_id = mapped_column(StringUUID, nullable=False)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def collection_binding_detail(self):
|
||||
|
|
@ -1408,58 +1408,58 @@ class AppAnnotationSetting(Base):
|
|||
class OperationLog(Base):
|
||||
__tablename__ = "operation_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="operation_log_pkey"),
|
||||
db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
|
||||
sa.PrimaryKeyConstraint("id", name="operation_log_pkey"),
|
||||
sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
action: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
content = mapped_column(db.JSON)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
content = mapped_column(sa.JSON)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_ip: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class EndUser(Base, UserMixin):
|
||||
__tablename__ = "end_users"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="end_user_pkey"),
|
||||
db.Index("end_user_session_id_idx", "session_id", "type"),
|
||||
db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
|
||||
sa.PrimaryKeyConstraint("id", name="end_user_pkey"),
|
||||
sa.Index("end_user_session_id_idx", "session_id", "type"),
|
||||
sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
external_user_id = mapped_column(String(255), nullable=True)
|
||||
name = mapped_column(String(255))
|
||||
is_anonymous: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
session_id: Mapped[str] = mapped_column()
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class AppMCPServer(Base):
|
||||
__tablename__ = "app_mcp_servers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
|
||||
db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
|
||||
sa.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
|
||||
sa.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
|
||||
)
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
server_code: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
parameters = mapped_column(db.Text, nullable=False)
|
||||
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
|
||||
parameters = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@staticmethod
|
||||
def generate_server_code(n):
|
||||
|
|
@ -1478,34 +1478,34 @@ class AppMCPServer(Base):
|
|||
class Site(Base):
|
||||
__tablename__ = "sites"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="site_pkey"),
|
||||
db.Index("site_app_id_idx", "app_id"),
|
||||
db.Index("site_code_idx", "code", "status"),
|
||||
sa.PrimaryKeyConstraint("id", name="site_pkey"),
|
||||
sa.Index("site_app_id_idx", "app_id"),
|
||||
sa.Index("site_code_idx", "code", "status"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
icon_type = mapped_column(String(255), nullable=True)
|
||||
icon = mapped_column(String(255))
|
||||
icon_background = mapped_column(String(255))
|
||||
description = mapped_column(db.Text)
|
||||
description = mapped_column(sa.Text)
|
||||
default_language: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
chat_color_theme = mapped_column(String(255))
|
||||
chat_color_theme_inverted: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
chat_color_theme_inverted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
copyright = mapped_column(String(255))
|
||||
privacy_policy = mapped_column(String(255))
|
||||
show_workflow_steps: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
show_workflow_steps: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
_custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="")
|
||||
customize_domain = mapped_column(String(255))
|
||||
customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
prompt_public: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
status = mapped_column(String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying"))
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
code = mapped_column(String(255))
|
||||
|
||||
@property
|
||||
|
|
@ -1535,19 +1535,19 @@ class Site(Base):
|
|||
class ApiToken(Base):
|
||||
__tablename__ = "api_tokens"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="api_token_pkey"),
|
||||
db.Index("api_token_app_id_type_idx", "app_id", "type"),
|
||||
db.Index("api_token_token_idx", "token", "type"),
|
||||
db.Index("api_token_tenant_idx", "tenant_id", "type"),
|
||||
sa.PrimaryKeyConstraint("id", name="api_token_pkey"),
|
||||
sa.Index("api_token_app_id_type_idx", "app_id", "type"),
|
||||
sa.Index("api_token_token_idx", "token", "type"),
|
||||
sa.Index("api_token_tenant_idx", "tenant_id", "type"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=True)
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
type = mapped_column(String(16), nullable=False)
|
||||
token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
last_used_at = mapped_column(db.DateTime, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
last_used_at = mapped_column(sa.DateTime, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@staticmethod
|
||||
def generate_api_key(prefix, n):
|
||||
|
|
@ -1561,26 +1561,26 @@ class ApiToken(Base):
|
|||
class UploadFile(Base):
|
||||
__tablename__ = "upload_files"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="upload_file_pkey"),
|
||||
db.Index("upload_file_tenant_idx", "tenant_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="upload_file_pkey"),
|
||||
sa.Index("upload_file_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
size: Mapped[int] = mapped_column(db.Integer, nullable=False)
|
||||
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
extension: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
mime_type: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, server_default=db.text("'account'::character varying")
|
||||
String(255), nullable=False, server_default=sa.text("'account'::character varying")
|
||||
)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True)
|
||||
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
|
||||
hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
source_url: Mapped[str] = mapped_column(sa.TEXT, default="")
|
||||
|
||||
|
|
@ -1623,71 +1623,71 @@ class UploadFile(Base):
|
|||
class ApiRequest(Base):
|
||||
__tablename__ = "api_requests"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="api_request_pkey"),
|
||||
db.Index("api_request_token_idx", "tenant_id", "api_token_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="api_request_pkey"),
|
||||
sa.Index("api_request_token_idx", "tenant_id", "api_token_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
api_token_id = mapped_column(StringUUID, nullable=False)
|
||||
path: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
request = mapped_column(db.Text, nullable=True)
|
||||
response = mapped_column(db.Text, nullable=True)
|
||||
request = mapped_column(sa.Text, nullable=True)
|
||||
response = mapped_column(sa.Text, nullable=True)
|
||||
ip: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class MessageChain(Base):
|
||||
__tablename__ = "message_chains"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_chain_pkey"),
|
||||
db.Index("message_chain_message_id_idx", "message_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="message_chain_pkey"),
|
||||
sa.Index("message_chain_message_id_idx", "message_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
input = mapped_column(db.Text, nullable=True)
|
||||
output = mapped_column(db.Text, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
input = mapped_column(sa.Text, nullable=True)
|
||||
output = mapped_column(sa.Text, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class MessageAgentThought(Base):
|
||||
__tablename__ = "message_agent_thoughts"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
|
||||
db.Index("message_agent_thought_message_id_idx", "message_id"),
|
||||
db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"),
|
||||
sa.Index("message_agent_thought_message_id_idx", "message_id"),
|
||||
sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
message_chain_id = mapped_column(StringUUID, nullable=True)
|
||||
position: Mapped[int] = mapped_column(db.Integer, nullable=False)
|
||||
thought = mapped_column(db.Text, nullable=True)
|
||||
tool = mapped_column(db.Text, nullable=True)
|
||||
tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
|
||||
tool_meta_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
|
||||
tool_input = mapped_column(db.Text, nullable=True)
|
||||
observation = mapped_column(db.Text, nullable=True)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
thought = mapped_column(sa.Text, nullable=True)
|
||||
tool = mapped_column(sa.Text, nullable=True)
|
||||
tool_labels_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text"))
|
||||
tool_meta_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text"))
|
||||
tool_input = mapped_column(sa.Text, nullable=True)
|
||||
observation = mapped_column(sa.Text, nullable=True)
|
||||
# plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
|
||||
tool_process_data = mapped_column(db.Text, nullable=True)
|
||||
message = mapped_column(db.Text, nullable=True)
|
||||
message_token: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
|
||||
message_unit_price = mapped_column(db.Numeric, nullable=True)
|
||||
message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
message_files = mapped_column(db.Text, nullable=True)
|
||||
answer = db.Column(db.Text, nullable=True)
|
||||
answer_token: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
|
||||
answer_unit_price = mapped_column(db.Numeric, nullable=True)
|
||||
answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
tokens: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
|
||||
total_price = mapped_column(db.Numeric, nullable=True)
|
||||
tool_process_data = mapped_column(sa.Text, nullable=True)
|
||||
message = mapped_column(sa.Text, nullable=True)
|
||||
message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
message_unit_price = mapped_column(sa.Numeric, nullable=True)
|
||||
message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
message_files = mapped_column(sa.Text, nullable=True)
|
||||
answer = db.Column(sa.Text, nullable=True)
|
||||
answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
answer_unit_price = mapped_column(sa.Numeric, nullable=True)
|
||||
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
total_price = mapped_column(sa.Numeric, nullable=True)
|
||||
currency = mapped_column(String, nullable=True)
|
||||
latency: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True)
|
||||
latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
|
||||
created_by_role = mapped_column(String, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
@property
|
||||
def files(self) -> list:
|
||||
|
|
@ -1769,80 +1769,80 @@ class MessageAgentThought(Base):
|
|||
class DatasetRetrieverResource(Base):
|
||||
__tablename__ = "dataset_retriever_resources"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
|
||||
db.Index("dataset_retriever_resource_message_id_idx", "message_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
|
||||
sa.Index("dataset_retriever_resource_message_id_idx", "message_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()"))
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int] = mapped_column(db.Integer, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_name = mapped_column(db.Text, nullable=False)
|
||||
dataset_name = mapped_column(sa.Text, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=True)
|
||||
document_name = mapped_column(db.Text, nullable=False)
|
||||
data_source_type = mapped_column(db.Text, nullable=True)
|
||||
document_name = mapped_column(sa.Text, nullable=False)
|
||||
data_source_type = mapped_column(sa.Text, nullable=True)
|
||||
segment_id = mapped_column(StringUUID, nullable=True)
|
||||
score: Mapped[Optional[float]] = mapped_column(db.Float, nullable=True)
|
||||
content = mapped_column(db.Text, nullable=False)
|
||||
hit_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
|
||||
word_count: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
|
||||
segment_position: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
|
||||
index_node_hash = mapped_column(db.Text, nullable=True)
|
||||
retriever_from = mapped_column(db.Text, nullable=False)
|
||||
score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
|
||||
content = mapped_column(sa.Text, nullable=False)
|
||||
hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
index_node_hash = mapped_column(sa.Text, nullable=True)
|
||||
retriever_from = mapped_column(sa.Text, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tag_pkey"),
|
||||
db.Index("tag_type_idx", "type"),
|
||||
db.Index("tag_name_idx", "name"),
|
||||
sa.PrimaryKeyConstraint("id", name="tag_pkey"),
|
||||
sa.Index("tag_type_idx", "type"),
|
||||
sa.Index("tag_name_idx", "name"),
|
||||
)
|
||||
|
||||
TAG_TYPE_LIST = ["knowledge", "app"]
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
type = mapped_column(String(16), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TagBinding(Base):
|
||||
__tablename__ = "tag_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
|
||||
db.Index("tag_bind_target_id_idx", "target_id"),
|
||||
db.Index("tag_bind_tag_id_idx", "tag_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
|
||||
sa.Index("tag_bind_target_id_idx", "target_id"),
|
||||
sa.Index("tag_bind_tag_id_idx", "tag_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
tag_id = mapped_column(StringUUID, nullable=True)
|
||||
target_id = mapped_column(StringUUID, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TraceAppConfig(Base):
|
||||
__tablename__ = "trace_app_config"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
|
||||
db.Index("trace_app_config_app_id_idx", "app_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
|
||||
sa.Index("trace_app_config_app_id_idx", "app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
tracing_provider = mapped_column(String(255), nullable=True)
|
||||
tracing_config = mapped_column(db.JSON, nullable=True)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
tracing_config = mapped_column(sa.JSON, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
|
||||
@property
|
||||
def tracing_config_dict(self):
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ from datetime import datetime
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
|
|
@ -47,9 +47,9 @@ class Provider(Base):
|
|||
|
||||
__tablename__ = "providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="provider_pkey"),
|
||||
db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"),
|
||||
db.UniqueConstraint(
|
||||
sa.PrimaryKeyConstraint("id", name="provider_pkey"),
|
||||
sa.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"),
|
||||
sa.UniqueConstraint(
|
||||
"tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota"
|
||||
),
|
||||
)
|
||||
|
|
@ -60,15 +60,15 @@ class Provider(Base):
|
|||
provider_type: Mapped[str] = mapped_column(
|
||||
String(40), nullable=False, server_default=text("'custom'::character varying")
|
||||
)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
quota_type: Mapped[Optional[str]] = mapped_column(
|
||||
String(40), nullable=True, server_default=text("''::character varying")
|
||||
)
|
||||
quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True)
|
||||
quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0)
|
||||
quota_limit: Mapped[Optional[int]] = mapped_column(sa.BigInteger, nullable=True)
|
||||
quota_used: Mapped[Optional[int]] = mapped_column(sa.BigInteger, default=0)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
@ -104,9 +104,9 @@ class ProviderModel(Base):
|
|||
|
||||
__tablename__ = "provider_models"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="provider_model_pkey"),
|
||||
db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"),
|
||||
db.UniqueConstraint(
|
||||
sa.PrimaryKeyConstraint("id", name="provider_model_pkey"),
|
||||
sa.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"),
|
||||
sa.UniqueConstraint(
|
||||
"tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name"
|
||||
),
|
||||
)
|
||||
|
|
@ -116,8 +116,8 @@ class ProviderModel(Base):
|
|||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
|
@ -125,8 +125,8 @@ class ProviderModel(Base):
|
|||
class TenantDefaultModel(Base):
|
||||
__tablename__ = "tenant_default_models"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
|
||||
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
|
||||
sa.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
|
||||
sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
|
|
@ -141,8 +141,8 @@ class TenantDefaultModel(Base):
|
|||
class TenantPreferredModelProvider(Base):
|
||||
__tablename__ = "tenant_preferred_model_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
|
||||
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
sa.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
|
||||
sa.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
|
|
@ -156,8 +156,8 @@ class TenantPreferredModelProvider(Base):
|
|||
class ProviderOrder(Base):
|
||||
__tablename__ = "provider_orders"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="provider_order_pkey"),
|
||||
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
sa.PrimaryKeyConstraint("id", name="provider_order_pkey"),
|
||||
sa.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
|
|
@ -167,9 +167,9 @@ class ProviderOrder(Base):
|
|||
payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False)
|
||||
payment_id: Mapped[Optional[str]] = mapped_column(String(191))
|
||||
transaction_id: Mapped[Optional[str]] = mapped_column(String(191))
|
||||
quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1"))
|
||||
quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
|
||||
currency: Mapped[Optional[str]] = mapped_column(String(40))
|
||||
total_amount: Mapped[Optional[int]] = mapped_column(db.Integer)
|
||||
total_amount: Mapped[Optional[int]] = mapped_column(sa.Integer)
|
||||
payment_status: Mapped[str] = mapped_column(
|
||||
String(40), nullable=False, server_default=text("'wait_pay'::character varying")
|
||||
)
|
||||
|
|
@ -187,8 +187,8 @@ class ProviderModelSetting(Base):
|
|||
|
||||
__tablename__ = "provider_model_settings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"),
|
||||
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
sa.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"),
|
||||
sa.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
|
|
@ -196,8 +196,8 @@ class ProviderModelSetting(Base):
|
|||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
|
||||
load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||
load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
|
@ -209,8 +209,8 @@ class LoadBalancingModelConfig(Base):
|
|||
|
||||
__tablename__ = "load_balancing_model_configs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"),
|
||||
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
sa.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"),
|
||||
sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
|
|
@ -219,7 +219,7 @@ class LoadBalancingModelConfig(Base):
|
|||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -2,50 +2,50 @@ import json
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from models.base import Base
|
||||
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DataSourceOauthBinding(Base):
|
||||
__tablename__ = "data_source_oauth_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
|
||||
db.Index("source_binding_tenant_id_idx", "tenant_id"),
|
||||
db.Index("source_info_idx", "source_info", postgresql_using="gin"),
|
||||
sa.PrimaryKeyConstraint("id", name="source_binding_pkey"),
|
||||
sa.Index("source_binding_tenant_id_idx", "tenant_id"),
|
||||
sa.Index("source_info_idx", "source_info", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
source_info = mapped_column(JSONB, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
|
||||
disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
|
||||
|
||||
class DataSourceApiKeyAuthBinding(Base):
|
||||
__tablename__ = "data_source_api_key_auth_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
|
||||
db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"),
|
||||
db.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
|
||||
sa.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
|
||||
sa.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"),
|
||||
sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
credentials = mapped_column(db.Text, nullable=True) # JSON
|
||||
credentials = mapped_column(sa.Text, nullable=True) # JSON
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled: Mapped[Optional[bool]] = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
|
||||
disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from celery import states # type: ignore
|
||||
from sqlalchemy import DateTime, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
|
@ -16,7 +17,7 @@ class CeleryTask(Base):
|
|||
|
||||
__tablename__ = "celery_taskmeta"
|
||||
|
||||
id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
|
||||
id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
|
||||
task_id = mapped_column(String(155), unique=True)
|
||||
status = mapped_column(String(50), default=states.PENDING)
|
||||
result = mapped_column(db.PickleType, nullable=True)
|
||||
|
|
@ -26,12 +27,12 @@ class CeleryTask(Base):
|
|||
onupdate=lambda: naive_utc_now(),
|
||||
nullable=True,
|
||||
)
|
||||
traceback = mapped_column(db.Text, nullable=True)
|
||||
traceback = mapped_column(sa.Text, nullable=True)
|
||||
name = mapped_column(String(155), nullable=True)
|
||||
args = mapped_column(db.LargeBinary, nullable=True)
|
||||
kwargs = mapped_column(db.LargeBinary, nullable=True)
|
||||
args = mapped_column(sa.LargeBinary, nullable=True)
|
||||
kwargs = mapped_column(sa.LargeBinary, nullable=True)
|
||||
worker = mapped_column(String(155), nullable=True)
|
||||
retries: Mapped[Optional[int]] = mapped_column(db.Integer, nullable=True)
|
||||
retries: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||
queue = mapped_column(String(155), nullable=True)
|
||||
|
||||
|
||||
|
|
@ -41,7 +42,7 @@ class CeleryTaskSet(Base):
|
|||
__tablename__ = "celery_tasksetmeta"
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
|
||||
sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
|
||||
)
|
||||
taskset_id = mapped_column(String(155), unique=True)
|
||||
result = mapped_column(db.PickleType, nullable=True)
|
||||
|
|
|
|||
|
|
@ -25,33 +25,33 @@ from .types import StringUUID
|
|||
class ToolOAuthSystemClient(Base):
|
||||
__tablename__ = "tool_oauth_system_clients"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
|
||||
db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
|
||||
sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
plugin_id = mapped_column(String(512), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# oauth params of the tool provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
|
||||
# tenant level tool oauth client params (client_id, client_secret, etc.)
|
||||
class ToolOAuthTenantClient(Base):
|
||||
__tablename__ = "tool_oauth_tenant_clients"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
# oauth params of the tool provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
@property
|
||||
def oauth_params(self) -> dict:
|
||||
|
|
@ -65,14 +65,14 @@ class BuiltinToolProvider(Base):
|
|||
|
||||
__tablename__ = "tool_builtin_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
|
||||
)
|
||||
|
||||
# id of the tool provider
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
name: Mapped[str] = mapped_column(
|
||||
String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
|
||||
String(256), nullable=False, server_default=sa.text("'API KEY 1'::character varying")
|
||||
)
|
||||
# id of the tenant
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
|
|
@ -81,19 +81,19 @@ class BuiltinToolProvider(Base):
|
|||
# name of the tool provider
|
||||
provider: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
# credential of the tool provider
|
||||
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
|
||||
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
# credential type, e.g., "api-key", "oauth2"
|
||||
credential_type: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, server_default=db.text("'api-key'::character varying")
|
||||
String(32), nullable=False, server_default=sa.text("'api-key'::character varying")
|
||||
)
|
||||
expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))
|
||||
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict:
|
||||
|
|
@ -107,28 +107,28 @@ class ApiToolProvider(Base):
|
|||
|
||||
__tablename__ = "tool_api_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
|
||||
db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
|
||||
sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# name of the api provider
|
||||
name = mapped_column(String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
|
||||
name = mapped_column(String(255), nullable=False, server_default=sa.text("'API KEY 1'::character varying"))
|
||||
# icon
|
||||
icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# original schema
|
||||
schema = mapped_column(db.Text, nullable=False)
|
||||
schema = mapped_column(sa.Text, nullable=False)
|
||||
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# who created this tool
|
||||
user_id = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
# description of the provider
|
||||
description = mapped_column(db.Text, nullable=False)
|
||||
description = mapped_column(sa.Text, nullable=False)
|
||||
# json format tools
|
||||
tools_str = mapped_column(db.Text, nullable=False)
|
||||
tools_str = mapped_column(sa.Text, nullable=False)
|
||||
# json format credentials
|
||||
credentials_str = mapped_column(db.Text, nullable=False)
|
||||
credentials_str = mapped_column(sa.Text, nullable=False)
|
||||
# privacy policy
|
||||
privacy_policy = mapped_column(String(255), nullable=True)
|
||||
# custom_disclaimer
|
||||
|
|
@ -167,11 +167,11 @@ class ToolLabelBinding(Base):
|
|||
|
||||
__tablename__ = "tool_label_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"),
|
||||
db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"),
|
||||
sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# tool id
|
||||
tool_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# tool type
|
||||
|
|
@ -187,12 +187,12 @@ class WorkflowToolProvider(Base):
|
|||
|
||||
__tablename__ = "tool_workflow_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"),
|
||||
db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"),
|
||||
db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"),
|
||||
sa.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"),
|
||||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# name of the workflow provider
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# label of the workflow provider
|
||||
|
|
@ -208,17 +208,17 @@ class WorkflowToolProvider(Base):
|
|||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# description of the provider
|
||||
description: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# parameter configuration
|
||||
parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]")
|
||||
parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]")
|
||||
# privacy policy
|
||||
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -245,19 +245,19 @@ class MCPToolProvider(Base):
|
|||
|
||||
__tablename__ = "tool_mcp_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"),
|
||||
db.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"),
|
||||
db.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"),
|
||||
sa.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"),
|
||||
sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# name of the mcp provider
|
||||
name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# server identifier of the mcp provider
|
||||
server_identifier: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# encrypted url of the mcp provider
|
||||
server_url: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
server_url: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# hash of server_url for uniqueness check
|
||||
server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# icon of the mcp provider
|
||||
|
|
@ -267,16 +267,16 @@ class MCPToolProvider(Base):
|
|||
# who created this tool
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# encrypted credentials
|
||||
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
|
||||
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
|
||||
# authed
|
||||
authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False)
|
||||
authed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||
# tools
|
||||
tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]")
|
||||
tools: Mapped[str] = mapped_column(sa.Text, nullable=False, default="[]")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
|
|
@ -347,9 +347,9 @@ class ToolModelInvoke(Base):
|
|||
"""
|
||||
|
||||
__tablename__ = "tool_model_invokes"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# who invoke this tool
|
||||
user_id = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
|
|
@ -361,18 +361,18 @@ class ToolModelInvoke(Base):
|
|||
# tool name
|
||||
tool_name = mapped_column(String(128), nullable=False)
|
||||
# invoke parameters
|
||||
model_parameters = mapped_column(db.Text, nullable=False)
|
||||
model_parameters = mapped_column(sa.Text, nullable=False)
|
||||
# prompt messages
|
||||
prompt_messages = mapped_column(db.Text, nullable=False)
|
||||
prompt_messages = mapped_column(sa.Text, nullable=False)
|
||||
# invoke response
|
||||
model_response = mapped_column(db.Text, nullable=False)
|
||||
model_response = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
prompt_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||
provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0"))
|
||||
total_price = mapped_column(db.Numeric(10, 7))
|
||||
prompt_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_price = mapped_column(sa.Numeric(10, 7))
|
||||
currency: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
@ -386,13 +386,13 @@ class ToolConversationVariables(Base):
|
|||
|
||||
__tablename__ = "tool_conversation_variables"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"),
|
||||
# add index for user_id and conversation_id
|
||||
db.Index("user_id_idx", "user_id"),
|
||||
db.Index("conversation_id_idx", "conversation_id"),
|
||||
sa.Index("user_id_idx", "user_id"),
|
||||
sa.Index("conversation_id_idx", "conversation_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# conversation user id
|
||||
user_id = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
|
|
@ -400,7 +400,7 @@ class ToolConversationVariables(Base):
|
|||
# conversation id
|
||||
conversation_id = mapped_column(StringUUID, nullable=False)
|
||||
# variables pool
|
||||
variables_str = mapped_column(db.Text, nullable=False)
|
||||
variables_str = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
@ -417,11 +417,11 @@ class ToolFile(Base):
|
|||
|
||||
__tablename__ = "tool_files"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tool_file_pkey"),
|
||||
db.Index("tool_file_conversation_id_idx", "conversation_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="tool_file_pkey"),
|
||||
sa.Index("tool_file_conversation_id_idx", "conversation_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# conversation user id
|
||||
user_id: Mapped[str] = mapped_column(StringUUID)
|
||||
# tenant id
|
||||
|
|
@ -448,30 +448,30 @@ class DeprecatedPublishedAppTool(Base):
|
|||
|
||||
__tablename__ = "tool_published_apps"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
|
||||
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
|
||||
sa.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
|
||||
sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# id of the app
|
||||
app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
|
||||
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# who published this tool
|
||||
description = mapped_column(db.Text, nullable=False)
|
||||
description = mapped_column(sa.Text, nullable=False)
|
||||
# llm_description of the tool, for LLM
|
||||
llm_description = mapped_column(db.Text, nullable=False)
|
||||
llm_description = mapped_column(sa.Text, nullable=False)
|
||||
# query description, query will be seem as a parameter of the tool,
|
||||
# to describe this parameter to llm, we need this field
|
||||
query_description = mapped_column(db.Text, nullable=False)
|
||||
query_description = mapped_column(sa.Text, nullable=False)
|
||||
# query name, the name of the query parameter
|
||||
query_name = mapped_column(String(40), nullable=False)
|
||||
# name of the tool provider
|
||||
tool_name = mapped_column(String(40), nullable=False)
|
||||
# author
|
||||
author = mapped_column(String(40), nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
@property
|
||||
def description_i18n(self) -> I18nObject:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
|
@ -13,15 +14,15 @@ from .types import StringUUID
|
|||
class SavedMessage(Base):
|
||||
__tablename__ = "saved_messages"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="saved_message_pkey"),
|
||||
db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
|
||||
sa.PrimaryKeyConstraint("id", name="saved_message_pkey"),
|
||||
sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
created_by_role = mapped_column(
|
||||
String(255), nullable=False, server_default=db.text("'end_user'::character varying")
|
||||
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
@ -34,15 +35,15 @@ class SavedMessage(Base):
|
|||
class PinnedConversation(Base):
|
||||
__tablename__ = "pinned_conversations"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
|
||||
db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
|
||||
sa.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
|
||||
sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_by_role = mapped_column(
|
||||
String(255), nullable=False, server_default=db.text("'end_user'::character varying")
|
||||
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from enum import Enum, StrEnum
|
|||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import DateTime, orm
|
||||
|
||||
|
|
@ -24,7 +25,6 @@ from ._workflow_exc import NodeNotFoundError, WorkflowDataError
|
|||
if TYPE_CHECKING:
|
||||
from models.model import AppMode
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import Index, PrimaryKeyConstraint, String, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
|
||||
|
|
@ -117,11 +117,11 @@ class Workflow(Base):
|
|||
|
||||
__tablename__ = "workflows"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_pkey"),
|
||||
db.Index("workflow_version_idx", "tenant_id", "app_id", "version"),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_pkey"),
|
||||
sa.Index("workflow_version_idx", "tenant_id", "app_id", "version"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
|
@ -140,10 +140,10 @@ class Workflow(Base):
|
|||
server_onupdate=func.current_timestamp(),
|
||||
)
|
||||
_environment_variables: Mapped[str] = mapped_column(
|
||||
"environment_variables", db.Text, nullable=False, server_default="{}"
|
||||
"environment_variables", sa.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
_conversation_variables: Mapped[str] = mapped_column(
|
||||
"conversation_variables", db.Text, nullable=False, server_default="{}"
|
||||
"conversation_variables", sa.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
|
||||
VERSION_DRAFT = "draft"
|
||||
|
|
@ -491,11 +491,11 @@ class WorkflowRun(Base):
|
|||
|
||||
__tablename__ = "workflow_runs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
|
||||
db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
|
||||
sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
|
||||
|
|
@ -503,19 +503,19 @@ class WorkflowRun(Base):
|
|||
type: Mapped[str] = mapped_column(String(255))
|
||||
triggered_from: Mapped[str] = mapped_column(String(255))
|
||||
version: Mapped[str] = mapped_column(String(255))
|
||||
graph: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
inputs: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
graph: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
inputs: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
||||
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
|
||||
error: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0"))
|
||||
error: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
||||
total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
|
||||
total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
|
||||
exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True)
|
||||
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
|
|
@ -704,25 +704,25 @@ class WorkflowNodeExecutionModel(Base):
|
|||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID)
|
||||
triggered_from: Mapped[str] = mapped_column(String(255))
|
||||
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
index: Mapped[int] = mapped_column(db.Integer)
|
||||
index: Mapped[int] = mapped_column(sa.Integer)
|
||||
predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
node_execution_id: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
node_id: Mapped[str] = mapped_column(String(255))
|
||||
node_type: Mapped[str] = mapped_column(String(255))
|
||||
title: Mapped[str] = mapped_column(String(255))
|
||||
inputs: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
process_data: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
outputs: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
inputs: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
process_data: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
outputs: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
status: Mapped[str] = mapped_column(String(255))
|
||||
error: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0"))
|
||||
execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text)
|
||||
error: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0"))
|
||||
execution_metadata: Mapped[Optional[str]] = mapped_column(sa.Text)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
|
||||
created_by_role: Mapped[str] = mapped_column(String(255))
|
||||
created_by: Mapped[str] = mapped_column(StringUUID)
|
||||
|
|
@ -834,11 +834,11 @@ class WorkflowAppLog(Base):
|
|||
|
||||
__tablename__ = "workflow_app_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
|
||||
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
|
||||
sa.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -864,6 +864,19 @@ class WorkflowAppLog(Base):
|
|||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"app_id": self.app_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"created_from": self.created_from,
|
||||
"created_by_role": self.created_by_role,
|
||||
"created_by": self.created_by,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
|
||||
class ConversationVariable(Base):
|
||||
__tablename__ = "workflow_conversation_variables"
|
||||
|
|
@ -871,7 +884,7 @@ class ConversationVariable(Base):
|
|||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
|
||||
data: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
data: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), index=True
|
||||
)
|
||||
|
|
@ -933,7 +946,7 @@ class WorkflowDraftVariable(Base):
|
|||
__allow_unmapped__ = True
|
||||
|
||||
# id is the unique identifier of a draft variable.
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()"))
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
|
|
|
|||
|
|
@ -49,6 +49,8 @@ dependencies = [
|
|||
"opentelemetry-instrumentation==0.48b0",
|
||||
"opentelemetry-instrumentation-celery==0.48b0",
|
||||
"opentelemetry-instrumentation-flask==0.48b0",
|
||||
"opentelemetry-instrumentation-redis==0.48b0",
|
||||
"opentelemetry-instrumentation-requests==0.48b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.48b0",
|
||||
"opentelemetry-propagator-b3==1.27.0",
|
||||
# opentelemetry-proto1.28.0 depends on protobuf (>=5.0,<6.0),
|
||||
|
|
@ -114,6 +116,7 @@ dev = [
|
|||
"pytest-cov~=4.1.0",
|
||||
"pytest-env~=1.1.3",
|
||||
"pytest-mock~=3.14.0",
|
||||
"testcontainers~=4.10.0",
|
||||
"types-aiofiles~=24.1.0",
|
||||
"types-beautifulsoup4~=4.12.0",
|
||||
"types-cachetools~=5.5.0",
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import yaml # type: ignore
|
|||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from packaging import version
|
||||
from packaging.version import parse as parse_version
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -269,7 +270,7 @@ class AppDslService:
|
|||
check_dependencies_pending_data = None
|
||||
if dependencies:
|
||||
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
|
||||
elif imported_version <= "0.1.5":
|
||||
elif parse_version(imported_version) <= parse_version("0.1.5"):
|
||||
if "workflow" in data:
|
||||
graph = data.get("workflow", {}).get("graph", {})
|
||||
dependencies_list = self._extract_dependencies_from_workflow_graph(graph)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from openai._exceptions import RateLimitError
|
||||
|
||||
|
|
@ -15,6 +16,7 @@ from libs.helper import RateLimiter
|
|||
from models.model import Account, App, AppMode, EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
|
@ -86,7 +88,8 @@ class AppGenerateService:
|
|||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
return rate_limit.generate(
|
||||
AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().generate(
|
||||
|
|
@ -101,7 +104,8 @@ class AppGenerateService:
|
|||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().generate(
|
||||
|
|
@ -210,14 +214,27 @@ class AppGenerateService:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow:
|
||||
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow:
|
||||
"""
|
||||
Get workflow
|
||||
:param app_model: app model
|
||||
:param invoke_from: invoke from
|
||||
:param workflow_id: optional workflow id to specify a specific version
|
||||
:return:
|
||||
"""
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# If workflow_id is specified, get the specific workflow version
|
||||
if workflow_id:
|
||||
try:
|
||||
workflow_uuid = uuid.UUID(workflow_id)
|
||||
except ValueError:
|
||||
raise WorkflowIdFormatError(f"Invalid workflow_id format: '{workflow_id}'. ")
|
||||
workflow = workflow_service.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise WorkflowNotFoundError(f"Workflow not found with id: {workflow_id}")
|
||||
return workflow
|
||||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# fetch draft workflow by app_model
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
|
|
|
|||
|
|
@ -159,9 +159,9 @@ class BillingService:
|
|||
):
|
||||
limiter_key = f"{account_id}:{tenant_id}"
|
||||
if cls.compliance_download_rate_limiter.is_rate_limited(limiter_key):
|
||||
from controllers.console.error import CompilanceRateLimitError
|
||||
from controllers.console.error import ComplianceRateLimitError
|
||||
|
||||
raise CompilanceRateLimitError()
|
||||
raise ComplianceRateLimitError()
|
||||
|
||||
json = {
|
||||
"doc_name": doc_name,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,19 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Tenant
|
||||
from models.model import App, Conversation, Message
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
Conversation,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
)
|
||||
from models.web import SavedMessage
|
||||
from models.workflow import WorkflowAppLog
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
|
@ -21,6 +33,85 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ClearFreePlanTenantExpiredLogs:
|
||||
@classmethod
|
||||
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None:
|
||||
"""
|
||||
Clean up message-related tables to avoid data redundancy.
|
||||
This method cleans up tables that have foreign key relationships with Message.
|
||||
|
||||
Args:
|
||||
session: Database session, the same with the one in process_tenant method
|
||||
tenant_id: Tenant ID for logging purposes
|
||||
batch_message_ids: List of message IDs to clean up
|
||||
"""
|
||||
if not batch_message_ids:
|
||||
return
|
||||
|
||||
# Clean up each related table
|
||||
related_tables = [
|
||||
(MessageFeedback, "message_feedbacks"),
|
||||
(MessageFile, "message_files"),
|
||||
(MessageAnnotation, "message_annotations"),
|
||||
(MessageChain, "message_chains"),
|
||||
(MessageAgentThought, "message_agent_thoughts"),
|
||||
(AppAnnotationHitHistory, "app_annotation_hit_histories"),
|
||||
(SavedMessage, "saved_messages"),
|
||||
]
|
||||
|
||||
for model, table_name in related_tables:
|
||||
# Query records related to expired messages
|
||||
records = (
|
||||
session.query(model)
|
||||
.filter(
|
||||
model.message_id.in_(batch_message_ids), # type: ignore
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(records) == 0:
|
||||
continue
|
||||
|
||||
# Save records before deletion
|
||||
record_ids = [record.id for record in records]
|
||||
try:
|
||||
record_data = []
|
||||
for record in records:
|
||||
try:
|
||||
if hasattr(record, "to_dict"):
|
||||
record_data.append(record.to_dict())
|
||||
else:
|
||||
# if record doesn't have to_dict method, we need to transform it to dict manually
|
||||
record_dict = {}
|
||||
for column in record.__table__.columns:
|
||||
record_dict[column.name] = getattr(record, column.name)
|
||||
record_data.append(record_dict)
|
||||
except Exception:
|
||||
logger.exception("Failed to transform %s record: %s", table_name, record.id)
|
||||
continue
|
||||
|
||||
if record_data:
|
||||
storage.save(
|
||||
f"free_plan_tenant_expired_logs/"
|
||||
f"{tenant_id}/{table_name}/{datetime.datetime.now().strftime('%Y-%m-%d')}"
|
||||
f"-{time.time()}.json",
|
||||
json.dumps(
|
||||
jsonable_encoder(record_data),
|
||||
).encode("utf-8"),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to save %s records", table_name)
|
||||
|
||||
session.query(model).filter(
|
||||
model.id.in_(record_ids), # type: ignore
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{datetime.datetime.now()}] Processed {len(record_ids)} "
|
||||
f"{table_name} records for tenant {tenant_id}"
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
|
||||
with flask_app.app_context():
|
||||
|
|
@ -58,6 +149,7 @@ class ClearFreePlanTenantExpiredLogs:
|
|||
Message.id.in_(message_ids),
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
cls._clear_message_related_tables(session, tenant_id, message_ids)
|
||||
session.commit()
|
||||
|
||||
click.echo(
|
||||
|
|
@ -199,6 +291,48 @@ class ClearFreePlanTenantExpiredLogs:
|
|||
if len(workflow_runs) < batch:
|
||||
break
|
||||
|
||||
while True:
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
workflow_app_logs = (
|
||||
session.query(WorkflowAppLog)
|
||||
.filter(
|
||||
WorkflowAppLog.tenant_id == tenant_id,
|
||||
WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
|
||||
)
|
||||
.limit(batch)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(workflow_app_logs) == 0:
|
||||
break
|
||||
|
||||
# save workflow app logs
|
||||
storage.save(
|
||||
f"free_plan_tenant_expired_logs/"
|
||||
f"{tenant_id}/workflow_app_logs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
|
||||
f"-{time.time()}.json",
|
||||
json.dumps(
|
||||
jsonable_encoder(
|
||||
[workflow_app_log.to_dict() for workflow_app_log in workflow_app_logs],
|
||||
),
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs]
|
||||
|
||||
# delete workflow app logs
|
||||
session.query(WorkflowAppLog).filter(
|
||||
WorkflowAppLog.id.in_(workflow_app_log_ids),
|
||||
).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{datetime.datetime.now()}] Processed {len(workflow_app_log_ids)}"
|
||||
f" workflow app logs for tenant {tenant_id}"
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def process(cls, days: int, batch: int, tenant_ids: list[str]):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -266,7 +266,7 @@ class DatasetService:
|
|||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(f"The dataset in unavailable, due to: {ex.description}")
|
||||
raise ValueError(f"The dataset is unavailable, due to: {ex.description}")
|
||||
|
||||
@staticmethod
|
||||
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str):
|
||||
|
|
@ -370,7 +370,7 @@ class DatasetService:
|
|||
raise ValueError("External knowledge api id is required.")
|
||||
# Update metadata fields
|
||||
dataset.updated_by = user.id if user else None
|
||||
dataset.updated_at = datetime.datetime.utcnow()
|
||||
dataset.updated_at = naive_utc_now()
|
||||
db.session.add(dataset)
|
||||
|
||||
# Update external knowledge binding
|
||||
|
|
@ -2372,7 +2372,7 @@ class SegmentService:
|
|||
)
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segmment_ids = []
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
|
|
@ -2382,10 +2382,10 @@ class SegmentService:
|
|||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
db.session.add(segment)
|
||||
real_deal_segmment_ids.append(segment.id)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
|
||||
enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
|
||||
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
elif action == "disable":
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
|
|
@ -2399,7 +2399,7 @@ class SegmentService:
|
|||
)
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segmment_ids = []
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
|
|
@ -2409,10 +2409,10 @@ class SegmentService:
|
|||
segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.add(segment)
|
||||
real_deal_segmment_ids.append(segment.id)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
|
||||
disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id)
|
||||
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
else:
|
||||
raise InvalidActionError()
|
||||
|
||||
|
|
@ -2670,7 +2670,7 @@ class SegmentService:
|
|||
# check segment
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
|
||||
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
|
|
|
|||
|
|
@ -52,6 +52,16 @@ class EnterpriseService:
|
|||
|
||||
return data.get("result", False)
|
||||
|
||||
@classmethod
|
||||
def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_codes: list[str]):
|
||||
if not app_codes:
|
||||
return {}
|
||||
body = {"userId": user_id, "appCodes": app_codes}
|
||||
data = EnterpriseRequest.send_request("POST", "/webapp/permission/batch", json=body)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
return data.get("permissions", {})
|
||||
|
||||
@classmethod
|
||||
def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings:
|
||||
if not app_id:
|
||||
|
|
|
|||
|
|
@ -8,3 +8,11 @@ class WorkflowHashNotEqualError(Exception):
|
|||
|
||||
class IsDraftWorkflowError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowIdFormatError(Exception):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -79,7 +79,10 @@ class MetadataService:
|
|||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
for document in documents:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
if not document.doc_metadata:
|
||||
doc_metadata = {}
|
||||
else:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
value = doc_metadata.pop(old_name, None)
|
||||
doc_metadata[name] = value
|
||||
document.doc_metadata = doc_metadata
|
||||
|
|
@ -109,7 +112,10 @@ class MetadataService:
|
|||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
for document in documents:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
if not document.doc_metadata:
|
||||
doc_metadata = {}
|
||||
else:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata.pop(metadata.name, None)
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
|
|
@ -137,7 +143,6 @@ class MetadataService:
|
|||
lock_key = f"dataset_metadata_lock_{dataset.id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
|
||||
dataset.built_in_field_enabled = True
|
||||
db.session.add(dataset)
|
||||
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||
if documents:
|
||||
|
|
@ -153,6 +158,7 @@ class MetadataService:
|
|||
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
dataset.built_in_field_enabled = True
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
logging.exception("Enable built-in field failed")
|
||||
|
|
@ -166,13 +172,15 @@ class MetadataService:
|
|||
lock_key = f"dataset_metadata_lock_{dataset.id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
|
||||
dataset.built_in_field_enabled = False
|
||||
db.session.add(dataset)
|
||||
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||
document_ids = []
|
||||
if documents:
|
||||
for document in documents:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
if not document.doc_metadata:
|
||||
doc_metadata = {}
|
||||
else:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata.pop(BuiltInField.document_name.value, None)
|
||||
doc_metadata.pop(BuiltInField.uploader.value, None)
|
||||
doc_metadata.pop(BuiltInField.upload_date.value, None)
|
||||
|
|
@ -181,6 +189,7 @@ class MetadataService:
|
|||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
document_ids.append(document.id)
|
||||
dataset.built_in_field_enabled = False
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
logging.exception("Disable built-in field failed")
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
import logging
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
|
||||
from models.engine import db
|
||||
|
|
@ -38,7 +39,7 @@ class PluginDataMigration:
|
|||
where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != ''
|
||||
limit 1000"""
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql))
|
||||
rs = conn.execute(sa.text(sql))
|
||||
|
||||
current_iter_count = 0
|
||||
for i in rs:
|
||||
|
|
@ -94,7 +95,7 @@ limit 1000"""
|
|||
:provider_name
|
||||
{update_retrieval_model_sql}
|
||||
where id = :record_id"""
|
||||
conn.execute(db.text(sql), params)
|
||||
conn.execute(sa.text(sql), params)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})",
|
||||
|
|
@ -148,7 +149,7 @@ limit 1000"""
|
|||
params = {"last_id": last_id or ""}
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql), params)
|
||||
rs = conn.execute(sa.text(sql), params)
|
||||
|
||||
current_iter_count = 0
|
||||
batch_updates = []
|
||||
|
|
@ -193,7 +194,7 @@ limit 1000"""
|
|||
SET {provider_column_name} = :updated_value
|
||||
WHERE id = :record_id
|
||||
"""
|
||||
conn.execute(db.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
|
||||
conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]",
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import Any, Optional
|
|||
from uuid import uuid4
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
import tqdm
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -197,7 +198,7 @@ class PluginMigration:
|
|||
"""
|
||||
with Session(db.engine) as session:
|
||||
rs = session.execute(
|
||||
db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
|
||||
sa.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
|
||||
)
|
||||
result = []
|
||||
for row in rs:
|
||||
|
|
|
|||
|
|
@ -422,7 +422,7 @@ class WorkflowDraftVariableService:
|
|||
description=conv_var.description,
|
||||
)
|
||||
draft_conv_vars.append(draft_var)
|
||||
_batch_upsert_draft_varaible(
|
||||
_batch_upsert_draft_variable(
|
||||
self._session,
|
||||
draft_conv_vars,
|
||||
policy=_UpsertPolicy.IGNORE,
|
||||
|
|
@ -434,7 +434,7 @@ class _UpsertPolicy(StrEnum):
|
|||
OVERWRITE = "overwrite"
|
||||
|
||||
|
||||
def _batch_upsert_draft_varaible(
|
||||
def _batch_upsert_draft_variable(
|
||||
session: Session,
|
||||
draft_vars: Sequence[WorkflowDraftVariable],
|
||||
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
|
||||
|
|
@ -721,7 +721,7 @@ class DraftVariableSaver:
|
|||
draft_vars = self._build_variables_from_start_mapping(outputs)
|
||||
else:
|
||||
draft_vars = self._build_variables_from_mapping(outputs)
|
||||
_batch_upsert_draft_varaible(self._session, draft_vars)
|
||||
_batch_upsert_draft_variable(self._session, draft_vars)
|
||||
|
||||
@staticmethod
|
||||
def _should_variable_be_editable(node_id: str, name: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -129,7 +129,10 @@ class WorkflowService:
|
|||
if not workflow:
|
||||
return None
|
||||
if workflow.version == Workflow.VERSION_DRAFT:
|
||||
raise IsDraftWorkflowError(f"Workflow is draft version, id={workflow_id}")
|
||||
raise IsDraftWorkflowError(
|
||||
f"Cannot use draft workflow version. Workflow ID: {workflow_id}. "
|
||||
f"Please use a published workflow version or leave workflow_id empty."
|
||||
)
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
|
||||
|
|
@ -441,9 +444,9 @@ class WorkflowService:
|
|||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
Run free workflow node
|
||||
"""
|
||||
# run draft workflow node
|
||||
# run free workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
node_execution = self._handle_node_run_result(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
from collections.abc import Callable
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from celery import shared_task # type: ignore
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
|
@ -331,7 +332,7 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str):
|
|||
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(query_sql), params)
|
||||
rs = conn.execute(sa.text(query_sql), params)
|
||||
if rs.rowcount == 0:
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,328 @@
|
|||
"""
|
||||
TestContainers-based integration test configuration for Dify API.
|
||||
|
||||
This module provides containerized test infrastructure using TestContainers library
|
||||
to spin up real database and service instances for integration testing. This approach
|
||||
ensures tests run against actual service implementations rather than mocks, providing
|
||||
more reliable and realistic test scenarios.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
from testcontainers.core.container import DockerContainer
|
||||
from testcontainers.core.waiting_utils import wait_for_logs
|
||||
from testcontainers.postgres import PostgresContainer
|
||||
from testcontainers.redis import RedisContainer
|
||||
|
||||
from app_factory import create_app
|
||||
from models import db
|
||||
|
||||
# Configure logging for test containers
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyTestContainers:
|
||||
"""
|
||||
Manages all test containers required for Dify integration tests.
|
||||
|
||||
This class provides a centralized way to manage multiple containers
|
||||
needed for comprehensive integration testing, including databases,
|
||||
caches, and search engines.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize container management with default configurations."""
|
||||
self.postgres: Optional[PostgresContainer] = None
|
||||
self.redis: Optional[RedisContainer] = None
|
||||
self.dify_sandbox: Optional[DockerContainer] = None
|
||||
self._containers_started = False
|
||||
logger.info("DifyTestContainers initialized - ready to manage test containers")
|
||||
|
||||
def start_containers_with_env(self) -> None:
|
||||
"""
|
||||
Start all required containers for integration testing.
|
||||
|
||||
This method initializes and starts PostgreSQL, Redis
|
||||
containers with appropriate configurations for Dify testing. Containers
|
||||
are started in dependency order to ensure proper initialization.
|
||||
"""
|
||||
if self._containers_started:
|
||||
logger.info("Containers already started - skipping container startup")
|
||||
return
|
||||
|
||||
logger.info("Starting test containers for Dify integration tests...")
|
||||
|
||||
# Start PostgreSQL container for main application database
|
||||
# PostgreSQL is used for storing user data, workflows, and application state
|
||||
logger.info("Initializing PostgreSQL container...")
|
||||
self.postgres = PostgresContainer(
|
||||
image="postgres:16-alpine",
|
||||
)
|
||||
self.postgres.start()
|
||||
db_host = self.postgres.get_container_host_ip()
|
||||
db_port = self.postgres.get_exposed_port(5432)
|
||||
os.environ["DB_HOST"] = db_host
|
||||
os.environ["DB_PORT"] = str(db_port)
|
||||
os.environ["DB_USERNAME"] = self.postgres.username
|
||||
os.environ["DB_PASSWORD"] = self.postgres.password
|
||||
os.environ["DB_DATABASE"] = self.postgres.dbname
|
||||
logger.info(
|
||||
"PostgreSQL container started successfully - Host: %s, Port: %s User: %s, Database: %s",
|
||||
db_host,
|
||||
db_port,
|
||||
self.postgres.username,
|
||||
self.postgres.dbname,
|
||||
)
|
||||
|
||||
# Wait for PostgreSQL to be ready
|
||||
logger.info("Waiting for PostgreSQL to be ready to accept connections...")
|
||||
wait_for_logs(self.postgres, "is ready to accept connections", timeout=30)
|
||||
logger.info("PostgreSQL container is ready and accepting connections")
|
||||
|
||||
# Install uuid-ossp extension for UUID generation
|
||||
logger.info("Installing uuid-ossp extension...")
|
||||
try:
|
||||
import psycopg2
|
||||
|
||||
conn = psycopg2.connect(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
user=self.postgres.username,
|
||||
password=self.postgres.password,
|
||||
database=self.postgres.dbname,
|
||||
)
|
||||
conn.autocommit = True
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
|
||||
cursor.close()
|
||||
conn.close()
|
||||
logger.info("uuid-ossp extension installed successfully")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to install uuid-ossp extension: %s", e)
|
||||
|
||||
# Set up storage environment variables
|
||||
os.environ["STORAGE_TYPE"] = "opendal"
|
||||
os.environ["OPENDAL_SCHEME"] = "fs"
|
||||
os.environ["OPENDAL_FS_ROOT"] = "storage"
|
||||
|
||||
# Start Redis container for caching and session management
|
||||
# Redis is used for storing session data, cache entries, and temporary data
|
||||
logger.info("Initializing Redis container...")
|
||||
self.redis = RedisContainer(image="redis:latest", port=6379)
|
||||
self.redis.start()
|
||||
redis_host = self.redis.get_container_host_ip()
|
||||
redis_port = self.redis.get_exposed_port(6379)
|
||||
os.environ["REDIS_HOST"] = redis_host
|
||||
os.environ["REDIS_PORT"] = str(redis_port)
|
||||
logger.info("Redis container started successfully - Host: %s, Port: %s", redis_host, redis_port)
|
||||
|
||||
# Wait for Redis to be ready
|
||||
logger.info("Waiting for Redis to be ready to accept connections...")
|
||||
wait_for_logs(self.redis, "Ready to accept connections", timeout=30)
|
||||
logger.info("Redis container is ready and accepting connections")
|
||||
|
||||
# Start Dify Sandbox container for code execution environment
|
||||
# Dify Sandbox provides a secure environment for executing user code
|
||||
logger.info("Initializing Dify Sandbox container...")
|
||||
self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest")
|
||||
self.dify_sandbox.with_exposed_ports(8194)
|
||||
self.dify_sandbox.env = {
|
||||
"API_KEY": "test_api_key",
|
||||
}
|
||||
self.dify_sandbox.start()
|
||||
sandbox_host = self.dify_sandbox.get_container_host_ip()
|
||||
sandbox_port = self.dify_sandbox.get_exposed_port(8194)
|
||||
os.environ["CODE_EXECUTION_ENDPOINT"] = f"http://{sandbox_host}:{sandbox_port}"
|
||||
os.environ["CODE_EXECUTION_API_KEY"] = "test_api_key"
|
||||
logger.info("Dify Sandbox container started successfully - Host: %s, Port: %s", sandbox_host, sandbox_port)
|
||||
|
||||
# Wait for Dify Sandbox to be ready
|
||||
logger.info("Waiting for Dify Sandbox to be ready to accept connections...")
|
||||
wait_for_logs(self.dify_sandbox, "config init success", timeout=60)
|
||||
logger.info("Dify Sandbox container is ready and accepting connections")
|
||||
|
||||
self._containers_started = True
|
||||
logger.info("All test containers started successfully")
|
||||
|
||||
def stop_containers(self) -> None:
|
||||
"""
|
||||
Stop and clean up all test containers.
|
||||
|
||||
This method ensures proper cleanup of all containers to prevent
|
||||
resource leaks and conflicts between test runs.
|
||||
"""
|
||||
if not self._containers_started:
|
||||
logger.info("No containers to stop - containers were not started")
|
||||
return
|
||||
|
||||
logger.info("Stopping and cleaning up test containers...")
|
||||
containers = [self.redis, self.postgres, self.dify_sandbox]
|
||||
for container in containers:
|
||||
if container:
|
||||
try:
|
||||
container_name = container.image
|
||||
logger.info("Stopping container: %s", container_name)
|
||||
container.stop()
|
||||
logger.info("Successfully stopped container: %s", container_name)
|
||||
except Exception as e:
|
||||
# Log error but don't fail the test cleanup
|
||||
logger.warning("Failed to stop container %s: %s", container, e)
|
||||
|
||||
self._containers_started = False
|
||||
logger.info("All test containers stopped and cleaned up successfully")
|
||||
|
||||
|
||||
# Global container manager instance
|
||||
_container_manager = DifyTestContainers()
|
||||
|
||||
|
||||
def _create_app_with_containers() -> Flask:
|
||||
"""
|
||||
Create Flask application configured to use test containers.
|
||||
|
||||
This function creates a Flask application instance that is configured
|
||||
to connect to the test containers instead of the default development
|
||||
or production databases.
|
||||
|
||||
Returns:
|
||||
Flask: Configured Flask application for containerized testing
|
||||
"""
|
||||
logger.info("Creating Flask application with test container configuration...")
|
||||
|
||||
# Re-create the config after environment variables have been set
|
||||
from configs import dify_config
|
||||
|
||||
# Force re-creation of config with new environment variables
|
||||
dify_config.__dict__.clear()
|
||||
dify_config.__init__()
|
||||
|
||||
# Create and configure the Flask application
|
||||
logger.info("Initializing Flask application...")
|
||||
app = create_app()
|
||||
logger.info("Flask application created successfully")
|
||||
|
||||
# Initialize database schema
|
||||
logger.info("Creating database schema...")
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
logger.info("Database schema created successfully")
|
||||
|
||||
logger.info("Flask application configured and ready for testing")
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def set_up_containers_and_env() -> Generator[DifyTestContainers, None, None]:
|
||||
"""
|
||||
Session-scoped fixture to manage test containers.
|
||||
|
||||
This fixture ensures containers are started once per test session
|
||||
and properly cleaned up when all tests are complete. This approach
|
||||
improves test performance by reusing containers across multiple tests.
|
||||
|
||||
Yields:
|
||||
DifyTestContainers: Container manager instance
|
||||
"""
|
||||
logger.info("=== Starting test session container management ===")
|
||||
_container_manager.start_containers_with_env()
|
||||
logger.info("Test containers ready for session")
|
||||
yield _container_manager
|
||||
logger.info("=== Cleaning up test session containers ===")
|
||||
_container_manager.stop_containers()
|
||||
logger.info("Test session container cleanup completed")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def flask_app_with_containers(set_up_containers_and_env) -> Flask:
|
||||
"""
|
||||
Session-scoped Flask application fixture using test containers.
|
||||
|
||||
This fixture provides a Flask application instance that is configured
|
||||
to use the test containers for all database and service connections.
|
||||
|
||||
Args:
|
||||
containers: Container manager fixture
|
||||
|
||||
Returns:
|
||||
Flask: Configured Flask application
|
||||
"""
|
||||
logger.info("=== Creating session-scoped Flask application ===")
|
||||
app = _create_app_with_containers()
|
||||
logger.info("Session-scoped Flask application created successfully")
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]:
|
||||
"""
|
||||
Request context fixture for containerized Flask application.
|
||||
|
||||
This fixture provides a Flask request context for tests that need
|
||||
to interact with the Flask application within a request scope.
|
||||
|
||||
Args:
|
||||
flask_app_with_containers: Flask application fixture
|
||||
|
||||
Yields:
|
||||
None: Request context is active during yield
|
||||
"""
|
||||
logger.debug("Creating Flask request context...")
|
||||
with flask_app_with_containers.test_request_context():
|
||||
logger.debug("Flask request context active")
|
||||
yield
|
||||
logger.debug("Flask request context closed")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]:
|
||||
"""
|
||||
Test client fixture for containerized Flask application.
|
||||
|
||||
This fixture provides a Flask test client that can be used to make
|
||||
HTTP requests to the containerized application for integration testing.
|
||||
|
||||
Args:
|
||||
flask_app_with_containers: Flask application fixture
|
||||
|
||||
Yields:
|
||||
FlaskClient: Test client instance
|
||||
"""
|
||||
logger.debug("Creating Flask test client...")
|
||||
with flask_app_with_containers.test_client() as client:
|
||||
logger.debug("Flask test client ready")
|
||||
yield client
|
||||
logger.debug("Flask test client closed")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Database session fixture for containerized testing.
|
||||
|
||||
This fixture provides a SQLAlchemy database session that is connected
|
||||
to the test PostgreSQL container, allowing tests to interact with
|
||||
the database directly.
|
||||
|
||||
Args:
|
||||
flask_app_with_containers: Flask application fixture
|
||||
|
||||
Yields:
|
||||
Session: Database session instance
|
||||
"""
|
||||
logger.debug("Creating database session...")
|
||||
with flask_app_with_containers.app_context():
|
||||
session = db.session()
|
||||
logger.debug("Database session created and ready")
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
logger.debug("Database session closed")
|
||||
|
|
@ -0,0 +1,371 @@
|
|||
import unittest
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
from models import ToolFile, UploadFile
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
|
||||
class TestStorageKeyLoader(unittest.TestCase):
|
||||
"""
|
||||
Integration tests for StorageKeyLoader class.
|
||||
|
||||
Tests the batched loading of storage keys from the database for files
|
||||
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data before each test method."""
|
||||
self.session = db.session()
|
||||
self.tenant_id = str(uuid4())
|
||||
self.user_id = str(uuid4())
|
||||
self.conversation_id = str(uuid4())
|
||||
|
||||
# Create test data that will be cleaned up after each test
|
||||
self.test_upload_files = []
|
||||
self.test_tool_files = []
|
||||
|
||||
# Create StorageKeyLoader instance
|
||||
self.loader = StorageKeyLoader(self.session, self.tenant_id)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test data after each test method."""
|
||||
self.session.rollback()
|
||||
|
||||
def _create_upload_file(
|
||||
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
|
||||
) -> UploadFile:
|
||||
"""Helper method to create an UploadFile record for testing."""
|
||||
if file_id is None:
|
||||
file_id = str(uuid4())
|
||||
if storage_key is None:
|
||||
storage_key = f"test_storage_key_{uuid4()}"
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
key=storage_key,
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
upload_file.id = file_id
|
||||
|
||||
self.session.add(upload_file)
|
||||
self.session.flush()
|
||||
self.test_upload_files.append(upload_file)
|
||||
|
||||
return upload_file
|
||||
|
||||
def _create_tool_file(
|
||||
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
|
||||
) -> ToolFile:
|
||||
"""Helper method to create a ToolFile record for testing."""
|
||||
if file_id is None:
|
||||
file_id = str(uuid4())
|
||||
if file_key is None:
|
||||
file_key = f"test_file_key_{uuid4()}"
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
tool_file = ToolFile()
|
||||
tool_file.id = file_id
|
||||
tool_file.user_id = self.user_id
|
||||
tool_file.tenant_id = tenant_id
|
||||
tool_file.conversation_id = self.conversation_id
|
||||
tool_file.file_key = file_key
|
||||
tool_file.mimetype = "text/plain"
|
||||
tool_file.original_url = "http://example.com/file.txt"
|
||||
tool_file.name = "test_tool_file.txt"
|
||||
tool_file.size = 2048
|
||||
|
||||
self.session.add(tool_file)
|
||||
self.session.flush()
|
||||
self.test_tool_files.append(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
def _create_file(
|
||||
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
|
||||
) -> File:
|
||||
"""Helper method to create a File object for testing."""
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
|
||||
file_related_id = None
|
||||
remote_url = None
|
||||
|
||||
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
|
||||
file_related_id = related_id
|
||||
elif transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
remote_url = "https://example.com/test_file.txt"
|
||||
file_related_id = related_id
|
||||
|
||||
return File(
|
||||
id=str(uuid4()), # Generate new UUID for File.id
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=transfer_method,
|
||||
related_id=file_related_id,
|
||||
remote_url=remote_url,
|
||||
filename="test_file.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="initial_key",
|
||||
)
|
||||
|
||||
def test_load_storage_keys_local_file(self):
|
||||
"""Test loading storage keys for LOCAL_FILE transfer method."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_remote_url(self):
|
||||
"""Test loading storage keys for REMOTE_URL transfer method."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_tool_file(self):
|
||||
"""Test loading storage keys for TOOL_FILE transfer method."""
|
||||
# Create test data
|
||||
tool_file = self._create_tool_file()
|
||||
file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == tool_file.file_key
|
||||
|
||||
def test_load_storage_keys_mixed_methods(self):
|
||||
"""Test batch loading with mixed transfer methods."""
|
||||
# Create test data for different transfer methods
|
||||
upload_file1 = self._create_upload_file()
|
||||
upload_file2 = self._create_upload_file()
|
||||
tool_file = self._create_tool_file()
|
||||
|
||||
file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||
file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
files = [file1, file2, file3]
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys(files)
|
||||
|
||||
# Verify all storage keys were loaded correctly
|
||||
assert file1._storage_key == upload_file1.key
|
||||
assert file2._storage_key == upload_file2.key
|
||||
assert file3._storage_key == tool_file.file_key
|
||||
|
||||
def test_load_storage_keys_empty_list(self):
|
||||
"""Test with empty file list."""
|
||||
# Should not raise any exceptions
|
||||
self.loader.load_storage_keys([])
|
||||
|
||||
def test_load_storage_keys_tenant_mismatch(self):
|
||||
"""Test tenant_id validation."""
|
||||
# Create file with different tenant_id
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(
|
||||
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
|
||||
)
|
||||
|
||||
# Should raise ValueError for tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
assert "invalid file, expected tenant_id" in str(context.value)
|
||||
|
||||
def test_load_storage_keys_missing_file_id(self):
|
||||
"""Test with None file.related_id."""
|
||||
# Create a file with valid parameters first, then manually set related_id to None
|
||||
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file.related_id = None
|
||||
|
||||
# Should raise ValueError for None file related_id
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
assert str(context.value) == "file id should not be None."
|
||||
|
||||
def test_load_storage_keys_nonexistent_upload_file_records(self):
|
||||
"""Test with missing UploadFile database records."""
|
||||
# Create file with non-existent upload file id
|
||||
non_existent_id = str(uuid4())
|
||||
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Should raise ValueError for missing record
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_nonexistent_tool_file_records(self):
|
||||
"""Test with missing ToolFile database records."""
|
||||
# Create file with non-existent tool file id
|
||||
non_existent_id = str(uuid4())
|
||||
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# Should raise ValueError for missing record
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_invalid_uuid(self):
|
||||
"""Test with invalid UUID format."""
|
||||
# Create a file with valid parameters first, then manually set invalid related_id
|
||||
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file.related_id = "invalid-uuid-format"
|
||||
|
||||
# Should raise ValueError for invalid UUID
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_batch_efficiency(self):
|
||||
"""Test batched operations use efficient queries."""
|
||||
# Create multiple files of different types
|
||||
upload_files = [self._create_upload_file() for _ in range(3)]
|
||||
tool_files = [self._create_tool_file() for _ in range(2)]
|
||||
|
||||
files = []
|
||||
files.extend(
|
||||
[self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
|
||||
)
|
||||
files.extend(
|
||||
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
|
||||
)
|
||||
|
||||
# Mock the session to count queries
|
||||
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
|
||||
self.loader.load_storage_keys(files)
|
||||
|
||||
# Should make exactly 2 queries (one for upload_files, one for tool_files)
|
||||
assert mock_scalars.call_count == 2
|
||||
|
||||
# Verify all storage keys were loaded correctly
|
||||
for i, file in enumerate(files[:3]):
|
||||
assert file._storage_key == upload_files[i].key
|
||||
for i, file in enumerate(files[3:]):
|
||||
assert file._storage_key == tool_files[i].file_key
|
||||
|
||||
def test_load_storage_keys_tenant_isolation(self):
|
||||
"""Test that tenant isolation works correctly."""
|
||||
# Create files for different tenants
|
||||
other_tenant_id = str(uuid4())
|
||||
|
||||
# Create upload file for current tenant
|
||||
upload_file_current = self._create_upload_file()
|
||||
file_current = self._create_file(
|
||||
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||
)
|
||||
|
||||
# Create upload file for other tenant (but don't add to cleanup list)
|
||||
upload_file_other = UploadFile(
|
||||
tenant_id=other_tenant_id,
|
||||
storage_type="local",
|
||||
key="other_tenant_key",
|
||||
name="other_file.txt",
|
||||
size=1024,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
upload_file_other.id = str(uuid4())
|
||||
self.session.add(upload_file_other)
|
||||
self.session.flush()
|
||||
|
||||
# Create file for other tenant but try to load with current tenant's loader
|
||||
file_other = self._create_file(
|
||||
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
|
||||
)
|
||||
|
||||
# Should raise ValueError due to tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file_other])
|
||||
|
||||
assert "invalid file, expected tenant_id" in str(context.value)
|
||||
|
||||
# Current tenant's file should still work
|
||||
self.loader.load_storage_keys([file_current])
|
||||
assert file_current._storage_key == upload_file_current.key
|
||||
|
||||
def test_load_storage_keys_mixed_tenant_batch(self):
|
||||
"""Test batch with mixed tenant files (should fail on first mismatch)."""
|
||||
# Create files for current tenant
|
||||
upload_file_current = self._create_upload_file()
|
||||
file_current = self._create_file(
|
||||
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||
)
|
||||
|
||||
# Create file for different tenant
|
||||
other_tenant_id = str(uuid4())
|
||||
file_other = self._create_file(
|
||||
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
|
||||
)
|
||||
|
||||
# Should raise ValueError on tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file_current, file_other])
|
||||
|
||||
assert "invalid file, expected tenant_id" in str(context.value)
|
||||
|
||||
def test_load_storage_keys_duplicate_file_ids(self):
|
||||
"""Test handling of duplicate file IDs in the batch."""
|
||||
# Create upload file
|
||||
upload_file = self._create_upload_file()
|
||||
|
||||
# Create two File objects with same related_id
|
||||
file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Should handle duplicates gracefully
|
||||
self.loader.load_storage_keys([file1, file2])
|
||||
|
||||
# Both files should have the same storage key
|
||||
assert file1._storage_key == upload_file.key
|
||||
assert file2._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_session_isolation(self):
|
||||
"""Test that the loader uses the provided session correctly."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Create loader with different session (same underlying connection)
|
||||
|
||||
with Session(bind=db.engine) as other_session:
|
||||
other_loader = StorageKeyLoader(other_session, self.tenant_id)
|
||||
with pytest.raises(ValueError):
|
||||
other_loader.load_storage_keys([file])
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,739 @@
|
|||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from models import App, Workflow
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import (
|
||||
UpdateNotSupportedError,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableService:
|
||||
"""
|
||||
Comprehensive integration tests for WorkflowDraftVariableService using testcontainers.
|
||||
|
||||
This test class covers all major functionality of the WorkflowDraftVariableService:
|
||||
- CRUD operations for workflow draft variables (Create, Read, Update, Delete)
|
||||
- Variable listing and filtering by type (conversation, system, node)
|
||||
- Variable updates and resets with proper validation
|
||||
- Variable deletion operations at different scopes
|
||||
- Special functionality like prefill and conversation ID retrieval
|
||||
- Error handling for various edge cases and invalid operations
|
||||
|
||||
All tests use the testcontainers infrastructure to ensure proper database isolation
|
||||
and realistic testing environment with actual database interactions.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""
|
||||
Mock setup for external service dependencies.
|
||||
|
||||
WorkflowDraftVariableService doesn't have external dependencies that need mocking,
|
||||
so this fixture returns an empty dictionary to maintain consistency with other test classes.
|
||||
This ensures the test structure remains consistent across different service test files.
|
||||
"""
|
||||
# WorkflowDraftVariableService doesn't have external dependencies that need mocking
|
||||
return {}
|
||||
|
||||
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None):
|
||||
"""
|
||||
Helper method to create a test app with realistic data for testing.
|
||||
|
||||
This method creates a complete App instance with all required fields populated
|
||||
using Faker for generating realistic test data. The app is configured for
|
||||
workflow mode to support workflow draft variable testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies (unused in this service)
|
||||
fake: Faker instance for generating test data, creates new instance if not provided
|
||||
|
||||
Returns:
|
||||
App: Created test app instance with all required fields populated
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
app = App()
|
||||
app.id = fake.uuid4()
|
||||
app.tenant_id = fake.uuid4()
|
||||
app.name = fake.company()
|
||||
app.description = fake.text()
|
||||
app.mode = "workflow"
|
||||
app.icon_type = "emoji"
|
||||
app.icon = "🤖"
|
||||
app.icon_background = "#FFEAD5"
|
||||
app.enable_site = True
|
||||
app.enable_api = True
|
||||
app.created_by = fake.uuid4()
|
||||
app.updated_by = app.created_by
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
return app
|
||||
|
||||
def _create_test_workflow(self, db_session_with_containers, app, fake=None):
|
||||
"""
|
||||
Helper method to create a test workflow associated with an app.
|
||||
|
||||
This method creates a Workflow instance using the proper factory method
|
||||
to ensure all required fields are set correctly. The workflow is configured
|
||||
as a draft version with basic graph structure for testing workflow variables.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app: The app to associate the workflow with
|
||||
fake: Faker instance for generating test data, creates new instance if not provided
|
||||
|
||||
Returns:
|
||||
Workflow: Created test workflow instance with proper configuration
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
workflow = Workflow.new(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features="{}",
|
||||
created_by=app.created_by,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
return workflow
|
||||
|
||||
def _create_test_variable(
|
||||
self, db_session_with_containers, app_id, node_id, name, value, variable_type="conversation", fake=None
|
||||
):
|
||||
"""
|
||||
Helper method to create a test workflow draft variable with proper configuration.
|
||||
|
||||
This method creates different types of variables (conversation, system, node) using
|
||||
the appropriate factory methods to ensure proper initialization. Each variable type
|
||||
has specific requirements and this method handles the creation logic for all types.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
app_id: ID of the app to associate the variable with
|
||||
node_id: ID of the node (or special constants like CONVERSATION_VARIABLE_NODE_ID)
|
||||
name: Name of the variable for identification
|
||||
value: StringSegment value for the variable content
|
||||
variable_type: Type of variable ("conversation", "system", "node") determining creation method
|
||||
fake: Faker instance for generating test data, creates new instance if not provided
|
||||
|
||||
Returns:
|
||||
WorkflowDraftVariable: Created test variable instance with proper type configuration
|
||||
"""
|
||||
fake = fake or Faker()
|
||||
if variable_type == "conversation":
|
||||
# Create conversation variable using the appropriate factory method
|
||||
variable = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=app_id,
|
||||
name=name,
|
||||
value=value,
|
||||
description=fake.text(max_nb_chars=20),
|
||||
)
|
||||
elif variable_type == "system":
|
||||
# Create system variable with editable flag and execution context
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=app_id,
|
||||
name=name,
|
||||
value=value,
|
||||
node_execution_id=fake.uuid4(),
|
||||
editable=True,
|
||||
)
|
||||
else: # node variable
|
||||
# Create node variable with visibility and editability settings
|
||||
variable = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
name=name,
|
||||
value=value,
|
||||
node_execution_id=fake.uuid4(),
|
||||
visible=True,
|
||||
editable=True,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(variable)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting a single variable by ID successfully.
|
||||
|
||||
This test verifies that the service can retrieve a specific variable
|
||||
by its ID and that the returned variable contains the correct data.
|
||||
It ensures the basic CRUD read operation works correctly for workflow draft variables.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
test_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_variable = service.get_variable(variable.id)
|
||||
assert retrieved_variable is not None
|
||||
assert retrieved_variable.id == variable.id
|
||||
assert retrieved_variable.name == "test_var"
|
||||
assert retrieved_variable.app_id == app.id
|
||||
assert retrieved_variable.get_value().value == test_value.value
|
||||
|
||||
def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting a variable that doesn't exist.
|
||||
|
||||
This test verifies that the service returns None when trying to
|
||||
retrieve a variable with a non-existent ID. This ensures proper
|
||||
handling of missing data scenarios.
|
||||
"""
|
||||
fake = Faker()
|
||||
non_existent_id = fake.uuid4()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_variable = service.get_variable(non_existent_id)
|
||||
assert retrieved_variable is None
|
||||
|
||||
def test_get_draft_variables_by_selectors_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting variables by selectors successfully.
|
||||
|
||||
This test verifies that the service can retrieve multiple variables
|
||||
using selector pairs (node_id, variable_name) and returns the correct
|
||||
variables for each selector. This is useful for bulk variable retrieval
|
||||
operations in workflow execution contexts.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
var1_value = StringSegment(value=fake.word())
|
||||
var2_value = StringSegment(value=fake.word())
|
||||
var3_value = StringSegment(value=fake.word())
|
||||
var1 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var1", var1_value, fake=fake
|
||||
)
|
||||
var2 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var2", var2_value, fake=fake
|
||||
)
|
||||
var3 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, "test_node_1", "var3", var3_value, "node", fake=fake
|
||||
)
|
||||
selectors = [
|
||||
[CONVERSATION_VARIABLE_NODE_ID, "var1"],
|
||||
[CONVERSATION_VARIABLE_NODE_ID, "var2"],
|
||||
["test_node_1", "var3"],
|
||||
]
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors)
|
||||
assert len(retrieved_variables) == 3
|
||||
var_names = [var.name for var in retrieved_variables]
|
||||
assert "var1" in var_names
|
||||
assert "var2" in var_names
|
||||
assert "var3" in var_names
|
||||
for var in retrieved_variables:
|
||||
if var.name == "var1":
|
||||
assert var.get_value().value == var1_value.value
|
||||
elif var.name == "var2":
|
||||
assert var.get_value().value == var2_value.value
|
||||
elif var.name == "var3":
|
||||
assert var.get_value().value == var3_value.value
|
||||
|
||||
def test_list_variables_without_values_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test listing variables without values successfully with pagination.
|
||||
|
||||
This test verifies that the service can list variables with pagination
|
||||
and that the returned variables don't include their values (for performance).
|
||||
This is important for scenarios where only variable metadata is needed
|
||||
without loading the actual content.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
for i in range(5):
|
||||
test_value = StringSegment(value=fake.numerify("value##"))
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_variables_without_values(app.id, page=1, limit=3)
|
||||
assert result.total == 5
|
||||
assert len(result.variables) == 3
|
||||
assert result.variables[0].created_at >= result.variables[1].created_at
|
||||
assert result.variables[1].created_at >= result.variables[2].created_at
|
||||
for var in result.variables:
|
||||
assert var.name is not None
|
||||
assert var.app_id == app.id
|
||||
|
||||
def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test listing variables for a specific node successfully.
|
||||
|
||||
This test verifies that the service can filter and return only
|
||||
variables associated with a specific node ID. This is crucial for
|
||||
workflow execution where variables need to be scoped to specific nodes.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
node_id = fake.word()
|
||||
var1_value = StringSegment(value=fake.word())
|
||||
var2_value = StringSegment(value=fake.word())
|
||||
var3_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(db_session_with_containers, app.id, node_id, "var1", var1_value, "node", fake=fake)
|
||||
self._create_test_variable(db_session_with_containers, app.id, node_id, "var2", var3_value, "node", fake=fake)
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, "other_node", "var3", var2_value, "node", fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_node_variables(app.id, node_id)
|
||||
assert len(result.variables) == 2
|
||||
for var in result.variables:
|
||||
assert var.node_id == node_id
|
||||
assert var.app_id == app.id
|
||||
var_names = [var.name for var in result.variables]
|
||||
assert "var1" in var_names
|
||||
assert "var2" in var_names
|
||||
assert "var3" not in var_names
|
||||
|
||||
def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test listing conversation variables successfully.
|
||||
|
||||
This test verifies that the service can filter and return only
|
||||
conversation variables, excluding system and node variables.
|
||||
Conversation variables are user-facing variables that can be
|
||||
modified during conversation flows.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
conv_var1_value = StringSegment(value=fake.word())
|
||||
conv_var2_value = StringSegment(value=fake.word())
|
||||
conv_var1 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var1", conv_var1_value, fake=fake
|
||||
)
|
||||
conv_var2 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var2", conv_var2_value, fake=fake
|
||||
)
|
||||
sys_var_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var", sys_var_value, "system", fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_conversation_variables(app.id)
|
||||
assert len(result.variables) == 2
|
||||
for var in result.variables:
|
||||
assert var.node_id == CONVERSATION_VARIABLE_NODE_ID
|
||||
assert var.app_id == app.id
|
||||
assert var.get_variable_type() == DraftVariableType.CONVERSATION
|
||||
var_names = [var.name for var in result.variables]
|
||||
assert "conv_var1" in var_names
|
||||
assert "conv_var2" in var_names
|
||||
assert "sys_var" not in var_names
|
||||
|
||||
def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test updating a variable's name and value successfully.
|
||||
|
||||
This test verifies that the service can update both the name and value
|
||||
of an editable variable and that the changes are persisted correctly.
|
||||
It also checks that the last_edited_at timestamp is updated appropriately.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
original_value = StringSegment(value=fake.word())
|
||||
new_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
"original_name",
|
||||
original_value,
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
updated_variable = service.update_variable(variable, name="new_name", value=new_value)
|
||||
assert updated_variable.name == "new_name"
|
||||
assert updated_variable.get_value().value == new_value.value
|
||||
assert updated_variable.last_edited_at is not None
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(variable)
|
||||
assert variable.name == "new_name"
|
||||
assert variable.get_value().value == new_value.value
|
||||
assert variable.last_edited_at is not None
|
||||
|
||||
def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test that updating a non-editable variable raises an exception.
|
||||
|
||||
This test verifies that the service properly prevents updates to
|
||||
variables that are not marked as editable. This is important for
|
||||
maintaining data integrity and preventing unauthorized modifications
|
||||
to system-controlled variables.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
original_value = StringSegment(value=fake.word())
|
||||
new_value = StringSegment(value=fake.word())
|
||||
variable = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=app.id,
|
||||
name=fake.word(), # This is typically not editable
|
||||
value=original_value,
|
||||
node_execution_id=fake.uuid4(),
|
||||
editable=False, # Set as non-editable
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(variable)
|
||||
db.session.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
with pytest.raises(UpdateNotSupportedError) as exc_info:
|
||||
service.update_variable(variable, name="new_name", value=new_value)
|
||||
assert "variable not support updating" in str(exc_info.value)
|
||||
assert variable.id in str(exc_info.value)
|
||||
|
||||
def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test resetting conversation variable successfully.
|
||||
|
||||
This test verifies that the service can reset a conversation variable
|
||||
to its default value and clear the last_edited_at timestamp.
|
||||
This functionality is useful for reverting user modifications
|
||||
back to the original workflow configuration.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake)
|
||||
from core.variables.variables import StringVariable
|
||||
|
||||
conv_var = StringVariable(
|
||||
id=fake.uuid4(),
|
||||
name="test_conv_var",
|
||||
value="default_value",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"],
|
||||
)
|
||||
workflow.conversation_variables = [conv_var]
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
modified_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
"test_conv_var",
|
||||
modified_value,
|
||||
fake=fake,
|
||||
)
|
||||
variable.last_edited_at = fake.date_time()
|
||||
db.session.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
reset_variable = service.reset_variable(workflow, variable)
|
||||
assert reset_variable is not None
|
||||
assert reset_variable.get_value().value == "default_value"
|
||||
assert reset_variable.last_edited_at is None
|
||||
db.session.refresh(variable)
|
||||
assert variable.get_value().value == "default_value"
|
||||
assert variable.last_edited_at is None
|
||||
|
||||
def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test deleting a single variable successfully.
|
||||
|
||||
This test verifies that the service can delete a specific variable
|
||||
and that it's properly removed from the database. It ensures that
|
||||
the deletion operation is atomic and complete.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
test_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_variable(variable)
|
||||
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None
|
||||
|
||||
def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test deleting all variables for a workflow successfully.
|
||||
|
||||
This test verifies that the service can delete all variables
|
||||
associated with a specific app/workflow. This is useful for
|
||||
cleanup operations when workflows are deleted or reset.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
for i in range(3):
|
||||
test_value = StringSegment(value=fake.numerify("value##"))
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), test_value, fake=fake
|
||||
)
|
||||
other_app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
other_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, other_app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), other_value, fake=fake
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
|
||||
other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
|
||||
assert len(app_variables) == 3
|
||||
assert len(other_app_variables) == 1
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_workflow_variables(app.id)
|
||||
app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
|
||||
other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
|
||||
assert len(app_variables_after) == 0
|
||||
assert len(other_app_variables_after) == 1
|
||||
|
||||
def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test deleting all variables for a specific node successfully.
|
||||
|
||||
This test verifies that the service can delete all variables
|
||||
associated with a specific node while preserving variables
|
||||
for other nodes and conversation variables. This is important
|
||||
for node-specific cleanup operations in workflow management.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
node_id = fake.word()
|
||||
for i in range(2):
|
||||
test_value = StringSegment(value=fake.numerify("node_value##"))
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, node_id, fake.word(), test_value, "node", fake=fake
|
||||
)
|
||||
other_node_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, "other_node", fake.word(), other_node_value, "node", fake=fake
|
||||
)
|
||||
conv_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, fake.word(), conv_value, fake=fake
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
|
||||
other_node_variables = (
|
||||
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
|
||||
)
|
||||
conv_variables = (
|
||||
db.session.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
.all()
|
||||
)
|
||||
assert len(target_node_variables) == 2
|
||||
assert len(other_node_variables) == 1
|
||||
assert len(conv_variables) == 1
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_node_variables(app.id, node_id)
|
||||
target_node_variables_after = (
|
||||
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
|
||||
)
|
||||
other_node_variables_after = (
|
||||
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
|
||||
)
|
||||
conv_variables_after = (
|
||||
db.session.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
.all()
|
||||
)
|
||||
assert len(target_node_variables_after) == 0
|
||||
assert len(other_node_variables_after) == 1
|
||||
assert len(conv_variables_after) == 1
|
||||
|
||||
def test_prefill_conversation_variable_default_values_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test prefill conversation variable default values successfully.
|
||||
|
||||
This test verifies that the service can automatically create
|
||||
conversation variables with default values based on the workflow
|
||||
configuration when none exist. This is important for initializing
|
||||
workflow variables with proper defaults from the workflow definition.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake)
|
||||
from core.variables.variables import StringVariable
|
||||
|
||||
conv_var1 = StringVariable(
|
||||
id=fake.uuid4(),
|
||||
name="conv_var1",
|
||||
value="default_value1",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var1"],
|
||||
)
|
||||
conv_var2 = StringVariable(
|
||||
id=fake.uuid4(),
|
||||
name="conv_var2",
|
||||
value="default_value2",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"],
|
||||
)
|
||||
workflow.conversation_variables = [conv_var1, conv_var2]
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.prefill_conversation_variable_default_values(workflow)
|
||||
draft_variables = (
|
||||
db.session.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
.all()
|
||||
)
|
||||
assert len(draft_variables) == 2
|
||||
var_names = [var.name for var in draft_variables]
|
||||
assert "conv_var1" in var_names
|
||||
assert "conv_var2" in var_names
|
||||
for var in draft_variables:
|
||||
assert var.app_id == app.id
|
||||
assert var.node_id == CONVERSATION_VARIABLE_NODE_ID
|
||||
assert var.editable is True
|
||||
assert var.get_variable_type() == DraftVariableType.CONVERSATION
|
||||
|
||||
def test_get_conversation_id_from_draft_variable_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting conversation ID from draft variable successfully.
|
||||
|
||||
This test verifies that the service can extract the conversation ID
|
||||
from a system variable named "conversation_id". This is important
|
||||
for maintaining conversation context across workflow executions.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
conversation_id = fake.uuid4()
|
||||
conv_id_value = StringSegment(value=conversation_id)
|
||||
self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
app.id,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
"conversation_id",
|
||||
conv_id_value,
|
||||
"system",
|
||||
fake=fake,
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
|
||||
assert retrieved_conv_id == conversation_id
|
||||
|
||||
def test_get_conversation_id_from_draft_variable_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting conversation ID when it doesn't exist.
|
||||
|
||||
This test verifies that the service returns None when no
|
||||
conversation_id variable exists for the app. This ensures
|
||||
proper handling of missing conversation context scenarios.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
|
||||
assert retrieved_conv_id is None
|
||||
|
||||
def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test listing system variables successfully.
|
||||
|
||||
This test verifies that the service can filter and return only
|
||||
system variables, excluding conversation and node variables.
|
||||
System variables are internal variables used by the workflow
|
||||
engine for maintaining state and context.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
sys_var1_value = StringSegment(value=fake.word())
|
||||
sys_var2_value = StringSegment(value=fake.word())
|
||||
sys_var1 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var1", sys_var1_value, "system", fake=fake
|
||||
)
|
||||
sys_var2 = self._create_test_variable(
|
||||
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "sys_var2", sys_var2_value, "system", fake=fake
|
||||
)
|
||||
conv_var_value = StringSegment(value=fake.word())
|
||||
self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var", conv_var_value, fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
result = service.list_system_variables(app.id)
|
||||
assert len(result.variables) == 2
|
||||
for var in result.variables:
|
||||
assert var.node_id == SYSTEM_VARIABLE_NODE_ID
|
||||
assert var.app_id == app.id
|
||||
assert var.get_variable_type() == DraftVariableType.SYS
|
||||
var_names = [var.name for var in result.variables]
|
||||
assert "sys_var1" in var_names
|
||||
assert "sys_var2" in var_names
|
||||
assert "conv_var" not in var_names
|
||||
|
||||
def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting variables by name successfully for different types.
|
||||
|
||||
This test verifies that the service can retrieve variables by name
|
||||
for different variable types (conversation, system, node). This
|
||||
functionality is important for variable lookup operations during
|
||||
workflow execution and user interactions.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
test_value = StringSegment(value=fake.word())
|
||||
conv_var = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_conv_var", test_value, fake=fake
|
||||
)
|
||||
sys_var = self._create_test_variable(
|
||||
db_session_with_containers, app.id, SYSTEM_VARIABLE_NODE_ID, "test_sys_var", test_value, "system", fake=fake
|
||||
)
|
||||
node_var = self._create_test_variable(
|
||||
db_session_with_containers, app.id, "test_node", "test_node_var", test_value, "node", fake=fake
|
||||
)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var")
|
||||
assert retrieved_conv_var is not None
|
||||
assert retrieved_conv_var.name == "test_conv_var"
|
||||
assert retrieved_conv_var.node_id == CONVERSATION_VARIABLE_NODE_ID
|
||||
retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var")
|
||||
assert retrieved_sys_var is not None
|
||||
assert retrieved_sys_var.name == "test_sys_var"
|
||||
assert retrieved_sys_var.node_id == SYSTEM_VARIABLE_NODE_ID
|
||||
retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var")
|
||||
assert retrieved_node_var is not None
|
||||
assert retrieved_node_var.name == "test_node_var"
|
||||
assert retrieved_node_var.node_id == "test_node"
|
||||
|
||||
def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting variables by name when they don't exist.
|
||||
|
||||
This test verifies that the service returns None when trying to
|
||||
retrieve variables by name that don't exist. This ensures proper
|
||||
handling of missing variable scenarios for all variable types.
|
||||
"""
|
||||
fake = Faker()
|
||||
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake)
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var")
|
||||
assert retrieved_conv_var is None
|
||||
retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var")
|
||||
assert retrieved_sys_var is None
|
||||
retrieved_node_var = service.get_node_variable(app.id, "test_node", "non_existent_node_var")
|
||||
assert retrieved_node_var is None
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
import pytest
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
|
||||
CODE_LANGUAGE = "unsupported_language"
|
||||
|
||||
|
||||
def test_unsupported_with_code_template():
|
||||
with pytest.raises(CodeExecutionError) as e:
|
||||
CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={})
|
||||
assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
from textwrap import dedent
|
||||
|
||||
from .test_utils import CodeExecutorTestMixin
|
||||
|
||||
|
||||
class TestJavaScriptCodeExecutor(CodeExecutorTestMixin):
|
||||
"""Test class for JavaScript code executor functionality."""
|
||||
|
||||
def test_javascript_plain(self, flask_app_with_containers):
|
||||
"""Test basic JavaScript code execution with console.log output"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
code = 'console.log("Hello World")'
|
||||
result_message = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code)
|
||||
assert result_message == "Hello World\n"
|
||||
|
||||
def test_javascript_json(self, flask_app_with_containers):
|
||||
"""Test JavaScript code execution with JSON output"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
code = dedent("""
|
||||
obj = {'Hello': 'World'}
|
||||
console.log(JSON.stringify(obj))
|
||||
""")
|
||||
result = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code)
|
||||
assert result == '{"Hello":"World"}\n'
|
||||
|
||||
def test_javascript_with_code_template(self, flask_app_with_containers):
|
||||
"""Test JavaScript workflow code template execution with inputs"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
JavascriptCodeProvider, _ = self.javascript_imports
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JAVASCRIPT,
|
||||
code=JavascriptCodeProvider.get_default_code(),
|
||||
inputs={"arg1": "Hello", "arg2": "World"},
|
||||
)
|
||||
assert result == {"result": "HelloWorld"}
|
||||
|
||||
def test_javascript_get_runner_script(self, flask_app_with_containers):
|
||||
"""Test JavaScript template transformer runner script generation"""
|
||||
_, NodeJsTemplateTransformer = self.javascript_imports
|
||||
|
||||
runner_script = NodeJsTemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(NodeJsTemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(NodeJsTemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(NodeJsTemplateTransformer._result_tag) == 2
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
import base64
|
||||
|
||||
from .test_utils import CodeExecutorTestMixin
|
||||
|
||||
|
||||
class TestJinja2CodeExecutor(CodeExecutorTestMixin):
|
||||
"""Test class for Jinja2 code executor functionality."""
|
||||
|
||||
def test_jinja2(self, flask_app_with_containers):
|
||||
"""Test basic Jinja2 template execution with variable substitution"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
_, Jinja2TemplateTransformer = self.jinja2_imports
|
||||
|
||||
template = "Hello {{template}}"
|
||||
inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8")
|
||||
code = (
|
||||
Jinja2TemplateTransformer.get_runner_script()
|
||||
.replace(Jinja2TemplateTransformer._code_placeholder, template)
|
||||
.replace(Jinja2TemplateTransformer._inputs_placeholder, inputs)
|
||||
)
|
||||
result = CodeExecutor.execute_code(
|
||||
language=CodeLanguage.JINJA2, preload=Jinja2TemplateTransformer.get_preload_script(), code=code
|
||||
)
|
||||
assert result == "<<RESULT>>Hello World<<RESULT>>\n"
|
||||
|
||||
def test_jinja2_with_code_template(self, flask_app_with_containers):
|
||||
"""Test Jinja2 workflow code template execution with inputs"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code="Hello {{template}}", inputs={"template": "World"}
|
||||
)
|
||||
assert result == {"result": "Hello World"}
|
||||
|
||||
def test_jinja2_get_runner_script(self, flask_app_with_containers):
|
||||
"""Test Jinja2 template transformer runner script generation"""
|
||||
_, Jinja2TemplateTransformer = self.jinja2_imports
|
||||
|
||||
runner_script = Jinja2TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
from textwrap import dedent
|
||||
|
||||
from .test_utils import CodeExecutorTestMixin
|
||||
|
||||
|
||||
class TestPython3CodeExecutor(CodeExecutorTestMixin):
|
||||
"""Test class for Python3 code executor functionality."""
|
||||
|
||||
def test_python3_plain(self, flask_app_with_containers):
|
||||
"""Test basic Python3 code execution with print output"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
code = 'print("Hello World")'
|
||||
result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code)
|
||||
assert result == "Hello World\n"
|
||||
|
||||
def test_python3_json(self, flask_app_with_containers):
|
||||
"""Test Python3 code execution with JSON output"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
|
||||
code = dedent("""
|
||||
import json
|
||||
print(json.dumps({'Hello': 'World'}))
|
||||
""")
|
||||
result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code)
|
||||
assert result == '{"Hello": "World"}\n'
|
||||
|
||||
def test_python3_with_code_template(self, flask_app_with_containers):
|
||||
"""Test Python3 workflow code template execution with inputs"""
|
||||
CodeExecutor, CodeLanguage = self.code_executor_imports
|
||||
Python3CodeProvider, _ = self.python3_imports
|
||||
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.PYTHON3,
|
||||
code=Python3CodeProvider.get_default_code(),
|
||||
inputs={"arg1": "Hello", "arg2": "World"},
|
||||
)
|
||||
assert result == {"result": "HelloWorld"}
|
||||
|
||||
def test_python3_get_runner_script(self, flask_app_with_containers):
|
||||
"""Test Python3 template transformer runner script generation"""
|
||||
_, Python3TemplateTransformer = self.python3_imports
|
||||
|
||||
runner_script = Python3TemplateTransformer.get_runner_script()
|
||||
assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1
|
||||
assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1
|
||||
assert runner_script.count(Python3TemplateTransformer._result_tag) == 2
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
Test utilities for code executor integration tests.
|
||||
|
||||
This module provides lazy import functions to avoid module loading issues
|
||||
that occur when modules are imported before the flask_app_with_containers fixture
|
||||
has set up the proper environment variables and configuration.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def force_reload_code_executor():
|
||||
"""
|
||||
Force reload the code_executor module to reinitialize code_execution_endpoint_url.
|
||||
|
||||
This function should be called after setting up environment variables
|
||||
to ensure the code_execution_endpoint_url is initialized with the correct value.
|
||||
"""
|
||||
try:
|
||||
import core.helper.code_executor.code_executor
|
||||
|
||||
importlib.reload(core.helper.code_executor.code_executor)
|
||||
except Exception as e:
|
||||
# Log the error but don't fail the test
|
||||
print(f"Warning: Failed to reload code_executor module: {e}")
|
||||
|
||||
|
||||
def get_code_executor_imports():
|
||||
"""
|
||||
Lazy import function for core CodeExecutor classes.
|
||||
|
||||
Returns:
|
||||
tuple: (CodeExecutor, CodeLanguage) classes
|
||||
"""
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
|
||||
return CodeExecutor, CodeLanguage
|
||||
|
||||
|
||||
def get_javascript_imports():
|
||||
"""
|
||||
Lazy import function for JavaScript-specific modules.
|
||||
|
||||
Returns:
|
||||
tuple: (JavascriptCodeProvider, NodeJsTemplateTransformer) classes
|
||||
"""
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
|
||||
|
||||
return JavascriptCodeProvider, NodeJsTemplateTransformer
|
||||
|
||||
|
||||
def get_python3_imports():
|
||||
"""
|
||||
Lazy import function for Python3-specific modules.
|
||||
|
||||
Returns:
|
||||
tuple: (Python3CodeProvider, Python3TemplateTransformer) classes
|
||||
"""
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
|
||||
return Python3CodeProvider, Python3TemplateTransformer
|
||||
|
||||
|
||||
def get_jinja2_imports():
|
||||
"""
|
||||
Lazy import function for Jinja2-specific modules.
|
||||
|
||||
Returns:
|
||||
tuple: (None, Jinja2TemplateTransformer) classes
|
||||
"""
|
||||
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
|
||||
|
||||
return None, Jinja2TemplateTransformer
|
||||
|
||||
|
||||
class CodeExecutorTestMixin:
|
||||
"""
|
||||
Mixin class providing lazy import methods for code executor tests.
|
||||
|
||||
This mixin helps avoid module loading issues by deferring imports
|
||||
until after the flask_app_with_containers fixture has set up the environment.
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
"""
|
||||
Setup method called before each test method.
|
||||
Force reload the code_executor module to ensure fresh initialization.
|
||||
"""
|
||||
force_reload_code_executor()
|
||||
|
||||
@property
|
||||
def code_executor_imports(self):
|
||||
"""Property to get CodeExecutor and CodeLanguage classes."""
|
||||
return get_code_executor_imports()
|
||||
|
||||
@property
|
||||
def javascript_imports(self):
|
||||
"""Property to get JavaScript-specific classes."""
|
||||
return get_javascript_imports()
|
||||
|
||||
@property
|
||||
def python3_imports(self):
|
||||
"""Property to get Python3-specific classes."""
|
||||
return get_python3_imports()
|
||||
|
||||
@property
|
||||
def jinja2_imports(self):
|
||||
"""Property to get Jinja2-specific classes."""
|
||||
return get_jinja2_imports()
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from core.variables.types import SegmentType
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
|
||||
|
||||
class TestSegmentTypeIsArrayType:
|
||||
|
|
@ -17,7 +17,6 @@ class TestSegmentTypeIsArrayType:
|
|||
value is tested for the is_array_type method.
|
||||
"""
|
||||
# Arrange
|
||||
all_segment_types = set(SegmentType)
|
||||
expected_array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
|
|
@ -58,3 +57,27 @@ class TestSegmentTypeIsArrayType:
|
|||
for seg_type in enum_values:
|
||||
is_array = seg_type.is_array_type()
|
||||
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
|
||||
|
||||
|
||||
class TestSegmentTypeIsValidArrayValidation:
|
||||
"""
|
||||
Test SegmentType.is_valid with array types using different validation strategies.
|
||||
"""
|
||||
|
||||
def test_array_validation_all_success(self):
|
||||
value = ["hello", "world", "foo"]
|
||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
|
||||
|
||||
def test_array_validation_all_fail(self):
|
||||
value = ["hello", 123, "world"]
|
||||
# Should return False, since 123 is not a string
|
||||
assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL)
|
||||
|
||||
def test_array_validation_first(self):
|
||||
value = ["hello", 123, None]
|
||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST)
|
||||
|
||||
def test_array_validation_none(self):
|
||||
value = [1, 2, 3]
|
||||
# validation is None, skip
|
||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,168 @@
|
|||
import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
|
||||
|
||||
class TestClearFreePlanTenantExpiredLogs:
|
||||
"""Unit tests for ClearFreePlanTenantExpiredLogs._clear_message_related_tables method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock database session."""
|
||||
session = Mock(spec=Session)
|
||||
session.query.return_value.filter.return_value.all.return_value = []
|
||||
session.query.return_value.filter.return_value.delete.return_value = 0
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(self):
|
||||
"""Create a mock storage object."""
|
||||
storage = Mock()
|
||||
storage.save.return_value = None
|
||||
return storage
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message_ids(self):
|
||||
"""Sample message IDs for testing."""
|
||||
return ["msg-1", "msg-2", "msg-3"]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_records(self):
|
||||
"""Sample records for testing."""
|
||||
records = []
|
||||
for i in range(3):
|
||||
record = Mock()
|
||||
record.id = f"record-{i}"
|
||||
record.to_dict.return_value = {
|
||||
"id": f"record-{i}",
|
||||
"message_id": f"msg-{i}",
|
||||
"created_at": datetime.datetime.now().isoformat(),
|
||||
}
|
||||
records.append(record)
|
||||
return records
|
||||
|
||||
def test_clear_message_related_tables_empty_message_ids(self, mock_session):
|
||||
"""Test that method returns early when message_ids is empty."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", [])
|
||||
|
||||
# Should not call any database operations
|
||||
mock_session.query.assert_not_called()
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids):
|
||||
"""Test when no related records are found."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call query for each related table but find no records
|
||||
assert mock_session.query.call_count > 0
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
def test_clear_message_related_tables_with_records_and_to_dict(
|
||||
self, mock_session, sample_message_ids, sample_records
|
||||
):
|
||||
"""Test when records are found and have to_dict method."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call to_dict on each record (called once per table, so 7 times total)
|
||||
for record in sample_records:
|
||||
assert record.to_dict.call_count == 7
|
||||
|
||||
# Should save backup data
|
||||
assert mock_storage.save.call_count > 0
|
||||
|
||||
def test_clear_message_related_tables_with_records_no_to_dict(self, mock_session, sample_message_ids):
|
||||
"""Test when records are found but don't have to_dict method."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
# Create records without to_dict method
|
||||
records = []
|
||||
for i in range(2):
|
||||
record = Mock()
|
||||
mock_table = Mock()
|
||||
mock_id_column = Mock()
|
||||
mock_id_column.name = "id"
|
||||
mock_message_id_column = Mock()
|
||||
mock_message_id_column.name = "message_id"
|
||||
mock_table.columns = [mock_id_column, mock_message_id_column]
|
||||
record.__table__ = mock_table
|
||||
record.id = f"record-{i}"
|
||||
record.message_id = f"msg-{i}"
|
||||
del record.to_dict
|
||||
records.append(record)
|
||||
|
||||
# Mock records for first table only, empty for others
|
||||
mock_session.query.return_value.filter.return_value.all.side_effect = [
|
||||
records,
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
]
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should save backup data even without to_dict
|
||||
assert mock_storage.save.call_count > 0
|
||||
|
||||
def test_clear_message_related_tables_storage_error_continues(
|
||||
self, mock_session, sample_message_ids, sample_records
|
||||
):
|
||||
"""Test that method continues even when storage.save fails."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_storage.save.side_effect = Exception("Storage error")
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = sample_records
|
||||
|
||||
# Should not raise exception
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should still delete records even if backup fails
|
||||
assert mock_session.query.return_value.filter.return_value.delete.called
|
||||
|
||||
def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids):
|
||||
"""Test that method continues even when record serialization fails."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
record = Mock()
|
||||
record.id = "record-1"
|
||||
record.to_dict.side_effect = Exception("Serialization error")
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [record]
|
||||
|
||||
# Should not raise exception
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should still delete records even if serialization fails
|
||||
assert mock_session.query.return_value.filter.return_value.delete.called
|
||||
|
||||
def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records):
|
||||
"""Test that deletion is called for found records."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
# Should call delete for each table that has records
|
||||
assert mock_session.query.return_value.filter.return_value.delete.called
|
||||
|
||||
def test_clear_message_related_tables_logging_output(
|
||||
self, mock_session, sample_message_ids, sample_records, capsys
|
||||
):
|
||||
"""Test that logging output is generated."""
|
||||
with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage:
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = sample_records
|
||||
|
||||
ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids)
|
||||
|
||||
pass
|
||||
68
api/uv.lock
68
api/uv.lock
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.11, <3.13"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'",
|
||||
|
|
@ -1265,6 +1265,8 @@ dependencies = [
|
|||
{ name = "opentelemetry-instrumentation" },
|
||||
{ name = "opentelemetry-instrumentation-celery" },
|
||||
{ name = "opentelemetry-instrumentation-flask" },
|
||||
{ name = "opentelemetry-instrumentation-redis" },
|
||||
{ name = "opentelemetry-instrumentation-requests" },
|
||||
{ name = "opentelemetry-instrumentation-sqlalchemy" },
|
||||
{ name = "opentelemetry-propagator-b3" },
|
||||
{ name = "opentelemetry-proto" },
|
||||
|
|
@ -1318,6 +1320,7 @@ dev = [
|
|||
{ name = "pytest-mock" },
|
||||
{ name = "ruff" },
|
||||
{ name = "scipy-stubs" },
|
||||
{ name = "testcontainers" },
|
||||
{ name = "types-aiofiles" },
|
||||
{ name = "types-beautifulsoup4" },
|
||||
{ name = "types-cachetools" },
|
||||
|
|
@ -1447,6 +1450,8 @@ requires-dist = [
|
|||
{ name = "opentelemetry-instrumentation", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-instrumentation-celery", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-instrumentation-flask", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-instrumentation-redis", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-instrumentation-requests", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-propagator-b3", specifier = "==1.27.0" },
|
||||
{ name = "opentelemetry-proto", specifier = "==1.27.0" },
|
||||
|
|
@ -1500,6 +1505,7 @@ dev = [
|
|||
{ name = "pytest-mock", specifier = "~=3.14.0" },
|
||||
{ name = "ruff", specifier = "~=0.12.3" },
|
||||
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
|
||||
{ name = "testcontainers", specifier = "~=4.10.0" },
|
||||
{ name = "types-aiofiles", specifier = "~=24.1.0" },
|
||||
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
|
||||
{ name = "types-cachetools", specifier = "~=5.5.0" },
|
||||
|
|
@ -1600,6 +1606,20 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docker"
|
||||
version = "7.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
{ name = "requests" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docstring-parser"
|
||||
version = "0.16"
|
||||
|
|
@ -3654,6 +3674,36 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/78/3d/fcde4f8f0bf9fa1ee73a12304fa538076fb83fe0a2ae966ab0f0b7da5109/opentelemetry_instrumentation_flask-0.48b0-py3-none-any.whl", hash = "sha256:26b045420b9d76e85493b1c23fcf27517972423480dc6cf78fd6924248ba5808", size = 14588, upload-time = "2024-08-28T21:26:58.504Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-redis"
|
||||
version = "0.48b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-instrumentation" },
|
||||
{ name = "opentelemetry-semantic-conventions" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/70/be/92e98e4c7f275be3d373899a41b0a7d4df64266657d985dbbdb9a54de0d5/opentelemetry_instrumentation_redis-0.48b0.tar.gz", hash = "sha256:61e33e984b4120e1b980d9fba6e9f7ca0c8d972f9970654d8f6e9f27fa115a8c", size = 10511, upload-time = "2024-08-28T21:28:15.061Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/94/40/892f30d400091106309cc047fd3f6d76a828fedd984a953fd5386b78a2fb/opentelemetry_instrumentation_redis-0.48b0-py3-none-any.whl", hash = "sha256:48c7f2e25cbb30bde749dc0d8b9c74c404c851f554af832956b9630b27f5bcb7", size = 11610, upload-time = "2024-08-28T21:27:18.759Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-requests"
|
||||
version = "0.48b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-instrumentation" },
|
||||
{ name = "opentelemetry-semantic-conventions" },
|
||||
{ name = "opentelemetry-util-http" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/ac/5eb78efde21ff21d0ad5dc8c6cc6a0f8ae482ce8a46293c2f45a628b6166/opentelemetry_instrumentation_requests-0.48b0.tar.gz", hash = "sha256:67ab9bd877a0352ee0db4616c8b4ae59736ddd700c598ed907482d44f4c9a2b3", size = 14120, upload-time = "2024-08-28T21:28:16.933Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/43/df/0df9226d1b14f29d23c07e6194b9fd5ad50e7d987b7fd13df7dcf718aeb1/opentelemetry_instrumentation_requests-0.48b0-py3-none-any.whl", hash = "sha256:d4f01852121d0bd4c22f14f429654a735611d4f7bf3cf93f244bdf1489b2233d", size = 12366, upload-time = "2024-08-28T21:27:20.771Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-sqlalchemy"
|
||||
version = "0.48b0"
|
||||
|
|
@ -5468,6 +5518,22 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "testcontainers"
|
||||
version = "4.10.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "docker" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "urllib3" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a1/49/9c618aff1c50121d183cdfbc3a4a5cf2727a2cde1893efe6ca55c7009196/testcontainers-4.10.0.tar.gz", hash = "sha256:03f85c3e505d8b4edeb192c72a961cebbcba0dd94344ae778b4a159cb6dcf8d3", size = 63327, upload-time = "2025-04-02T16:13:27.582Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/0a/824b0c1ecf224802125279c3effff2e25ed785ed046e67da6e53d928de4c/testcontainers-4.10.0-py3-none-any.whl", hash = "sha256:31ed1a81238c7e131a2a29df6db8f23717d892b592fa5a1977fd0dcd0c23fc23", size = 107414, upload-time = "2025-04-02T16:13:25.785Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tidb-vector"
|
||||
version = "0.0.9"
|
||||
|
|
|
|||
|
|
@ -15,3 +15,6 @@ dev/pytest/pytest_workflow.sh
|
|||
|
||||
# Unit tests
|
||||
dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
# TestContainers tests
|
||||
dev/pytest/pytest_testcontainers.sh
|
||||
|
|
|
|||
|
|
@ -0,0 +1,7 @@
|
|||
#!/bin/bash
|
||||
set -x
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/../.."
|
||||
|
||||
pytest api/tests/test_containers_integration_tests
|
||||
|
|
@ -538,7 +538,7 @@ services:
|
|||
|
||||
milvus-standalone:
|
||||
container_name: milvus-standalone
|
||||
image: milvusdb/milvus:v2.5.0-beta
|
||||
image: milvusdb/milvus:v2.5.15
|
||||
profiles:
|
||||
- milvus
|
||||
command: [ 'milvus', 'run', 'standalone' ]
|
||||
|
|
|
|||
|
|
@ -1087,7 +1087,7 @@ services:
|
|||
|
||||
milvus-standalone:
|
||||
container_name: milvus-standalone
|
||||
image: milvusdb/milvus:v2.5.0-beta
|
||||
image: milvusdb/milvus:v2.5.15
|
||||
profiles:
|
||||
- milvus
|
||||
command: [ 'milvus', 'run', 'standalone' ]
|
||||
|
|
|
|||
|
|
@ -265,7 +265,6 @@ export default translation
|
|||
fs.writeFileSync(path.join(testZhDir, 'pages.ts'), file2Content)
|
||||
|
||||
const allEnKeys = await getKeysFromLanguage('en-US')
|
||||
const allZhKeys = await getKeysFromLanguage('zh-Hans')
|
||||
|
||||
// Test file filtering logic
|
||||
const targetFile = 'components'
|
||||
|
|
@ -563,4 +562,201 @@ export default translation
|
|||
expect(enKeys.length - zhKeysExtra.length).toBe(-2) // -2 means 2 extra keys
|
||||
})
|
||||
})
|
||||
|
||||
describe('Auto-remove multiline key-value pairs', () => {
|
||||
// Helper function to simulate removeExtraKeysFromFile logic
|
||||
function removeExtraKeysFromFile(content: string, keysToRemove: string[]): string {
|
||||
const lines = content.split('\n')
|
||||
const linesToRemove: number[] = []
|
||||
|
||||
for (const keyToRemove of keysToRemove) {
|
||||
let targetLineIndex = -1
|
||||
const linesToRemoveForKey: number[] = []
|
||||
|
||||
// Find the key line (simplified for single-level keys in test)
|
||||
for (let i = 0; i < lines.length; i++) {
|
||||
const line = lines[i]
|
||||
const keyPattern = new RegExp(`^\\s*${keyToRemove}\\s*:`)
|
||||
if (keyPattern.test(line)) {
|
||||
targetLineIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if (targetLineIndex !== -1) {
|
||||
linesToRemoveForKey.push(targetLineIndex)
|
||||
|
||||
// Check if this is a multiline key-value pair
|
||||
const keyLine = lines[targetLineIndex]
|
||||
const trimmedKeyLine = keyLine.trim()
|
||||
|
||||
// If key line ends with ":" (not complete value), it's likely multiline
|
||||
if (trimmedKeyLine.endsWith(':') && !trimmedKeyLine.includes('{') && !trimmedKeyLine.match(/:\s*['"`]/)) {
|
||||
// Find the value lines that belong to this key
|
||||
let currentLine = targetLineIndex + 1
|
||||
let foundValue = false
|
||||
|
||||
while (currentLine < lines.length) {
|
||||
const line = lines[currentLine]
|
||||
const trimmed = line.trim()
|
||||
|
||||
// Skip empty lines
|
||||
if (trimmed === '') {
|
||||
currentLine++
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this line starts a new key (indicates end of current value)
|
||||
if (trimmed.match(/^\w+\s*:/))
|
||||
break
|
||||
|
||||
// Check if this line is part of the value
|
||||
if (trimmed.startsWith('\'') || trimmed.startsWith('"') || trimmed.startsWith('`') || foundValue) {
|
||||
linesToRemoveForKey.push(currentLine)
|
||||
foundValue = true
|
||||
|
||||
// Check if this line ends the value (ends with quote and comma/no comma)
|
||||
if ((trimmed.endsWith('\',') || trimmed.endsWith('",') || trimmed.endsWith('`,')
|
||||
|| trimmed.endsWith('\'') || trimmed.endsWith('"') || trimmed.endsWith('`'))
|
||||
&& !trimmed.startsWith('//'))
|
||||
break
|
||||
}
|
||||
else {
|
||||
break
|
||||
}
|
||||
|
||||
currentLine++
|
||||
}
|
||||
}
|
||||
|
||||
linesToRemove.push(...linesToRemoveForKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove duplicates and sort in reverse order
|
||||
const uniqueLinesToRemove = [...new Set(linesToRemove)].sort((a, b) => b - a)
|
||||
|
||||
for (const lineIndex of uniqueLinesToRemove)
|
||||
lines.splice(lineIndex, 1)
|
||||
|
||||
return lines.join('\n')
|
||||
}
|
||||
|
||||
it('should remove single-line key-value pairs correctly', () => {
|
||||
const content = `const translation = {
|
||||
keepThis: 'This should stay',
|
||||
removeThis: 'This should be removed',
|
||||
alsoKeep: 'This should also stay',
|
||||
}
|
||||
|
||||
export default translation`
|
||||
|
||||
const result = removeExtraKeysFromFile(content, ['removeThis'])
|
||||
|
||||
expect(result).toContain('keepThis: \'This should stay\'')
|
||||
expect(result).toContain('alsoKeep: \'This should also stay\'')
|
||||
expect(result).not.toContain('removeThis: \'This should be removed\'')
|
||||
})
|
||||
|
||||
it('should remove multiline key-value pairs completely', () => {
|
||||
const content = `const translation = {
|
||||
keepThis: 'This should stay',
|
||||
removeMultiline:
|
||||
'This is a multiline value that should be removed completely',
|
||||
alsoKeep: 'This should also stay',
|
||||
}
|
||||
|
||||
export default translation`
|
||||
|
||||
const result = removeExtraKeysFromFile(content, ['removeMultiline'])
|
||||
|
||||
expect(result).toContain('keepThis: \'This should stay\'')
|
||||
expect(result).toContain('alsoKeep: \'This should also stay\'')
|
||||
expect(result).not.toContain('removeMultiline:')
|
||||
expect(result).not.toContain('This is a multiline value that should be removed completely')
|
||||
})
|
||||
|
||||
it('should handle mixed single-line and multiline removals', () => {
|
||||
const content = `const translation = {
|
||||
keepThis: 'Keep this',
|
||||
removeSingle: 'Remove this single line',
|
||||
removeMultiline:
|
||||
'Remove this multiline value',
|
||||
anotherMultiline:
|
||||
'Another multiline that spans multiple lines',
|
||||
keepAnother: 'Keep this too',
|
||||
}
|
||||
|
||||
export default translation`
|
||||
|
||||
const result = removeExtraKeysFromFile(content, ['removeSingle', 'removeMultiline', 'anotherMultiline'])
|
||||
|
||||
expect(result).toContain('keepThis: \'Keep this\'')
|
||||
expect(result).toContain('keepAnother: \'Keep this too\'')
|
||||
expect(result).not.toContain('removeSingle:')
|
||||
expect(result).not.toContain('removeMultiline:')
|
||||
expect(result).not.toContain('anotherMultiline:')
|
||||
expect(result).not.toContain('Remove this single line')
|
||||
expect(result).not.toContain('Remove this multiline value')
|
||||
expect(result).not.toContain('Another multiline that spans multiple lines')
|
||||
})
|
||||
|
||||
it('should properly detect multiline vs single-line patterns', () => {
|
||||
const multilineContent = `const translation = {
|
||||
singleLine: 'This is single line',
|
||||
multilineKey:
|
||||
'This is multiline',
|
||||
keyWithColon: 'Value with: colon inside',
|
||||
objectKey: {
|
||||
nested: 'value'
|
||||
},
|
||||
}
|
||||
|
||||
export default translation`
|
||||
|
||||
// Test that single line with colon in value is not treated as multiline
|
||||
const result1 = removeExtraKeysFromFile(multilineContent, ['keyWithColon'])
|
||||
expect(result1).not.toContain('keyWithColon:')
|
||||
expect(result1).not.toContain('Value with: colon inside')
|
||||
|
||||
// Test that true multiline is handled correctly
|
||||
const result2 = removeExtraKeysFromFile(multilineContent, ['multilineKey'])
|
||||
expect(result2).not.toContain('multilineKey:')
|
||||
expect(result2).not.toContain('This is multiline')
|
||||
|
||||
// Test that object key removal works (note: this is a simplified test)
|
||||
// In real scenario, object removal would be more complex
|
||||
const result3 = removeExtraKeysFromFile(multilineContent, ['objectKey'])
|
||||
expect(result3).not.toContain('objectKey: {')
|
||||
// Note: Our simplified test function doesn't handle nested object removal perfectly
|
||||
// This is acceptable as it's testing the main multiline string removal functionality
|
||||
})
|
||||
|
||||
it('should handle real-world Polish translation structure', () => {
|
||||
const polishContent = `const translation = {
|
||||
createApp: 'UTWÓRZ APLIKACJĘ',
|
||||
newApp: {
|
||||
captionAppType: 'Jaki typ aplikacji chcesz stworzyć?',
|
||||
chatbotDescription:
|
||||
'Zbuduj aplikację opartą na czacie. Ta aplikacja używa formatu pytań i odpowiedzi.',
|
||||
agentDescription:
|
||||
'Zbuduj inteligentnego agenta, który może autonomicznie wybierać narzędzia.',
|
||||
basic: 'Podstawowy',
|
||||
},
|
||||
}
|
||||
|
||||
export default translation`
|
||||
|
||||
const result = removeExtraKeysFromFile(polishContent, ['captionAppType', 'chatbotDescription', 'agentDescription'])
|
||||
|
||||
expect(result).toContain('createApp: \'UTWÓRZ APLIKACJĘ\'')
|
||||
expect(result).toContain('basic: \'Podstawowy\'')
|
||||
expect(result).not.toContain('captionAppType:')
|
||||
expect(result).not.toContain('chatbotDescription:')
|
||||
expect(result).not.toContain('agentDescription:')
|
||||
expect(result).not.toContain('Jaki typ aplikacji')
|
||||
expect(result).not.toContain('Zbuduj aplikację opartą na czacie')
|
||||
expect(result).not.toContain('Zbuduj inteligentnego agenta')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -0,0 +1,305 @@
|
|||
/**
|
||||
* Document Detail Navigation Fix Verification Test
|
||||
*
|
||||
* This test specifically validates that the backToPrev function in the document detail
|
||||
* component correctly preserves pagination and filter states.
|
||||
*/
|
||||
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useDocumentDetail, useDocumentMetadata } from '@/service/knowledge/use-document'
|
||||
|
||||
// Mock Next.js router
|
||||
const mockPush = jest.fn()
|
||||
jest.mock('next/navigation', () => ({
|
||||
useRouter: jest.fn(() => ({
|
||||
push: mockPush,
|
||||
})),
|
||||
}))
|
||||
|
||||
// Mock the document service hooks
|
||||
jest.mock('@/service/knowledge/use-document', () => ({
|
||||
useDocumentDetail: jest.fn(),
|
||||
useDocumentMetadata: jest.fn(),
|
||||
useInvalidDocumentList: jest.fn(() => jest.fn()),
|
||||
}))
|
||||
|
||||
// Mock other dependencies
|
||||
jest.mock('@/context/dataset-detail', () => ({
|
||||
useDatasetDetailContext: jest.fn(() => [null]),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/use-base', () => ({
|
||||
useInvalid: jest.fn(() => jest.fn()),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/knowledge/use-segment', () => ({
|
||||
useSegmentListKey: jest.fn(),
|
||||
useChildSegmentListKey: jest.fn(),
|
||||
}))
|
||||
|
||||
// Create a minimal version of the DocumentDetail component that includes our fix
|
||||
const DocumentDetailWithFix = ({ datasetId, documentId }: { datasetId: string; documentId: string }) => {
|
||||
const router = useRouter()
|
||||
|
||||
// This is the FIXED implementation from detail/index.tsx
|
||||
const backToPrev = () => {
|
||||
// Preserve pagination and filter states when navigating back
|
||||
const searchParams = new URLSearchParams(window.location.search)
|
||||
const queryString = searchParams.toString()
|
||||
const separator = queryString ? '?' : ''
|
||||
const backPath = `/datasets/${datasetId}/documents${separator}${queryString}`
|
||||
router.push(backPath)
|
||||
}
|
||||
|
||||
return (
|
||||
<div data-testid="document-detail-fixed">
|
||||
<button data-testid="back-button-fixed" onClick={backToPrev}>
|
||||
Back to Documents
|
||||
</button>
|
||||
<div data-testid="document-info">
|
||||
Dataset: {datasetId}, Document: {documentId}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
describe('Document Detail Navigation Fix Verification', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
|
||||
// Mock successful API responses
|
||||
;(useDocumentDetail as jest.Mock).mockReturnValue({
|
||||
data: {
|
||||
id: 'doc-123',
|
||||
name: 'Test Document',
|
||||
display_status: 'available',
|
||||
enabled: true,
|
||||
archived: false,
|
||||
},
|
||||
error: null,
|
||||
})
|
||||
|
||||
;(useDocumentMetadata as jest.Mock).mockReturnValue({
|
||||
data: null,
|
||||
error: null,
|
||||
})
|
||||
})
|
||||
|
||||
describe('Query Parameter Preservation', () => {
|
||||
test('preserves pagination state (page 3, limit 25)', () => {
|
||||
// Simulate user coming from page 3 with 25 items per page
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=3&limit=25',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
// User clicks back button
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
// Should preserve the pagination state
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=3&limit=25')
|
||||
|
||||
console.log('✅ Pagination state preserved: page=3&limit=25')
|
||||
})
|
||||
|
||||
test('preserves search keyword and filters', () => {
|
||||
// Simulate user with search and filters applied
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=2&limit=10&keyword=API%20documentation&status=active',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
// Should preserve all query parameters
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=2&limit=10&keyword=API+documentation&status=active')
|
||||
|
||||
console.log('✅ Search and filters preserved')
|
||||
})
|
||||
|
||||
test('handles complex query parameters with special characters', () => {
|
||||
// Test with complex query string including encoded characters
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=1&limit=50&keyword=test%20%26%20debug&sort=name&order=desc&filter=%7B%22type%22%3A%22pdf%22%7D',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
// URLSearchParams will normalize the encoding, but preserve all parameters
|
||||
const expectedCall = mockPush.mock.calls[0][0]
|
||||
expect(expectedCall).toMatch(/^\/datasets\/dataset-123\/documents\?/)
|
||||
expect(expectedCall).toMatch(/page=1/)
|
||||
expect(expectedCall).toMatch(/limit=50/)
|
||||
expect(expectedCall).toMatch(/keyword=test/)
|
||||
expect(expectedCall).toMatch(/sort=name/)
|
||||
expect(expectedCall).toMatch(/order=desc/)
|
||||
|
||||
console.log('✅ Complex query parameters handled:', expectedCall)
|
||||
})
|
||||
|
||||
test('handles empty query parameters gracefully', () => {
|
||||
// No query parameters in URL
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
// Should navigate to clean documents URL
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents')
|
||||
|
||||
console.log('✅ Empty parameters handled gracefully')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Different Dataset IDs', () => {
|
||||
test('works with different dataset identifiers', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=5&limit=10',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
// Test with different dataset ID format
|
||||
render(<DocumentDetailWithFix datasetId="ds-prod-2024-001" documentId="doc-456" />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/ds-prod-2024-001/documents?page=5&limit=10')
|
||||
|
||||
console.log('✅ Works with different dataset ID formats')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Real User Scenarios', () => {
|
||||
test('scenario: user searches, goes to page 3, views document, clicks back', () => {
|
||||
// User searched for "API" and navigated to page 3
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?keyword=API&page=3&limit=10',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="main-dataset" documentId="api-doc-123" />)
|
||||
|
||||
// User decides to go back to continue browsing
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
// Should return to page 3 of API search results
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/main-dataset/documents?keyword=API&page=3&limit=10')
|
||||
|
||||
console.log('✅ Real user scenario: search + pagination preserved')
|
||||
})
|
||||
|
||||
test('scenario: user applies multiple filters, goes to document, returns', () => {
|
||||
// User has applied multiple filters and is on page 2
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=2&limit=25&status=active&type=pdf&sort=created_at&order=desc',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="filtered-dataset" documentId="filtered-doc" />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
|
||||
// All filters should be preserved
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-dataset/documents?page=2&limit=25&status=active&type=pdf&sort=created_at&order=desc')
|
||||
|
||||
console.log('✅ Complex filtering scenario preserved')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling and Edge Cases', () => {
|
||||
test('handles malformed query parameters gracefully', () => {
|
||||
// Test with potentially problematic query string
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=invalid&limit=&keyword=test&=emptykey&malformed',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
// Should not throw errors
|
||||
expect(() => {
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
}).not.toThrow()
|
||||
|
||||
// Should still attempt navigation (URLSearchParams will clean up the parameters)
|
||||
expect(mockPush).toHaveBeenCalled()
|
||||
const navigationPath = mockPush.mock.calls[0][0]
|
||||
expect(navigationPath).toMatch(/^\/datasets\/dataset-123\/documents/)
|
||||
|
||||
console.log('✅ Malformed parameters handled gracefully:', navigationPath)
|
||||
})
|
||||
|
||||
test('handles very long query strings', () => {
|
||||
// Test with a very long query string
|
||||
const longKeyword = 'a'.repeat(1000)
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: `?page=1&keyword=${longKeyword}`,
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
expect(() => {
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
}).not.toThrow()
|
||||
|
||||
expect(mockPush).toHaveBeenCalled()
|
||||
|
||||
console.log('✅ Long query strings handled')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Performance Verification', () => {
|
||||
test('navigation function executes quickly', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: {
|
||||
search: '?page=1&limit=10&keyword=test',
|
||||
},
|
||||
writable: true,
|
||||
})
|
||||
|
||||
render(<DocumentDetailWithFix datasetId="dataset-123" documentId="doc-456" />)
|
||||
|
||||
const startTime = performance.now()
|
||||
fireEvent.click(screen.getByTestId('back-button-fixed'))
|
||||
const endTime = performance.now()
|
||||
|
||||
const executionTime = endTime - startTime
|
||||
|
||||
// Should execute in less than 10ms
|
||||
expect(executionTime).toBeLessThan(10)
|
||||
|
||||
console.log(`⚡ Navigation execution time: ${executionTime.toFixed(2)}ms`)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Document List Sorting Tests
|
||||
*/
|
||||
|
||||
describe('Document List Sorting', () => {
|
||||
const mockDocuments = [
|
||||
{ id: '1', name: 'Beta.pdf', word_count: 500, hit_count: 10, created_at: 1699123456 },
|
||||
{ id: '2', name: 'Alpha.txt', word_count: 200, hit_count: 25, created_at: 1699123400 },
|
||||
{ id: '3', name: 'Gamma.docx', word_count: 800, hit_count: 5, created_at: 1699123500 },
|
||||
]
|
||||
|
||||
const sortDocuments = (docs: any[], field: string, order: 'asc' | 'desc') => {
|
||||
return [...docs].sort((a, b) => {
|
||||
let aValue: any
|
||||
let bValue: any
|
||||
|
||||
switch (field) {
|
||||
case 'name':
|
||||
aValue = a.name?.toLowerCase() || ''
|
||||
bValue = b.name?.toLowerCase() || ''
|
||||
break
|
||||
case 'word_count':
|
||||
aValue = a.word_count || 0
|
||||
bValue = b.word_count || 0
|
||||
break
|
||||
case 'hit_count':
|
||||
aValue = a.hit_count || 0
|
||||
bValue = b.hit_count || 0
|
||||
break
|
||||
case 'created_at':
|
||||
aValue = a.created_at
|
||||
bValue = b.created_at
|
||||
break
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
|
||||
if (field === 'name') {
|
||||
const result = aValue.localeCompare(bValue)
|
||||
return order === 'asc' ? result : -result
|
||||
}
|
||||
else {
|
||||
const result = aValue - bValue
|
||||
return order === 'asc' ? result : -result
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
test('sorts by name descending (default for UI consistency)', () => {
|
||||
const sorted = sortDocuments(mockDocuments, 'name', 'desc')
|
||||
expect(sorted.map(doc => doc.name)).toEqual(['Gamma.docx', 'Beta.pdf', 'Alpha.txt'])
|
||||
})
|
||||
|
||||
test('sorts by name ascending (after toggle)', () => {
|
||||
const sorted = sortDocuments(mockDocuments, 'name', 'asc')
|
||||
expect(sorted.map(doc => doc.name)).toEqual(['Alpha.txt', 'Beta.pdf', 'Gamma.docx'])
|
||||
})
|
||||
|
||||
test('sorts by word_count descending', () => {
|
||||
const sorted = sortDocuments(mockDocuments, 'word_count', 'desc')
|
||||
expect(sorted.map(doc => doc.word_count)).toEqual([800, 500, 200])
|
||||
})
|
||||
|
||||
test('sorts by hit_count descending', () => {
|
||||
const sorted = sortDocuments(mockDocuments, 'hit_count', 'desc')
|
||||
expect(sorted.map(doc => doc.hit_count)).toEqual([25, 10, 5])
|
||||
})
|
||||
|
||||
test('sorts by created_at descending (newest first)', () => {
|
||||
const sorted = sortDocuments(mockDocuments, 'created_at', 'desc')
|
||||
expect(sorted.map(doc => doc.created_at)).toEqual([1699123500, 1699123456, 1699123400])
|
||||
})
|
||||
|
||||
test('handles empty values correctly', () => {
|
||||
const docsWithEmpty = [
|
||||
{ id: '1', name: 'Test', word_count: 100, hit_count: 5, created_at: 1699123456 },
|
||||
{ id: '2', name: 'Empty', word_count: 0, hit_count: 0, created_at: 1699123400 },
|
||||
]
|
||||
|
||||
const sorted = sortDocuments(docsWithEmpty, 'word_count', 'desc')
|
||||
expect(sorted.map(doc => doc.word_count)).toEqual([100, 0])
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,290 @@
|
|||
/**
|
||||
* Navigation Utilities Test
|
||||
*
|
||||
* Tests for the navigation utility functions to ensure they handle
|
||||
* query parameter preservation correctly across different scenarios.
|
||||
*/
|
||||
|
||||
import {
|
||||
createBackNavigation,
|
||||
createNavigationPath,
|
||||
createNavigationPathWithParams,
|
||||
datasetNavigation,
|
||||
extractQueryParams,
|
||||
mergeQueryParams,
|
||||
} from '@/utils/navigation'
|
||||
|
||||
// Mock router for testing
|
||||
const mockPush = jest.fn()
|
||||
const mockRouter = { push: mockPush }
|
||||
|
||||
describe('Navigation Utilities', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('createNavigationPath', () => {
|
||||
test('preserves query parameters by default', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&limit=10&keyword=test' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const path = createNavigationPath('/datasets/123/documents')
|
||||
expect(path).toBe('/datasets/123/documents?page=3&limit=10&keyword=test')
|
||||
})
|
||||
|
||||
test('returns clean path when preserveParams is false', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&limit=10' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const path = createNavigationPath('/datasets/123/documents', false)
|
||||
expect(path).toBe('/datasets/123/documents')
|
||||
})
|
||||
|
||||
test('handles empty query parameters', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const path = createNavigationPath('/datasets/123/documents')
|
||||
expect(path).toBe('/datasets/123/documents')
|
||||
})
|
||||
|
||||
test('handles errors gracefully', () => {
|
||||
// Mock window.location to throw an error
|
||||
Object.defineProperty(window, 'location', {
|
||||
get: () => {
|
||||
throw new Error('Location access denied')
|
||||
},
|
||||
configurable: true,
|
||||
})
|
||||
|
||||
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
|
||||
const path = createNavigationPath('/datasets/123/documents')
|
||||
|
||||
expect(path).toBe('/datasets/123/documents')
|
||||
expect(consoleSpy).toHaveBeenCalledWith('Failed to preserve query parameters:', expect.any(Error))
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('createBackNavigation', () => {
|
||||
test('creates function that navigates with preserved params', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=2&limit=25' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const backFn = createBackNavigation(mockRouter, '/datasets/123/documents')
|
||||
backFn()
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/123/documents?page=2&limit=25')
|
||||
})
|
||||
|
||||
test('creates function that navigates without params when specified', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=2&limit=25' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const backFn = createBackNavigation(mockRouter, '/datasets/123/documents', false)
|
||||
backFn()
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/123/documents')
|
||||
})
|
||||
})
|
||||
|
||||
describe('extractQueryParams', () => {
|
||||
test('extracts specified parameters', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&limit=10&keyword=test&other=value' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const params = extractQueryParams(['page', 'limit', 'keyword'])
|
||||
expect(params).toEqual({
|
||||
page: '3',
|
||||
limit: '10',
|
||||
keyword: 'test',
|
||||
})
|
||||
})
|
||||
|
||||
test('handles missing parameters', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const params = extractQueryParams(['page', 'limit', 'missing'])
|
||||
expect(params).toEqual({
|
||||
page: '3',
|
||||
})
|
||||
})
|
||||
|
||||
test('handles errors gracefully', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
get: () => {
|
||||
throw new Error('Location access denied')
|
||||
},
|
||||
configurable: true,
|
||||
})
|
||||
|
||||
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
|
||||
const params = extractQueryParams(['page', 'limit'])
|
||||
|
||||
expect(params).toEqual({})
|
||||
expect(consoleSpy).toHaveBeenCalledWith('Failed to extract query parameters:', expect.any(Error))
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe('createNavigationPathWithParams', () => {
|
||||
test('creates path with specified parameters', () => {
|
||||
const path = createNavigationPathWithParams('/datasets/123/documents', {
|
||||
page: 1,
|
||||
limit: 25,
|
||||
keyword: 'search term',
|
||||
})
|
||||
|
||||
expect(path).toBe('/datasets/123/documents?page=1&limit=25&keyword=search+term')
|
||||
})
|
||||
|
||||
test('filters out empty values', () => {
|
||||
const path = createNavigationPathWithParams('/datasets/123/documents', {
|
||||
page: 1,
|
||||
limit: '',
|
||||
keyword: 'test',
|
||||
empty: null,
|
||||
undefined,
|
||||
})
|
||||
|
||||
expect(path).toBe('/datasets/123/documents?page=1&keyword=test')
|
||||
})
|
||||
|
||||
test('handles errors gracefully', () => {
|
||||
// Mock URLSearchParams to throw an error
|
||||
const originalURLSearchParams = globalThis.URLSearchParams
|
||||
globalThis.URLSearchParams = jest.fn(() => {
|
||||
throw new Error('URLSearchParams error')
|
||||
}) as any
|
||||
|
||||
const consoleSpy = jest.spyOn(console, 'warn').mockImplementation()
|
||||
const path = createNavigationPathWithParams('/datasets/123/documents', { page: 1 })
|
||||
|
||||
expect(path).toBe('/datasets/123/documents')
|
||||
expect(consoleSpy).toHaveBeenCalledWith('Failed to create navigation path with params:', expect.any(Error))
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
globalThis.URLSearchParams = originalURLSearchParams
|
||||
})
|
||||
})
|
||||
|
||||
describe('mergeQueryParams', () => {
|
||||
test('merges new params with existing ones', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&limit=10' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const merged = mergeQueryParams({ keyword: 'test', page: '1' })
|
||||
const result = merged.toString()
|
||||
|
||||
expect(result).toContain('page=1') // overridden
|
||||
expect(result).toContain('limit=10') // preserved
|
||||
expect(result).toContain('keyword=test') // added
|
||||
})
|
||||
|
||||
test('removes parameters when value is null', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&limit=10&keyword=test' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const merged = mergeQueryParams({ keyword: null, filter: 'active' })
|
||||
const result = merged.toString()
|
||||
|
||||
expect(result).toContain('page=3')
|
||||
expect(result).toContain('limit=10')
|
||||
expect(result).not.toContain('keyword')
|
||||
expect(result).toContain('filter=active')
|
||||
})
|
||||
|
||||
test('creates fresh params when preserveExisting is false', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&limit=10' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const merged = mergeQueryParams({ keyword: 'test' }, false)
|
||||
const result = merged.toString()
|
||||
|
||||
expect(result).toBe('keyword=test')
|
||||
})
|
||||
})
|
||||
|
||||
describe('datasetNavigation', () => {
|
||||
test('backToDocuments creates correct navigation function', () => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=2&limit=25' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const backFn = datasetNavigation.backToDocuments(mockRouter, 'dataset-123')
|
||||
backFn()
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=2&limit=25')
|
||||
})
|
||||
|
||||
test('toDocumentDetail creates correct navigation function', () => {
|
||||
const detailFn = datasetNavigation.toDocumentDetail(mockRouter, 'dataset-123', 'doc-456')
|
||||
detailFn()
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents/doc-456')
|
||||
})
|
||||
|
||||
test('toDocumentSettings creates correct navigation function', () => {
|
||||
const settingsFn = datasetNavigation.toDocumentSettings(mockRouter, 'dataset-123', 'doc-456')
|
||||
settingsFn()
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents/doc-456/settings')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Real-world Integration Scenarios', () => {
|
||||
test('complete user workflow: list -> detail -> back', () => {
|
||||
// User starts on page 3 with search
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=3&keyword=API&limit=25' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
// Create back navigation function (as would be done in detail component)
|
||||
const backToDocuments = datasetNavigation.backToDocuments(mockRouter, 'main-dataset')
|
||||
|
||||
// User clicks back
|
||||
backToDocuments()
|
||||
|
||||
// Should return to exact same list state
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/main-dataset/documents?page=3&keyword=API&limit=25')
|
||||
})
|
||||
|
||||
test('user applies filters then views document', () => {
|
||||
// Complex filter state
|
||||
Object.defineProperty(window, 'location', {
|
||||
value: { search: '?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc' },
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const backFn = createBackNavigation(mockRouter, '/datasets/filtered-set/documents')
|
||||
backFn()
|
||||
|
||||
expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-set/documents?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,396 @@
|
|||
/**
|
||||
* Unified Tags Editing - Pure Logic Tests
|
||||
*
|
||||
* This test file validates the core business logic and state management
|
||||
* behaviors introduced in the recent 7 commits without requiring complex mocks.
|
||||
*/
|
||||
|
||||
describe('Unified Tags Editing - Pure Logic Tests', () => {
|
||||
describe('Tag State Management Logic', () => {
|
||||
it('should detect when tag values have changed', () => {
|
||||
const currentValue = ['tag1', 'tag2']
|
||||
const newSelectedTagIDs = ['tag1', 'tag3']
|
||||
|
||||
// This is the valueNotChanged logic from TagSelector component
|
||||
const valueNotChanged
|
||||
= currentValue.length === newSelectedTagIDs.length
|
||||
&& currentValue.every(v => newSelectedTagIDs.includes(v))
|
||||
&& newSelectedTagIDs.every(v => currentValue.includes(v))
|
||||
|
||||
expect(valueNotChanged).toBe(false)
|
||||
})
|
||||
|
||||
it('should correctly identify unchanged tag values', () => {
|
||||
const currentValue = ['tag1', 'tag2']
|
||||
const newSelectedTagIDs = ['tag2', 'tag1'] // Same tags, different order
|
||||
|
||||
const valueNotChanged
|
||||
= currentValue.length === newSelectedTagIDs.length
|
||||
&& currentValue.every(v => newSelectedTagIDs.includes(v))
|
||||
&& newSelectedTagIDs.every(v => currentValue.includes(v))
|
||||
|
||||
expect(valueNotChanged).toBe(true)
|
||||
})
|
||||
|
||||
it('should calculate correct tag operations for binding/unbinding', () => {
|
||||
const currentValue = ['tag1', 'tag2']
|
||||
const selectedTagIDs = ['tag2', 'tag3']
|
||||
|
||||
// This is the handleValueChange logic from TagSelector
|
||||
const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v))
|
||||
const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v))
|
||||
|
||||
expect(addTagIDs).toEqual(['tag3'])
|
||||
expect(removeTagIDs).toEqual(['tag1'])
|
||||
})
|
||||
|
||||
it('should handle empty tag arrays correctly', () => {
|
||||
const currentValue: string[] = []
|
||||
const selectedTagIDs = ['tag1']
|
||||
|
||||
const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v))
|
||||
const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v))
|
||||
|
||||
expect(addTagIDs).toEqual(['tag1'])
|
||||
expect(removeTagIDs).toEqual([])
|
||||
expect(currentValue.length).toBe(0) // Verify empty array usage
|
||||
})
|
||||
|
||||
it('should handle removing all tags', () => {
|
||||
const currentValue = ['tag1', 'tag2']
|
||||
const selectedTagIDs: string[] = []
|
||||
|
||||
const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v))
|
||||
const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v))
|
||||
|
||||
expect(addTagIDs).toEqual([])
|
||||
expect(removeTagIDs).toEqual(['tag1', 'tag2'])
|
||||
expect(selectedTagIDs.length).toBe(0) // Verify empty array usage
|
||||
})
|
||||
})
|
||||
|
||||
describe('Fallback Logic (from layout-main.tsx)', () => {
|
||||
it('should trigger fallback when tags are missing or empty', () => {
|
||||
const appDetailWithoutTags = { tags: [] }
|
||||
const appDetailWithTags = { tags: [{ id: 'tag1' }] }
|
||||
const appDetailWithUndefinedTags = { tags: undefined as any }
|
||||
|
||||
// This simulates the condition in layout-main.tsx
|
||||
const shouldFallback1 = !appDetailWithoutTags.tags || appDetailWithoutTags.tags.length === 0
|
||||
const shouldFallback2 = !appDetailWithTags.tags || appDetailWithTags.tags.length === 0
|
||||
const shouldFallback3 = !appDetailWithUndefinedTags.tags || appDetailWithUndefinedTags.tags.length === 0
|
||||
|
||||
expect(shouldFallback1).toBe(true) // Empty array should trigger fallback
|
||||
expect(shouldFallback2).toBe(false) // Has tags, no fallback needed
|
||||
expect(shouldFallback3).toBe(true) // Undefined tags should trigger fallback
|
||||
})
|
||||
|
||||
it('should preserve tags when fallback succeeds', () => {
|
||||
const originalAppDetail = { tags: [] as any[] }
|
||||
const fallbackResult = { tags: [{ id: 'tag1', name: 'fallback-tag' }] }
|
||||
|
||||
// This simulates the successful fallback in layout-main.tsx
|
||||
if (fallbackResult?.tags)
|
||||
originalAppDetail.tags = fallbackResult.tags
|
||||
|
||||
expect(originalAppDetail.tags).toEqual(fallbackResult.tags)
|
||||
expect(originalAppDetail.tags.length).toBe(1)
|
||||
})
|
||||
|
||||
it('should continue with empty tags when fallback fails', () => {
|
||||
const originalAppDetail: { tags: any[] } = { tags: [] }
|
||||
const fallbackResult: { tags?: any[] } | null = null
|
||||
|
||||
// This simulates fallback failure in layout-main.tsx
|
||||
if (fallbackResult?.tags)
|
||||
originalAppDetail.tags = fallbackResult.tags
|
||||
|
||||
expect(originalAppDetail.tags).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('TagSelector Auto-initialization Logic', () => {
|
||||
it('should trigger getTagList when tagList is empty', () => {
|
||||
const tagList: any[] = []
|
||||
let getTagListCalled = false
|
||||
const getTagList = () => {
|
||||
getTagListCalled = true
|
||||
}
|
||||
|
||||
// This simulates the useEffect in TagSelector
|
||||
if (tagList.length === 0)
|
||||
getTagList()
|
||||
|
||||
expect(getTagListCalled).toBe(true)
|
||||
})
|
||||
|
||||
it('should not trigger getTagList when tagList has items', () => {
|
||||
const tagList = [{ id: 'tag1', name: 'existing-tag' }]
|
||||
let getTagListCalled = false
|
||||
const getTagList = () => {
|
||||
getTagListCalled = true
|
||||
}
|
||||
|
||||
// This simulates the useEffect in TagSelector
|
||||
if (tagList.length === 0)
|
||||
getTagList()
|
||||
|
||||
expect(getTagListCalled).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('State Initialization Patterns', () => {
|
||||
it('should maintain AppCard tag state pattern', () => {
|
||||
const app = { tags: [{ id: 'tag1', name: 'test' }] }
|
||||
|
||||
// Original AppCard pattern: useState(app.tags)
|
||||
const initialTags = app.tags
|
||||
expect(Array.isArray(initialTags)).toBe(true)
|
||||
expect(initialTags.length).toBe(1)
|
||||
expect(initialTags).toBe(app.tags) // Reference equality for AppCard
|
||||
})
|
||||
|
||||
it('should maintain AppInfo tag state pattern', () => {
|
||||
const appDetail = { tags: [{ id: 'tag1', name: 'test' }] }
|
||||
|
||||
// New AppInfo pattern: useState(appDetail?.tags || [])
|
||||
const initialTags = appDetail?.tags || []
|
||||
expect(Array.isArray(initialTags)).toBe(true)
|
||||
expect(initialTags.length).toBe(1)
|
||||
})
|
||||
|
||||
it('should handle undefined appDetail gracefully in AppInfo', () => {
|
||||
const appDetail = undefined
|
||||
|
||||
// AppInfo pattern with undefined appDetail
|
||||
const initialTags = (appDetail as any)?.tags || []
|
||||
expect(Array.isArray(initialTags)).toBe(true)
|
||||
expect(initialTags.length).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('CSS Class and Layout Logic', () => {
|
||||
it('should apply correct minimum width condition', () => {
|
||||
const minWidth = 'true'
|
||||
|
||||
// This tests the minWidth logic in TagSelector
|
||||
const shouldApplyMinWidth = minWidth && '!min-w-80'
|
||||
expect(shouldApplyMinWidth).toBe('!min-w-80')
|
||||
})
|
||||
|
||||
it('should not apply minimum width when not specified', () => {
|
||||
const minWidth = undefined
|
||||
|
||||
const shouldApplyMinWidth = minWidth && '!min-w-80'
|
||||
expect(shouldApplyMinWidth).toBeFalsy()
|
||||
})
|
||||
|
||||
it('should handle overflow layout classes correctly', () => {
|
||||
// This tests the layout pattern from AppCard and new AppInfo
|
||||
const overflowLayoutClasses = {
|
||||
container: 'flex w-0 grow items-center',
|
||||
inner: 'w-full',
|
||||
truncate: 'truncate',
|
||||
}
|
||||
|
||||
expect(overflowLayoutClasses.container).toContain('w-0 grow')
|
||||
expect(overflowLayoutClasses.inner).toContain('w-full')
|
||||
expect(overflowLayoutClasses.truncate).toBe('truncate')
|
||||
})
|
||||
})
|
||||
|
||||
describe('fetchAppWithTags Service Logic', () => {
|
||||
it('should correctly find app by ID from app list', () => {
|
||||
const appList = [
|
||||
{ id: 'app1', name: 'App 1', tags: [] },
|
||||
{ id: 'test-app-id', name: 'Test App', tags: [{ id: 'tag1', name: 'test' }] },
|
||||
{ id: 'app3', name: 'App 3', tags: [] },
|
||||
]
|
||||
const targetAppId = 'test-app-id'
|
||||
|
||||
// This simulates the logic in fetchAppWithTags
|
||||
const foundApp = appList.find(app => app.id === targetAppId)
|
||||
|
||||
expect(foundApp).toBeDefined()
|
||||
expect(foundApp?.id).toBe('test-app-id')
|
||||
expect(foundApp?.tags.length).toBe(1)
|
||||
})
|
||||
|
||||
it('should return null when app not found', () => {
|
||||
const appList = [
|
||||
{ id: 'app1', name: 'App 1' },
|
||||
{ id: 'app2', name: 'App 2' },
|
||||
]
|
||||
const targetAppId = 'nonexistent-app'
|
||||
|
||||
const foundApp = appList.find(app => app.id === targetAppId) || null
|
||||
|
||||
expect(foundApp).toBeNull()
|
||||
})
|
||||
|
||||
it('should handle empty app list', () => {
|
||||
const appList: any[] = []
|
||||
const targetAppId = 'any-app'
|
||||
|
||||
const foundApp = appList.find(app => app.id === targetAppId) || null
|
||||
|
||||
expect(foundApp).toBeNull()
|
||||
expect(appList.length).toBe(0) // Verify empty array usage
|
||||
})
|
||||
})
|
||||
|
||||
describe('Data Structure Validation', () => {
|
||||
it('should maintain consistent tag data structure', () => {
|
||||
const tag = {
|
||||
id: 'tag1',
|
||||
name: 'test-tag',
|
||||
type: 'app',
|
||||
binding_count: 1,
|
||||
}
|
||||
|
||||
expect(tag).toHaveProperty('id')
|
||||
expect(tag).toHaveProperty('name')
|
||||
expect(tag).toHaveProperty('type')
|
||||
expect(tag).toHaveProperty('binding_count')
|
||||
expect(tag.type).toBe('app')
|
||||
expect(typeof tag.binding_count).toBe('number')
|
||||
})
|
||||
|
||||
it('should handle tag arrays correctly', () => {
|
||||
const tags = [
|
||||
{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 1 },
|
||||
{ id: 'tag2', name: 'Tag 2', type: 'app', binding_count: 0 },
|
||||
]
|
||||
|
||||
expect(Array.isArray(tags)).toBe(true)
|
||||
expect(tags.length).toBe(2)
|
||||
expect(tags.every(tag => tag.type === 'app')).toBe(true)
|
||||
})
|
||||
|
||||
it('should validate app data structure with tags', () => {
|
||||
const app = {
|
||||
id: 'test-app',
|
||||
name: 'Test App',
|
||||
tags: [
|
||||
{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 1 },
|
||||
],
|
||||
}
|
||||
|
||||
expect(app).toHaveProperty('id')
|
||||
expect(app).toHaveProperty('name')
|
||||
expect(app).toHaveProperty('tags')
|
||||
expect(Array.isArray(app.tags)).toBe(true)
|
||||
expect(app.tags.length).toBe(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Performance and Edge Cases', () => {
|
||||
it('should handle large tag arrays efficiently', () => {
|
||||
const largeTags = Array.from({ length: 100 }, (_, i) => `tag${i}`)
|
||||
const selectedTags = ['tag1', 'tag50', 'tag99']
|
||||
|
||||
// Performance test: filtering should be efficient
|
||||
const startTime = Date.now()
|
||||
const addTags = selectedTags.filter(tag => !largeTags.includes(tag))
|
||||
const removeTags = largeTags.filter(tag => !selectedTags.includes(tag))
|
||||
const endTime = Date.now()
|
||||
|
||||
expect(endTime - startTime).toBeLessThan(10) // Should be very fast
|
||||
expect(addTags.length).toBe(0) // All selected tags exist
|
||||
expect(removeTags.length).toBe(97) // 100 - 3 = 97 tags to remove
|
||||
})
|
||||
|
||||
it('should handle malformed tag data gracefully', () => {
|
||||
const mixedData = [
|
||||
{ id: 'valid1', name: 'Valid Tag', type: 'app', binding_count: 1 },
|
||||
{ id: 'invalid1' }, // Missing required properties
|
||||
null,
|
||||
undefined,
|
||||
{ id: 'valid2', name: 'Another Valid', type: 'app', binding_count: 0 },
|
||||
]
|
||||
|
||||
// Filter out invalid entries
|
||||
const validTags = mixedData.filter((tag): tag is { id: string; name: string; type: string; binding_count: number } =>
|
||||
tag != null
|
||||
&& typeof tag === 'object'
|
||||
&& 'id' in tag
|
||||
&& 'name' in tag
|
||||
&& 'type' in tag
|
||||
&& 'binding_count' in tag
|
||||
&& typeof tag.binding_count === 'number',
|
||||
)
|
||||
|
||||
expect(validTags.length).toBe(2)
|
||||
expect(validTags.every(tag => tag.id && tag.name)).toBe(true)
|
||||
})
|
||||
|
||||
it('should handle concurrent tag operations correctly', () => {
|
||||
const operations = [
|
||||
{ type: 'add', tagIds: ['tag1', 'tag2'] },
|
||||
{ type: 'remove', tagIds: ['tag3'] },
|
||||
{ type: 'add', tagIds: ['tag4'] },
|
||||
]
|
||||
|
||||
// Simulate processing operations
|
||||
const results = operations.map(op => ({
|
||||
...op,
|
||||
processed: true,
|
||||
timestamp: Date.now(),
|
||||
}))
|
||||
|
||||
expect(results.length).toBe(3)
|
||||
expect(results.every(result => result.processed)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Backward Compatibility Verification', () => {
|
||||
it('should not break existing AppCard behavior', () => {
|
||||
// Verify AppCard continues to work with original patterns
|
||||
const originalAppCardLogic = {
|
||||
initializeTags: (app: any) => app.tags,
|
||||
updateTags: (_currentTags: any[], newTags: any[]) => newTags,
|
||||
shouldRefresh: true,
|
||||
}
|
||||
|
||||
const app = { tags: [{ id: 'tag1', name: 'original' }] }
|
||||
const initializedTags = originalAppCardLogic.initializeTags(app)
|
||||
|
||||
expect(initializedTags).toBe(app.tags)
|
||||
expect(originalAppCardLogic.shouldRefresh).toBe(true)
|
||||
})
|
||||
|
||||
it('should ensure AppInfo follows AppCard patterns', () => {
|
||||
// Verify AppInfo uses compatible state management
|
||||
const appCardPattern = (app: any) => app.tags
|
||||
const appInfoPattern = (appDetail: any) => appDetail?.tags || []
|
||||
|
||||
const appWithTags = { tags: [{ id: 'tag1' }] }
|
||||
const appWithoutTags = { tags: [] }
|
||||
const undefinedApp = undefined
|
||||
|
||||
expect(appCardPattern(appWithTags)).toEqual(appInfoPattern(appWithTags))
|
||||
expect(appInfoPattern(appWithoutTags)).toEqual([])
|
||||
expect(appInfoPattern(undefinedApp)).toEqual([])
|
||||
})
|
||||
|
||||
it('should maintain consistent API parameters', () => {
|
||||
// Verify service layer maintains expected parameters
|
||||
const fetchAppListParams = {
|
||||
url: '/apps',
|
||||
params: { page: 1, limit: 100 },
|
||||
}
|
||||
|
||||
const tagApiParams = {
|
||||
bindTag: (tagIDs: string[], targetID: string, type: string) => ({ tagIDs, targetID, type }),
|
||||
unBindTag: (tagID: string, targetID: string, type: string) => ({ tagID, targetID, type }),
|
||||
}
|
||||
|
||||
expect(fetchAppListParams.url).toBe('/apps')
|
||||
expect(fetchAppListParams.params.limit).toBe(100)
|
||||
|
||||
const bindResult = tagApiParams.bindTag(['tag1'], 'app1', 'app')
|
||||
expect(bindResult.tagIDs).toEqual(['tag1'])
|
||||
expect(bindResult.type).toBe('app')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,212 @@
|
|||
/**
|
||||
* XSS Fix Verification Test
|
||||
*
|
||||
* This test verifies that the XSS vulnerability in check-code pages has been
|
||||
* properly fixed by replacing dangerouslySetInnerHTML with safe React rendering.
|
||||
*/
|
||||
|
||||
import React from 'react'
|
||||
import { cleanup, render } from '@testing-library/react'
|
||||
import '@testing-library/jest-dom'
|
||||
|
||||
// Mock i18next with the new safe translation structure
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
if (key === 'login.checkCode.tipsPrefix')
|
||||
return 'We send a verification code to '
|
||||
|
||||
return key
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock Next.js useSearchParams
|
||||
jest.mock('next/navigation', () => ({
|
||||
useSearchParams: () => ({
|
||||
get: (key: string) => {
|
||||
if (key === 'email')
|
||||
return 'test@example.com<script>alert("XSS")</script>'
|
||||
return null
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
// Fixed CheckCode component implementation (current secure version)
|
||||
const SecureCheckCodeComponent = ({ email }: { email: string }) => {
|
||||
const { t } = require('react-i18next').useTranslation()
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h1>Check Code</h1>
|
||||
<p>
|
||||
<span>
|
||||
{t('login.checkCode.tipsPrefix')}
|
||||
<strong>{email}</strong>
|
||||
</span>
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Vulnerable implementation for comparison (what we fixed)
|
||||
const VulnerableCheckCodeComponent = ({ email }: { email: string }) => {
|
||||
const mockTranslation = (key: string, params?: any) => {
|
||||
if (key === 'login.checkCode.tips' && params?.email)
|
||||
return `We send a verification code to <strong>${params.email}</strong>`
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h1>Check Code</h1>
|
||||
<p>
|
||||
<span dangerouslySetInnerHTML={{ __html: mockTranslation('login.checkCode.tips', { email }) }}></span>
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
describe('XSS Fix Verification - Check Code Pages Security', () => {
|
||||
afterEach(() => {
|
||||
cleanup()
|
||||
})
|
||||
|
||||
const maliciousEmail = 'test@example.com<script>alert("XSS")</script>'
|
||||
|
||||
it('should securely render email with HTML characters as text (FIXED VERSION)', () => {
|
||||
console.log('\n🔒 Security Fix Verification Report')
|
||||
console.log('===================================')
|
||||
|
||||
const { container } = render(<SecureCheckCodeComponent email={maliciousEmail} />)
|
||||
|
||||
const spanElement = container.querySelector('span')
|
||||
const strongElement = container.querySelector('strong')
|
||||
const scriptElements = container.querySelectorAll('script')
|
||||
|
||||
console.log('\n✅ Fixed Implementation Results:')
|
||||
console.log('- Email rendered in strong tag:', strongElement?.textContent)
|
||||
console.log('- HTML tags visible as text:', strongElement?.textContent?.includes('<script>'))
|
||||
console.log('- Script elements created:', scriptElements.length)
|
||||
console.log('- Full text content:', spanElement?.textContent)
|
||||
|
||||
// Verify secure behavior
|
||||
expect(strongElement?.textContent).toBe(maliciousEmail) // Email rendered as text
|
||||
expect(strongElement?.textContent).toContain('<script>') // HTML visible as text
|
||||
expect(scriptElements).toHaveLength(0) // No script elements created
|
||||
expect(spanElement?.textContent).toBe(`We send a verification code to ${maliciousEmail}`)
|
||||
|
||||
console.log('\n🎯 Security Status: SECURE - HTML automatically escaped by React')
|
||||
})
|
||||
|
||||
it('should demonstrate the vulnerability that was fixed (VULNERABLE VERSION)', () => {
|
||||
const { container } = render(<VulnerableCheckCodeComponent email={maliciousEmail} />)
|
||||
|
||||
const spanElement = container.querySelector('span')
|
||||
const strongElement = container.querySelector('strong')
|
||||
const scriptElements = container.querySelectorAll('script')
|
||||
|
||||
console.log('\n⚠️ Previous Vulnerable Implementation:')
|
||||
console.log('- HTML content:', spanElement?.innerHTML)
|
||||
console.log('- Strong element text:', strongElement?.textContent)
|
||||
console.log('- Script elements created:', scriptElements.length)
|
||||
console.log('- Script content:', scriptElements[0]?.textContent)
|
||||
|
||||
// Verify vulnerability exists in old implementation
|
||||
expect(scriptElements).toHaveLength(1) // Script element was created
|
||||
expect(scriptElements[0]?.textContent).toBe('alert("XSS")') // Contains malicious code
|
||||
expect(spanElement?.innerHTML).toContain('<script>') // Raw HTML in DOM
|
||||
|
||||
console.log('\n❌ Security Status: VULNERABLE - dangerouslySetInnerHTML creates script elements')
|
||||
})
|
||||
|
||||
it('should verify all affected components use the secure pattern', () => {
|
||||
console.log('\n📋 Component Security Audit')
|
||||
console.log('============================')
|
||||
|
||||
// Test multiple malicious inputs
|
||||
const testCases = [
|
||||
'user@test.com<img src=x onerror=alert(1)>',
|
||||
'test@evil.com<div onclick="alert(2)">click</div>',
|
||||
'admin@site.com<script>document.cookie="stolen"</script>',
|
||||
'normal@email.com',
|
||||
]
|
||||
|
||||
testCases.forEach((testEmail, index) => {
|
||||
const { container } = render(<SecureCheckCodeComponent email={testEmail} />)
|
||||
|
||||
const strongElement = container.querySelector('strong')
|
||||
const scriptElements = container.querySelectorAll('script')
|
||||
const imgElements = container.querySelectorAll('img')
|
||||
const divElements = container.querySelectorAll('div:not([data-testid])')
|
||||
|
||||
console.log(`\n📧 Test Case ${index + 1}: ${testEmail.substring(0, 20)}...`)
|
||||
console.log(` - Script elements: ${scriptElements.length}`)
|
||||
console.log(` - Img elements: ${imgElements.length}`)
|
||||
console.log(` - Malicious divs: ${divElements.length - 1}`) // -1 for container div
|
||||
console.log(` - Text content: ${strongElement?.textContent === testEmail ? 'SAFE' : 'ISSUE'}`)
|
||||
|
||||
// All should be safe
|
||||
expect(scriptElements).toHaveLength(0)
|
||||
expect(imgElements).toHaveLength(0)
|
||||
expect(strongElement?.textContent).toBe(testEmail)
|
||||
})
|
||||
|
||||
console.log('\n✅ All test cases passed - secure rendering confirmed')
|
||||
})
|
||||
|
||||
it('should validate the translation structure is secure', () => {
|
||||
console.log('\n🔍 Translation Security Analysis')
|
||||
console.log('=================================')
|
||||
|
||||
const { t } = require('react-i18next').useTranslation()
|
||||
const prefix = t('login.checkCode.tipsPrefix')
|
||||
|
||||
console.log('- Translation key used: login.checkCode.tipsPrefix')
|
||||
console.log('- Translation value:', prefix)
|
||||
console.log('- Contains HTML tags:', prefix.includes('<'))
|
||||
console.log('- Pure text content:', !prefix.includes('<') && !prefix.includes('>'))
|
||||
|
||||
// Verify translation is plain text
|
||||
expect(prefix).toBe('We send a verification code to ')
|
||||
expect(prefix).not.toContain('<')
|
||||
expect(prefix).not.toContain('>')
|
||||
expect(typeof prefix).toBe('string')
|
||||
|
||||
console.log('\n✅ Translation structure is secure - no HTML content')
|
||||
})
|
||||
|
||||
it('should confirm React automatic escaping works correctly', () => {
|
||||
console.log('\n⚡ React Security Mechanism Test')
|
||||
console.log('=================================')
|
||||
|
||||
// Test React's automatic escaping with various inputs
|
||||
const dangerousInputs = [
|
||||
'<script>alert("xss")</script>',
|
||||
'<img src="x" onerror="alert(1)">',
|
||||
'"><script>alert(2)</script>',
|
||||
'\'>alert(3)</script>',
|
||||
'<div onclick="alert(4)">click</div>',
|
||||
]
|
||||
|
||||
dangerousInputs.forEach((input, index) => {
|
||||
const TestComponent = () => <strong>{input}</strong>
|
||||
const { container } = render(<TestComponent />)
|
||||
|
||||
const strongElement = container.querySelector('strong')
|
||||
const scriptElements = container.querySelectorAll('script')
|
||||
|
||||
console.log(`\n🧪 Input ${index + 1}: ${input.substring(0, 30)}...`)
|
||||
console.log(` - Rendered as text: ${strongElement?.textContent === input}`)
|
||||
console.log(` - No script execution: ${scriptElements.length === 0}`)
|
||||
|
||||
expect(strongElement?.textContent).toBe(input)
|
||||
expect(scriptElements).toHaveLength(0)
|
||||
})
|
||||
|
||||
console.log('\n🛡️ React automatic escaping is working perfectly')
|
||||
})
|
||||
})
|
||||
|
||||
export {}
|
||||
|
|
@ -20,12 +20,18 @@ import cn from '@/utils/classnames'
|
|||
import { useStore } from '@/app/components/app/store'
|
||||
import AppSideBar from '@/app/components/app-sidebar'
|
||||
import type { NavIcon } from '@/app/components/app-sidebar/navLink'
|
||||
import { fetchAppDetail } from '@/service/apps'
|
||||
import { fetchAppDetailDirect } from '@/service/apps'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import type { App } from '@/types/app'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
|
||||
import dynamic from 'next/dynamic'
|
||||
|
||||
const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), {
|
||||
ssr: false,
|
||||
})
|
||||
|
||||
export type IAppDetailLayoutProps = {
|
||||
children: React.ReactNode
|
||||
|
|
@ -48,6 +54,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
|||
setAppDetail: state.setAppDetail,
|
||||
setAppSiderbarExpand: state.setAppSiderbarExpand,
|
||||
})))
|
||||
const showTagManagementModal = useTagStore(s => s.showTagManagementModal)
|
||||
const [isLoadingAppDetail, setIsLoadingAppDetail] = useState(false)
|
||||
const [appDetailRes, setAppDetailRes] = useState<App | null>(null)
|
||||
const [navigation, setNavigation] = useState<Array<{
|
||||
|
|
@ -111,7 +118,7 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
|||
useEffect(() => {
|
||||
setAppDetail()
|
||||
setIsLoadingAppDetail(true)
|
||||
fetchAppDetail({ url: '/apps', id: appId }).then((res) => {
|
||||
fetchAppDetailDirect({ url: '/apps', id: appId }).then((res: App) => {
|
||||
setAppDetailRes(res)
|
||||
}).catch((e: any) => {
|
||||
if (e.status === 404)
|
||||
|
|
@ -163,6 +170,9 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
|||
<div className="grow overflow-hidden bg-components-panel-bg">
|
||||
{children}
|
||||
</div>
|
||||
{showTagManagementModal && (
|
||||
<TagManagementModal type='app' show={showTagManagementModal} />
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,156 @@
|
|||
import React from 'react'
|
||||
import { render } from '@testing-library/react'
|
||||
import '@testing-library/jest-dom'
|
||||
import { OpikIconBig } from '@/app/components/base/icons/src/public/tracing'
|
||||
|
||||
// Mock dependencies to isolate the SVG rendering issue
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('SVG Attribute Error Reproduction', () => {
|
||||
// Capture console errors
|
||||
const originalError = console.error
|
||||
let errorMessages: string[] = []
|
||||
|
||||
beforeEach(() => {
|
||||
errorMessages = []
|
||||
console.error = jest.fn((message) => {
|
||||
errorMessages.push(message)
|
||||
originalError(message)
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
console.error = originalError
|
||||
})
|
||||
|
||||
it('should reproduce inkscape attribute errors when rendering OpikIconBig', () => {
|
||||
console.log('\n=== TESTING OpikIconBig SVG ATTRIBUTE ERRORS ===')
|
||||
|
||||
// Test multiple renders to check for inconsistency
|
||||
for (let i = 0; i < 5; i++) {
|
||||
console.log(`\nRender attempt ${i + 1}:`)
|
||||
|
||||
const { unmount } = render(<OpikIconBig />)
|
||||
|
||||
// Check for specific inkscape attribute errors
|
||||
const inkscapeErrors = errorMessages.filter(msg =>
|
||||
typeof msg === 'string' && msg.includes('inkscape'),
|
||||
)
|
||||
|
||||
if (inkscapeErrors.length > 0) {
|
||||
console.log(`Found ${inkscapeErrors.length} inkscape errors:`)
|
||||
inkscapeErrors.forEach((error, index) => {
|
||||
console.log(` ${index + 1}. ${error.substring(0, 100)}...`)
|
||||
})
|
||||
}
|
||||
else {
|
||||
console.log('No inkscape errors found in this render')
|
||||
}
|
||||
|
||||
unmount()
|
||||
|
||||
// Clear errors for next iteration
|
||||
errorMessages = []
|
||||
}
|
||||
})
|
||||
|
||||
it('should analyze the SVG structure causing the errors', () => {
|
||||
console.log('\n=== ANALYZING SVG STRUCTURE ===')
|
||||
|
||||
// Import the JSON data directly
|
||||
const iconData = require('@/app/components/base/icons/src/public/tracing/OpikIconBig.json')
|
||||
|
||||
console.log('Icon structure analysis:')
|
||||
console.log('- Root element:', iconData.icon.name)
|
||||
console.log('- Children count:', iconData.icon.children?.length || 0)
|
||||
|
||||
// Find problematic elements
|
||||
const findProblematicElements = (node: any, path = '') => {
|
||||
const problematicElements: any[] = []
|
||||
|
||||
if (node.name && (node.name.includes(':') || node.name.startsWith('sodipodi'))) {
|
||||
problematicElements.push({
|
||||
path,
|
||||
name: node.name,
|
||||
attributes: Object.keys(node.attributes || {}),
|
||||
})
|
||||
}
|
||||
|
||||
// Check attributes for inkscape/sodipodi properties
|
||||
if (node.attributes) {
|
||||
const problematicAttrs = Object.keys(node.attributes).filter(attr =>
|
||||
attr.startsWith('inkscape:') || attr.startsWith('sodipodi:'),
|
||||
)
|
||||
|
||||
if (problematicAttrs.length > 0) {
|
||||
problematicElements.push({
|
||||
path,
|
||||
name: node.name,
|
||||
problematicAttributes: problematicAttrs,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if (node.children) {
|
||||
node.children.forEach((child: any, index: number) => {
|
||||
problematicElements.push(
|
||||
...findProblematicElements(child, `${path}/${node.name}[${index}]`),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
return problematicElements
|
||||
}
|
||||
|
||||
const problematicElements = findProblematicElements(iconData.icon, 'root')
|
||||
|
||||
console.log(`\n🚨 Found ${problematicElements.length} problematic elements:`)
|
||||
problematicElements.forEach((element, index) => {
|
||||
console.log(`\n${index + 1}. Element: ${element.name}`)
|
||||
console.log(` Path: ${element.path}`)
|
||||
if (element.problematicAttributes)
|
||||
console.log(` Problematic attributes: ${element.problematicAttributes.join(', ')}`)
|
||||
})
|
||||
})
|
||||
|
||||
it('should test the normalizeAttrs function behavior', () => {
|
||||
console.log('\n=== TESTING normalizeAttrs FUNCTION ===')
|
||||
|
||||
const { normalizeAttrs } = require('@/app/components/base/icons/utils')
|
||||
|
||||
const testAttributes = {
|
||||
'inkscape:showpageshadow': '2',
|
||||
'inkscape:pageopacity': '0.0',
|
||||
'inkscape:pagecheckerboard': '0',
|
||||
'inkscape:deskcolor': '#d1d1d1',
|
||||
'sodipodi:docname': 'opik-icon-big.svg',
|
||||
'xmlns:inkscape': 'https://www.inkscape.org/namespaces/inkscape',
|
||||
'xmlns:sodipodi': 'https://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd',
|
||||
'xmlns:svg': 'https://www.w3.org/2000/svg',
|
||||
'data-name': 'Layer 1',
|
||||
'normal-attr': 'value',
|
||||
'class': 'test-class',
|
||||
}
|
||||
|
||||
console.log('Input attributes:', Object.keys(testAttributes))
|
||||
|
||||
const normalized = normalizeAttrs(testAttributes)
|
||||
|
||||
console.log('Normalized attributes:', Object.keys(normalized))
|
||||
console.log('Normalized values:', normalized)
|
||||
|
||||
// Check if problematic attributes are still present
|
||||
const problematicKeys = Object.keys(normalized).filter(key =>
|
||||
key.toLowerCase().includes('inkscape') || key.toLowerCase().includes('sodipodi'),
|
||||
)
|
||||
|
||||
if (problematicKeys.length > 0)
|
||||
console.log(`🚨 PROBLEM: Still found problematic attributes: ${problematicKeys.join(', ')}`)
|
||||
else
|
||||
console.log('✅ No problematic attributes found after normalization')
|
||||
})
|
||||
})
|
||||
|
|
@ -1,12 +1,9 @@
|
|||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import {
|
||||
RiEqualizer2Line,
|
||||
} from '@remixicon/react'
|
||||
import React, { useCallback, useRef, useState } from 'react'
|
||||
|
||||
import type { PopupProps } from './config-popup'
|
||||
import ConfigPopup from './config-popup'
|
||||
import cn from '@/utils/classnames'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
|
|
@ -17,13 +14,13 @@ type Props = {
|
|||
readOnly: boolean
|
||||
className?: string
|
||||
hasConfigured: boolean
|
||||
controlShowPopup?: number
|
||||
children?: React.ReactNode
|
||||
} & PopupProps
|
||||
|
||||
const ConfigBtn: FC<Props> = ({
|
||||
className,
|
||||
hasConfigured,
|
||||
controlShowPopup,
|
||||
children,
|
||||
...popupProps
|
||||
}) => {
|
||||
const [open, doSetOpen] = useState(false)
|
||||
|
|
@ -37,13 +34,6 @@ const ConfigBtn: FC<Props> = ({
|
|||
setOpen(!openRef.current)
|
||||
}, [setOpen])
|
||||
|
||||
useEffect(() => {
|
||||
if (controlShowPopup)
|
||||
// setOpen(!openRef.current)
|
||||
setOpen(true)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [controlShowPopup])
|
||||
|
||||
if (popupProps.readOnly && !hasConfigured)
|
||||
return null
|
||||
|
||||
|
|
@ -52,14 +42,11 @@ const ConfigBtn: FC<Props> = ({
|
|||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement='bottom-end'
|
||||
offset={{
|
||||
mainAxis: 12,
|
||||
crossAxis: hasConfigured ? 8 : 49,
|
||||
}}
|
||||
offset={12}
|
||||
>
|
||||
<PortalToFollowElemTrigger onClick={handleTrigger}>
|
||||
<div className={cn(className, 'rounded-md p-1')}>
|
||||
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
|
||||
<div className="select-none">
|
||||
{children}
|
||||
</div>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-[11]'>
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback, useEffect, useState } from 'react'
|
||||
import React, { useEffect, useState } from 'react'
|
||||
import {
|
||||
RiArrowDownDoubleLine,
|
||||
RiEqualizer2Line,
|
||||
} from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { usePathname } from 'next/navigation'
|
||||
|
|
@ -180,10 +181,6 @@ const Panel: FC = () => {
|
|||
})()
|
||||
}, [])
|
||||
|
||||
const [controlShowPopup, setControlShowPopup] = useState<number>(0)
|
||||
const showPopup = useCallback(() => {
|
||||
setControlShowPopup(Date.now())
|
||||
}, [setControlShowPopup])
|
||||
if (!isLoaded) {
|
||||
return (
|
||||
<div className='mb-3 flex items-center justify-between'>
|
||||
|
|
@ -196,46 +193,66 @@ const Panel: FC = () => {
|
|||
|
||||
return (
|
||||
<div className={cn('flex items-center justify-between')}>
|
||||
<div
|
||||
className={cn(
|
||||
'flex cursor-pointer items-center rounded-xl border-l-[0.5px] border-t border-effects-highlight bg-background-default-dodge p-2 shadow-xs hover:border-effects-highlight-lightmode-off hover:bg-background-default-lighter',
|
||||
controlShowPopup && 'border-effects-highlight-lightmode-off bg-background-default-lighter',
|
||||
)}
|
||||
onClick={showPopup}
|
||||
>
|
||||
{!inUseTracingProvider && (
|
||||
<>
|
||||
{!inUseTracingProvider && (
|
||||
<ConfigButton
|
||||
appId={appId}
|
||||
readOnly={readOnly}
|
||||
hasConfigured={false}
|
||||
enabled={enabled}
|
||||
onStatusChange={handleTracingEnabledChange}
|
||||
chosenProvider={inUseTracingProvider}
|
||||
onChooseProvider={handleChooseProvider}
|
||||
arizeConfig={arizeConfig}
|
||||
phoenixConfig={phoenixConfig}
|
||||
langSmithConfig={langSmithConfig}
|
||||
langFuseConfig={langFuseConfig}
|
||||
opikConfig={opikConfig}
|
||||
weaveConfig={weaveConfig}
|
||||
aliyunConfig={aliyunConfig}
|
||||
onConfigUpdated={handleTracingConfigUpdated}
|
||||
onConfigRemoved={handleTracingConfigRemoved}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
'flex cursor-pointer select-none items-center rounded-xl border-l-[0.5px] border-t border-effects-highlight bg-background-default-dodge p-2 shadow-xs hover:border-effects-highlight-lightmode-off hover:bg-background-default-lighter',
|
||||
)}
|
||||
>
|
||||
<TracingIcon size='md' />
|
||||
<div className='system-sm-semibold mx-2 text-text-secondary'>{t(`${I18N_PREFIX}.title`)}</div>
|
||||
<div className='flex items-center' onClick={e => e.stopPropagation()}>
|
||||
<ConfigButton
|
||||
appId={appId}
|
||||
readOnly={readOnly}
|
||||
hasConfigured={false}
|
||||
enabled={enabled}
|
||||
onStatusChange={handleTracingEnabledChange}
|
||||
chosenProvider={inUseTracingProvider}
|
||||
onChooseProvider={handleChooseProvider}
|
||||
arizeConfig={arizeConfig}
|
||||
phoenixConfig={phoenixConfig}
|
||||
langSmithConfig={langSmithConfig}
|
||||
langFuseConfig={langFuseConfig}
|
||||
opikConfig={opikConfig}
|
||||
weaveConfig={weaveConfig}
|
||||
aliyunConfig={aliyunConfig}
|
||||
onConfigUpdated={handleTracingConfigUpdated}
|
||||
onConfigRemoved={handleTracingConfigRemoved}
|
||||
controlShowPopup={controlShowPopup}
|
||||
/>
|
||||
<div className='rounded-md p-1'>
|
||||
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
|
||||
</div>
|
||||
<Divider type='vertical' className='h-3.5' />
|
||||
<div className='rounded-md p-1'>
|
||||
<RiArrowDownDoubleLine className='h-4 w-4 text-text-tertiary' />
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{hasConfiguredTracing && (
|
||||
<>
|
||||
</div>
|
||||
</ConfigButton>
|
||||
)}
|
||||
{hasConfiguredTracing && (
|
||||
<ConfigButton
|
||||
appId={appId}
|
||||
readOnly={readOnly}
|
||||
hasConfigured
|
||||
enabled={enabled}
|
||||
onStatusChange={handleTracingEnabledChange}
|
||||
chosenProvider={inUseTracingProvider}
|
||||
onChooseProvider={handleChooseProvider}
|
||||
arizeConfig={arizeConfig}
|
||||
phoenixConfig={phoenixConfig}
|
||||
langSmithConfig={langSmithConfig}
|
||||
langFuseConfig={langFuseConfig}
|
||||
opikConfig={opikConfig}
|
||||
weaveConfig={weaveConfig}
|
||||
aliyunConfig={aliyunConfig}
|
||||
onConfigUpdated={handleTracingConfigUpdated}
|
||||
onConfigRemoved={handleTracingConfigRemoved}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
'flex cursor-pointer select-none items-center rounded-xl border-l-[0.5px] border-t border-effects-highlight bg-background-default-dodge p-2 shadow-xs hover:border-effects-highlight-lightmode-off hover:bg-background-default-lighter',
|
||||
)}
|
||||
>
|
||||
<div className='ml-4 mr-1 flex items-center'>
|
||||
<Indicator color={enabled ? 'green' : 'gray'} />
|
||||
<div className='system-xs-semibold-uppercase ml-1.5 text-text-tertiary'>
|
||||
|
|
@ -243,33 +260,14 @@ const Panel: FC = () => {
|
|||
</div>
|
||||
</div>
|
||||
{InUseProviderIcon && <InUseProviderIcon className='ml-1 h-4' />}
|
||||
<Divider type='vertical' className='h-3.5' />
|
||||
<div className='flex items-center' onClick={e => e.stopPropagation()}>
|
||||
<ConfigButton
|
||||
appId={appId}
|
||||
readOnly={readOnly}
|
||||
hasConfigured
|
||||
className='ml-2'
|
||||
enabled={enabled}
|
||||
onStatusChange={handleTracingEnabledChange}
|
||||
chosenProvider={inUseTracingProvider}
|
||||
onChooseProvider={handleChooseProvider}
|
||||
arizeConfig={arizeConfig}
|
||||
phoenixConfig={phoenixConfig}
|
||||
langSmithConfig={langSmithConfig}
|
||||
langFuseConfig={langFuseConfig}
|
||||
opikConfig={opikConfig}
|
||||
weaveConfig={weaveConfig}
|
||||
aliyunConfig={aliyunConfig}
|
||||
onConfigUpdated={handleTracingConfigUpdated}
|
||||
onConfigRemoved={handleTracingConfigRemoved}
|
||||
controlShowPopup={controlShowPopup}
|
||||
/>
|
||||
<div className='ml-2 rounded-md p-1'>
|
||||
<RiEqualizer2Line className='h-4 w-4 text-text-tertiary' />
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div >
|
||||
</div >
|
||||
<Divider type='vertical' className='h-3.5' />
|
||||
</div>
|
||||
</ConfigButton>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(Panel)
|
||||
|
|
|
|||
|
|
@ -56,33 +56,50 @@ const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => {
|
|||
}, [isMobile, setShowTips])
|
||||
|
||||
return <div>
|
||||
{hasRelatedApps && (
|
||||
<>
|
||||
{!isMobile && (
|
||||
<Tooltip
|
||||
position='right'
|
||||
noDecoration
|
||||
popupContent={
|
||||
<LinkedAppsPanel
|
||||
relatedApps={relatedApps.data}
|
||||
isMobile={isMobile}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div className='system-xs-medium-uppercase inline-flex cursor-pointer items-center space-x-1 text-text-secondary'>
|
||||
<span>{relatedAppsTotal || '--'} {t('common.datasetMenus.relatedApp')}</span>
|
||||
<RiInformation2Line className='h-4 w-4' />
|
||||
</div>
|
||||
</Tooltip>
|
||||
)}
|
||||
{/* Related apps for desktop */}
|
||||
<div className={classNames(
|
||||
'transition-all duration-200 ease-in-out',
|
||||
(hasRelatedApps && !isMobile)
|
||||
? 'w-auto opacity-100'
|
||||
: 'pointer-events-none h-0 w-0 overflow-hidden opacity-0',
|
||||
)}>
|
||||
<Tooltip
|
||||
position='right'
|
||||
noDecoration
|
||||
popupContent={
|
||||
<LinkedAppsPanel
|
||||
relatedApps={relatedApps?.data || []}
|
||||
isMobile={isMobile}
|
||||
/>
|
||||
}
|
||||
>
|
||||
<div className='system-xs-medium-uppercase inline-flex cursor-pointer items-center space-x-1 whitespace-nowrap text-text-secondary'>
|
||||
<span>{relatedAppsTotal || '--'} {t('common.datasetMenus.relatedApp')}</span>
|
||||
<RiInformation2Line className='h-4 w-4' />
|
||||
</div>
|
||||
</Tooltip>
|
||||
</div>
|
||||
|
||||
{isMobile && <div className={classNames('pb-2 pt-4 text-xs font-medium uppercase text-text-tertiary', 'flex items-center justify-center gap-1 !px-0')}>
|
||||
{relatedAppsTotal || '--'}
|
||||
<PaperClipIcon className='h-4 w-4 text-text-secondary' />
|
||||
</div>}
|
||||
</>
|
||||
)}
|
||||
{!hasRelatedApps && !expand && (
|
||||
{/* Related apps for mobile */}
|
||||
<div className={classNames(
|
||||
'transition-all duration-200 ease-in-out',
|
||||
(hasRelatedApps && isMobile)
|
||||
? 'w-auto opacity-100'
|
||||
: 'pointer-events-none h-0 w-0 overflow-hidden opacity-0',
|
||||
)}>
|
||||
<div className={classNames('pb-2 pt-4 text-xs font-medium uppercase text-text-tertiary', 'flex items-center justify-center gap-1 whitespace-nowrap !px-0')}>
|
||||
{relatedAppsTotal || '--'}
|
||||
<PaperClipIcon className='h-4 w-4 text-text-secondary' />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* No related apps tooltip */}
|
||||
<div className={classNames(
|
||||
'transition-all duration-200 ease-in-out',
|
||||
(!hasRelatedApps && !expand)
|
||||
? 'w-auto opacity-100'
|
||||
: 'pointer-events-none h-0 w-0 overflow-hidden opacity-0',
|
||||
)}>
|
||||
<Tooltip
|
||||
position='right'
|
||||
noDecoration
|
||||
|
|
@ -103,12 +120,12 @@ const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => {
|
|||
</div>
|
||||
}
|
||||
>
|
||||
<div className='system-xs-medium-uppercase inline-flex cursor-pointer items-center space-x-1 text-text-secondary'>
|
||||
<div className='system-xs-medium-uppercase inline-flex cursor-pointer items-center space-x-1 whitespace-nowrap text-text-secondary'>
|
||||
<span>{t('common.datasetMenus.noRelatedApp')}</span>
|
||||
<RiInformation2Line className='h-4 w-4' />
|
||||
</div>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
import React from 'react'
|
||||
import DatasetUpdateForm from '@/app/components/datasets/create'
|
||||
|
||||
type Props = {}
|
||||
|
||||
const DatasetCreation = async (props: Props) => {
|
||||
const DatasetCreation = async () => {
|
||||
return (
|
||||
<DatasetUpdateForm />
|
||||
)
|
||||
|
|
|
|||
|
|
@ -70,7 +70,10 @@ export default function CheckCode() {
|
|||
<div className='pb-4 pt-2'>
|
||||
<h2 className='title-4xl-semi-bold text-text-primary'>{t('login.checkCode.checkYourEmail')}</h2>
|
||||
<p className='body-md-regular mt-2 text-text-secondary'>
|
||||
<span dangerouslySetInnerHTML={{ __html: t('login.checkCode.tips', { email }) as string }}></span>
|
||||
<span>
|
||||
{t('login.checkCode.tipsPrefix')}
|
||||
<strong>{email}</strong>
|
||||
</span>
|
||||
<br />
|
||||
{t('login.checkCode.validTime')}
|
||||
</p>
|
||||
|
|
|
|||
|
|
@ -93,7 +93,10 @@ export default function CheckCode() {
|
|||
<div className='pb-4 pt-2'>
|
||||
<h2 className='title-4xl-semi-bold text-text-primary'>{t('login.checkCode.checkYourEmail')}</h2>
|
||||
<p className='body-md-regular mt-2 text-text-secondary'>
|
||||
<span dangerouslySetInnerHTML={{ __html: t('login.checkCode.tips', { email }) as string }}></span>
|
||||
<span>
|
||||
{t('login.checkCode.tipsPrefix')}
|
||||
<strong>{email}</strong>
|
||||
</span>
|
||||
<br />
|
||||
{t('login.checkCode.validTime')}
|
||||
</p>
|
||||
|
|
|
|||
|
|
@ -29,15 +29,17 @@ const DatasetInfo: FC<Props> = ({
|
|||
<div className='mr-3 shrink-0'>
|
||||
<AppIcon innerIcon={DatasetSvg} className='!border-[0.5px] !border-indigo-100 !bg-indigo-25' />
|
||||
</div>
|
||||
{expand && (
|
||||
<div className='mt-2'>
|
||||
<div className='system-md-semibold text-text-secondary'>
|
||||
{name}
|
||||
</div>
|
||||
<div className='system-2xs-medium-uppercase mt-1 text-text-tertiary'>{isExternal ? t('dataset.externalTag') : t('dataset.localDocs')}</div>
|
||||
<div className='system-xs-regular my-3 text-text-tertiary first-letter:capitalize'>{description}</div>
|
||||
<div className={`transition-all duration-200 ease-in-out ${
|
||||
expand
|
||||
? 'mt-2 w-auto opacity-100'
|
||||
: 'pointer-events-none h-0 w-0 overflow-hidden opacity-0'
|
||||
}`}>
|
||||
<div className='system-md-semibold truncate whitespace-nowrap text-text-secondary'>
|
||||
{name}
|
||||
</div>
|
||||
)}
|
||||
<div className='system-2xs-medium-uppercase mt-1 whitespace-nowrap text-text-tertiary'>{isExternal ? t('dataset.externalTag') : t('dataset.localDocs')}</div>
|
||||
<div className='system-xs-regular my-3 whitespace-nowrap text-text-tertiary first-letter:capitalize'>{description}</div>
|
||||
</div>
|
||||
{extraInfo}
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -88,7 +88,8 @@ const HeaderOptions: FC<Props> = ({
|
|||
await clearAllAnnotations(appId)
|
||||
onAdded()
|
||||
}
|
||||
catch (_) {
|
||||
catch (e) {
|
||||
console.error(`failed to clear all annotations, ${e}`)
|
||||
}
|
||||
finally {
|
||||
setShowClearConfirm(false)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import SelectTypeItem from '../select-type-item'
|
|||
import Field from './field'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { checkKeys, getNewVarInWorkflow, replaceSpaceWithUnderscreInVarNameInput } from '@/utils/var'
|
||||
import { checkKeys, getNewVarInWorkflow, replaceSpaceWithUnderscoreInVarNameInput } from '@/utils/var'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
import type { InputVar, MoreInfo, UploadFileSetting } from '@/app/components/workflow/types'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
|
|
@ -111,7 +111,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
|
|||
}, [checkVariableName, tempPayload.label])
|
||||
|
||||
const handleVarNameChange = useCallback((e: ChangeEvent<any>) => {
|
||||
replaceSpaceWithUnderscreInVarNameInput(e.target)
|
||||
replaceSpaceWithUnderscoreInVarNameInput(e.target)
|
||||
const value = e.target.value
|
||||
const { isValid, errorKey, errorMessageKey } = checkKeys([value], true)
|
||||
if (!isValid) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import React from 'react'
|
||||
import React, { useState } from 'react'
|
||||
import Link from 'next/link'
|
||||
import { RiDiscordFill, RiGithubFill } from '@remixicon/react'
|
||||
import { RiCloseLine, RiDiscordFill, RiGithubFill } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
type CustomLinkProps = {
|
||||
|
|
@ -26,9 +26,24 @@ const CustomLink = React.memo(({
|
|||
|
||||
const Footer = () => {
|
||||
const { t } = useTranslation()
|
||||
const [isVisible, setIsVisible] = useState(true)
|
||||
|
||||
const handleClose = () => {
|
||||
setIsVisible(false)
|
||||
}
|
||||
|
||||
if (!isVisible)
|
||||
return null
|
||||
|
||||
return (
|
||||
<footer className='shrink-0 grow-0 px-12 py-6'>
|
||||
<footer className='relative shrink-0 grow-0 px-12 py-2'>
|
||||
<button
|
||||
onClick={handleClose}
|
||||
className='absolute right-2 top-2 flex h-6 w-6 cursor-pointer items-center justify-center rounded-full transition-colors duration-200 ease-in-out hover:bg-components-main-nav-nav-button-bg-active'
|
||||
aria-label="Close footer"
|
||||
>
|
||||
<RiCloseLine className='h-4 w-4 text-text-tertiary hover:text-text-secondary' />
|
||||
</button>
|
||||
<h3 className='text-gradient text-xl font-semibold leading-tight'>{t('app.join')}</h3>
|
||||
<p className='system-sm-regular mt-1 text-text-tertiary'>{t('app.communityIntro')}</p>
|
||||
<div className='mt-3 flex items-center gap-2'>
|
||||
|
|
|
|||
|
|
@ -115,8 +115,11 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
|
|||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if (appData?.site.default_language)
|
||||
changeLanguage(appData.site.default_language)
|
||||
const setLocaleFromProps = async () => {
|
||||
if (appData?.site.default_language)
|
||||
await changeLanguage(appData.site.default_language)
|
||||
}
|
||||
setLocaleFromProps()
|
||||
}, [appData])
|
||||
|
||||
const [sidebarCollapseState, setSidebarCollapseState] = useState<boolean>(false)
|
||||
|
|
@ -159,9 +162,21 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
|
|||
return currentConversationId
|
||||
}, [currentConversationId, newConversationId])
|
||||
|
||||
const { data: appPinnedConversationData, mutate: mutateAppPinnedConversationData } = useSWR(['appConversationData', isInstalledApp, appId, true], () => fetchConversations(isInstalledApp, appId, undefined, true, 100))
|
||||
const { data: appConversationData, isLoading: appConversationDataLoading, mutate: mutateAppConversationData } = useSWR(['appConversationData', isInstalledApp, appId, false], () => fetchConversations(isInstalledApp, appId, undefined, false, 100))
|
||||
const { data: appChatListData, isLoading: appChatListDataLoading } = useSWR(chatShouldReloadKey ? ['appChatList', chatShouldReloadKey, isInstalledApp, appId] : null, () => fetchChatList(chatShouldReloadKey, isInstalledApp, appId))
|
||||
const { data: appPinnedConversationData, mutate: mutateAppPinnedConversationData } = useSWR(
|
||||
appId ? ['appConversationData', isInstalledApp, appId, true] : null,
|
||||
() => fetchConversations(isInstalledApp, appId, undefined, true, 100),
|
||||
{ revalidateOnFocus: false, revalidateOnReconnect: false },
|
||||
)
|
||||
const { data: appConversationData, isLoading: appConversationDataLoading, mutate: mutateAppConversationData } = useSWR(
|
||||
appId ? ['appConversationData', isInstalledApp, appId, false] : null,
|
||||
() => fetchConversations(isInstalledApp, appId, undefined, false, 100),
|
||||
{ revalidateOnFocus: false, revalidateOnReconnect: false },
|
||||
)
|
||||
const { data: appChatListData, isLoading: appChatListDataLoading } = useSWR(
|
||||
chatShouldReloadKey ? ['appChatList', chatShouldReloadKey, isInstalledApp, appId] : null,
|
||||
() => fetchChatList(chatShouldReloadKey, isInstalledApp, appId),
|
||||
{ revalidateOnFocus: false, revalidateOnReconnect: false },
|
||||
)
|
||||
|
||||
const [clearChatList, setClearChatList] = useState(false)
|
||||
const [isResponding, setIsResponding] = useState(false)
|
||||
|
|
|
|||
|
|
@ -101,15 +101,15 @@ export const useEmbeddedChatbot = () => {
|
|||
|
||||
if (localeParam) {
|
||||
// If locale parameter exists in URL, use it instead of default
|
||||
changeLanguage(localeParam)
|
||||
await changeLanguage(localeParam)
|
||||
}
|
||||
else if (localeFromSysVar) {
|
||||
// If locale is set as a system variable, use that
|
||||
changeLanguage(localeFromSysVar)
|
||||
await changeLanguage(localeFromSysVar)
|
||||
}
|
||||
else if (appInfo?.site.default_language) {
|
||||
// Otherwise use the default from app config
|
||||
changeLanguage(appInfo.site.default_language)
|
||||
await changeLanguage(appInfo.site.default_language)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import {
|
|||
memo,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { RiExternalLinkLine } from '@remixicon/react'
|
||||
import type { AnyFieldApi } from '@tanstack/react-form'
|
||||
import { useStore } from '@tanstack/react-form'
|
||||
import cn from '@/utils/classnames'
|
||||
|
|
@ -200,6 +201,22 @@ const BaseField = ({
|
|||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
formSchema.url && (
|
||||
<a
|
||||
className='system-xs-regular mt-4 flex items-center text-text-accent'
|
||||
href={formSchema?.url}
|
||||
target='_blank'
|
||||
>
|
||||
<span className='break-all'>
|
||||
{renderI18nObject(formSchema?.help as any)}
|
||||
</span>
|
||||
{
|
||||
<RiExternalLinkLine className='ml-1 h-3 w-3' />
|
||||
}
|
||||
</a>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,9 +14,26 @@ export type Attrs = {
|
|||
|
||||
export function normalizeAttrs(attrs: Attrs = {}): Attrs {
|
||||
return Object.keys(attrs).reduce((acc: Attrs, key) => {
|
||||
// Filter out editor metadata attributes before processing
|
||||
if (key.startsWith('inkscape:')
|
||||
|| key.startsWith('sodipodi:')
|
||||
|| key.startsWith('xmlns:inkscape')
|
||||
|| key.startsWith('xmlns:sodipodi')
|
||||
|| key.startsWith('xmlns:svg')
|
||||
|| key === 'data-name')
|
||||
return acc
|
||||
|
||||
const val = attrs[key]
|
||||
key = key.replace(/([-]\w)/g, (g: string) => g[1].toUpperCase())
|
||||
key = key.replace(/([:]\w)/g, (g: string) => g[1].toUpperCase())
|
||||
|
||||
// Additional filter after camelCase conversion
|
||||
if (key === 'xmlnsInkscape'
|
||||
|| key === 'xmlnsSodipodi'
|
||||
|| key === 'xmlnsSvg'
|
||||
|| key === 'dataName')
|
||||
return acc
|
||||
|
||||
switch (key) {
|
||||
case 'class':
|
||||
acc.className = val
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue