mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into e-300
This commit is contained in:
commit
b38286f06a
|
|
@ -476,6 +476,7 @@ LOGIN_LOCKOUT_DURATION=86400
|
|||
ENABLE_OTEL=false
|
||||
OTLP_BASE_ENDPOINT=http://localhost:4318
|
||||
OTLP_API_KEY=
|
||||
OTEL_EXPORTER_OTLP_PROTOCOL=
|
||||
OTEL_EXPORTER_TYPE=otlp
|
||||
OTEL_SAMPLING_RATE=0.1
|
||||
OTEL_BATCH_EXPORT_SCHEDULE_DELAY=5000
|
||||
|
|
|
|||
|
|
@ -54,7 +54,6 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_otel,
|
||||
ext_proxy_fix,
|
||||
ext_redis,
|
||||
ext_repositories,
|
||||
ext_sentry,
|
||||
ext_set_secretkey,
|
||||
ext_storage,
|
||||
|
|
@ -75,7 +74,6 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_migrate,
|
||||
ext_redis,
|
||||
ext_storage,
|
||||
ext_repositories,
|
||||
ext_celery,
|
||||
ext_login,
|
||||
ext_mail,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Optional
|
|||
|
||||
import click
|
||||
from flask import current_app
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -297,11 +298,11 @@ def migrate_knowledge_vector_database():
|
|||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = (
|
||||
Dataset.query.filter(Dataset.indexing_technique == "high_quality")
|
||||
.order_by(Dataset.created_at.desc())
|
||||
.paginate(page=page, per_page=50)
|
||||
stmt = (
|
||||
select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
|
||||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
|
|
@ -551,11 +552,12 @@ def old_metadata_migration():
|
|||
page = 1
|
||||
while True:
|
||||
try:
|
||||
documents = (
|
||||
DatasetDocument.query.filter(DatasetDocument.doc_metadata is not None)
|
||||
stmt = (
|
||||
select(DatasetDocument)
|
||||
.filter(DatasetDocument.doc_metadata.is_not(None))
|
||||
.order_by(DatasetDocument.created_at.desc())
|
||||
.paginate(page=page, per_page=50)
|
||||
)
|
||||
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
except NotFound:
|
||||
break
|
||||
if not documents:
|
||||
|
|
@ -592,11 +594,15 @@ def old_metadata_migration():
|
|||
)
|
||||
db.session.add(dataset_metadata_binding)
|
||||
else:
|
||||
dataset_metadata_binding = DatasetMetadataBinding.query.filter(
|
||||
DatasetMetadataBinding.dataset_id == document.dataset_id,
|
||||
DatasetMetadataBinding.document_id == document.id,
|
||||
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
|
||||
).first()
|
||||
dataset_metadata_binding = (
|
||||
db.session.query(DatasetMetadataBinding) # type: ignore
|
||||
.filter(
|
||||
DatasetMetadataBinding.dataset_id == document.dataset_id,
|
||||
DatasetMetadataBinding.document_id == document.id,
|
||||
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not dataset_metadata_binding:
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
tenant_id=document.tenant_id,
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||
|
||||
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
|
||||
description="URL endpoint for the code execution service",
|
||||
default="http://sandbox:8194",
|
||||
default=HttpUrl("http://sandbox:8194"),
|
||||
)
|
||||
|
||||
CODE_EXECUTION_API_KEY: str = Field(
|
||||
|
|
@ -145,7 +145,7 @@ class PluginConfig(BaseSettings):
|
|||
|
||||
PLUGIN_DAEMON_URL: HttpUrl = Field(
|
||||
description="Plugin API URL",
|
||||
default="http://localhost:5002",
|
||||
default=HttpUrl("http://localhost:5002"),
|
||||
)
|
||||
|
||||
PLUGIN_DAEMON_KEY: str = Field(
|
||||
|
|
@ -188,7 +188,7 @@ class MarketplaceConfig(BaseSettings):
|
|||
|
||||
MARKETPLACE_API_URL: HttpUrl = Field(
|
||||
description="Marketplace API URL",
|
||||
default="https://marketplace.dify.ai",
|
||||
default=HttpUrl("https://marketplace.dify.ai"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
from typing import Any, Literal, Optional
|
||||
from urllib.parse import quote_plus
|
||||
from urllib.parse import parse_qsl, quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
|
@ -173,17 +173,31 @@ class DatabaseConfig(BaseSettings):
|
|||
|
||||
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
|
||||
description="Number of processes for the retrieval service, default to CPU cores.",
|
||||
default=os.cpu_count(),
|
||||
default=os.cpu_count() or 1,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
# Parse DB_EXTRAS for 'options'
|
||||
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
|
||||
options = db_extras_dict.get("options", "")
|
||||
# Always include timezone
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
# Merge user options and timezone
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
return {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
|
||||
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
|
||||
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
|
||||
"connect_args": {"options": "-c timezone=UTC"},
|
||||
"connect_args": connect_args,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -83,3 +83,13 @@ class RedisConfig(BaseSettings):
|
|||
description="Password for Redis Clusters authentication (if required)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SERIALIZATION_PROTOCOL: int = Field(
|
||||
description="Redis serialization protocol (RESP) version",
|
||||
default=3,
|
||||
)
|
||||
|
||||
REDIS_ENABLE_CLIENT_SIDE_CACHE: bool = Field(
|
||||
description="Enable client side cache in redis",
|
||||
default=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,11 @@ class OTelConfig(BaseSettings):
|
|||
default="otlp",
|
||||
)
|
||||
|
||||
OTEL_EXPORTER_OTLP_PROTOCOL: str = Field(
|
||||
description="OTLP exporter protocol ('grpc' or 'http')",
|
||||
default="http",
|
||||
)
|
||||
|
||||
OTEL_SAMPLING_RATE: float = Field(default=0.1, description="Sampling rate for traces (0.0 to 1.0)")
|
||||
|
||||
OTEL_BATCH_EXPORT_SCHEDULE_DELAY: int = Field(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from flask_restful import fields
|
||||
|
||||
from libs.helper import AppIconUrlField
|
||||
|
||||
parameters__system_parameters = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
|
|
@ -22,3 +24,20 @@ parameters_fields = {
|
|||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
||||
}
|
||||
|
||||
site_fields = {
|
||||
"title": fields.String,
|
||||
"chat_color_theme": fields.String,
|
||||
"chat_color_theme_inverted": fields.Boolean,
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"description": fields.String,
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"default_language": fields.String,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -526,14 +526,20 @@ class DatasetIndexingStatusApi(Resource):
|
|||
)
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
).count()
|
||||
total_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
||||
).count()
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
documents_status.append(marshal(document, document_status_fields))
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import cast
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import asc, desc
|
||||
from sqlalchemy import asc, desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -112,7 +112,7 @@ class GetProcessRuleApi(Resource):
|
|||
limits = DocumentService.DEFAULT_RULES["limits"]
|
||||
if document_id:
|
||||
# get the latest process rule
|
||||
document = Document.query.get_or_404(document_id)
|
||||
document = db.get_or_404(Document, document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(document.dataset_id)
|
||||
|
||||
|
|
@ -175,7 +175,7 @@ class DatasetDocumentListApi(Resource):
|
|||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
|
|
@ -209,18 +209,24 @@ class DatasetDocumentListApi(Resource):
|
|||
desc(Document.position),
|
||||
)
|
||||
|
||||
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
).count()
|
||||
total_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
||||
).count()
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
data = marshal(documents, document_with_segments_fields)
|
||||
|
|
@ -563,14 +569,20 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
|||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
).count()
|
||||
total_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
||||
).count()
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
|
|
@ -589,14 +601,20 @@ class DocumentIndexingStatusApi(DocumentResource):
|
|||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
completed_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
).count()
|
||||
total_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment"
|
||||
).count()
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import pandas as pd
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -26,6 +27,7 @@ from controllers.console.wraps import (
|
|||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.login import login_required
|
||||
|
|
@ -74,9 +76,14 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
hit_count_gte = args["hit_count_gte"]
|
||||
keyword = args["keyword"]
|
||||
|
||||
query = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).order_by(DocumentSegment.position.asc())
|
||||
query = (
|
||||
select(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
)
|
||||
|
||||
if status_list:
|
||||
query = query.filter(DocumentSegment.status.in_(status_list))
|
||||
|
|
@ -93,7 +100,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||
elif args["enabled"].lower() == "false":
|
||||
query = query.filter(DocumentSegment.enabled == False)
|
||||
|
||||
segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
response = {
|
||||
"data": marshal(segments.items, segment_fields),
|
||||
|
|
@ -276,9 +283,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
|
|
@ -320,9 +329,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
|
|
@ -423,9 +434,11 @@ class ChildChunkAddApi(Resource):
|
|||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
if not current_user.is_dataset_editor:
|
||||
|
|
@ -478,9 +491,11 @@ class ChildChunkAddApi(Resource):
|
|||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -523,9 +538,11 @@ class ChildChunkAddApi(Resource):
|
|||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
|
|
@ -567,16 +584,20 @@ class ChildChunkUpdateApi(Resource):
|
|||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = ChildChunk.query.filter(
|
||||
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
|
|
@ -612,16 +633,20 @@ class ChildChunkUpdateApi(Resource):
|
|||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = ChildChunk.query.filter(
|
||||
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
|
|
|
|||
|
|
@ -209,6 +209,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
|
@ -219,6 +220,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||
query=args["query"],
|
||||
account=current_user,
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class InstalledAppsListApi(Resource):
|
|||
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
||||
args = parser.parse_args()
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
if recommended_app is None:
|
||||
raise NotFound("App not found")
|
||||
|
||||
|
|
@ -100,9 +100,11 @@ class InstalledAppsListApi(Resource):
|
|||
if not app.is_public:
|
||||
raise Forbidden("You can't install a non-public app")
|
||||
|
||||
installed_app = InstalledApp.query.filter(
|
||||
and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)
|
||||
).first()
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
if installed_app is None:
|
||||
# todo: position
|
||||
|
|
|
|||
|
|
@ -79,7 +79,6 @@ class MemberInviteEmailApi(Resource):
|
|||
invitation_results.append(
|
||||
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import logging
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
|
|
@ -88,9 +89,8 @@ class WorkspaceListApi(Resource):
|
|||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(
|
||||
page=args["page"], per_page=args["limit"], error_out=False
|
||||
)
|
||||
stmt = select(Tenant).order_by(Tenant.created_at.desc())
|
||||
tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
has_more = False
|
||||
|
||||
if tenants.has_next:
|
||||
|
|
@ -162,7 +162,7 @@ class CustomConfigWorkspaceApi(Resource):
|
|||
parser.add_argument("replace_webapp_logo", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404()
|
||||
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
||||
|
||||
custom_config_dict = {
|
||||
"remove_webapp_brand": args["remove_webapp_brand"],
|
||||
|
|
@ -226,7 +226,7 @@ class WorkspaceInfoApi(Resource):
|
|||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404()
|
||||
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
|
||||
tenant.name = args["name"]
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
|||
api = ExternalApi(bp)
|
||||
|
||||
from . import index
|
||||
from .app import annotation, app, audio, completion, conversation, file, message, workflow
|
||||
from .app import annotation, app, audio, completion, conversation, file, message, site, workflow
|
||||
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
|
||||
from .workspace import models
|
||||
|
|
|
|||
|
|
@ -93,6 +93,18 @@ class MessageFeedbackApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
class AppGetFeedbacksApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get All Feedbacks of an app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 101), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"])
|
||||
return {"data": feedbacks}
|
||||
|
||||
|
||||
class MessageSuggestedApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||
|
|
@ -119,3 +131,4 @@ class MessageSuggestedApi(Resource):
|
|||
api.add_resource(MessageListApi, "/messages")
|
||||
api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks")
|
||||
api.add_resource(MessageSuggestedApi, "/messages/<uuid:message_id>/suggested")
|
||||
api.add_resource(AppGetFeedbacksApi, "/app/feedbacks")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
from flask_restful import Resource, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
|
||||
|
||||
class AppSiteApi(Resource):
|
||||
"""Resource for app sites."""
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(fields.site_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app site info."""
|
||||
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
api.add_resource(AppSiteApi, "/site")
|
||||
|
|
@ -313,7 +313,7 @@ class DatasetApi(DatasetApiResource):
|
|||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return {"result": "success"}, 204
|
||||
return 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
except services.errors.dataset.DatasetInUseError:
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ import json
|
|||
|
||||
from flask import request
|
||||
from flask_restful import marshal, reqparse
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import desc, select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.dataset_service
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
|
|
@ -323,7 +323,7 @@ class DocumentDeleteApi(DatasetApiResource):
|
|||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return 204
|
||||
|
||||
|
||||
class DocumentListApi(DatasetApiResource):
|
||||
|
|
@ -337,7 +337,7 @@ class DocumentListApi(DatasetApiResource):
|
|||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
|
|
@ -345,7 +345,7 @@ class DocumentListApi(DatasetApiResource):
|
|||
|
||||
query = query.order_by(desc(Document.created_at), desc(Document.position))
|
||||
|
||||
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
response = {
|
||||
|
|
@ -374,14 +374,20 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
|||
raise NotFound("Documents not found.")
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
).count()
|
||||
total_segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
|
||||
).count()
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != "re_segment",
|
||||
)
|
||||
.count()
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
|
||||
.count()
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
|||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return {"result": "success"}, 204
|
||||
return 204
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
|
|
@ -344,7 +344,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
|||
except ChildChunkDeleteIndexServiceError as e:
|
||||
raise ChildChunkDeleteIndexError(str(e))
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return 204
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from flask import request
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.web import api
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.app_config.common.parameters_mapping import \
|
||||
get_parameters_from_feature_dict
|
||||
from flask import request
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from libs.passport import PassportService
|
||||
from models.model import App, AppMode
|
||||
from services.app_service import AppService
|
||||
|
|
|
|||
|
|
@ -91,6 +91,8 @@ class BaseAgentRunner(AppRunner):
|
|||
return_resource=app_config.additional_features.show_retrieve_source,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
user_id=user_id,
|
||||
inputs=cast(dict, application_generate_entity.inputs),
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = (
|
||||
|
|
|
|||
|
|
@ -69,13 +69,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
|
||||
# fix metadata filter not work
|
||||
if app_config.dataset is not None:
|
||||
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
|
||||
for key, dataset_retriever_tool in tool_instances.items():
|
||||
if hasattr(dataset_retriever_tool, "retrieval_tool"):
|
||||
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
|
||||
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
|
|
@ -87,6 +80,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.total_tokens += usage.total_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
llm_usage.total_price += usage.total_price
|
||||
|
|
|
|||
|
|
@ -45,13 +45,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
# fix metadata filter not work
|
||||
if app_config.dataset is not None:
|
||||
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
|
||||
for key, dataset_retriever_tool in tool_instances.items():
|
||||
if hasattr(dataset_retriever_tool, "retrieval_tool"):
|
||||
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
|
|
@ -72,6 +65,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.total_tokens += usage.total_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
llm_usage.completion_price += usage.completion_price
|
||||
llm_usage.total_price += usage.total_price
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
|
|||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.workflow.repository import RepositoryFactory
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
|
|
@ -163,12 +163,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": application_generate_entity.app_config.tenant_id,
|
||||
"app_id": application_generate_entity.app_config.app_id,
|
||||
"session_factory": session_factory,
|
||||
}
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
|
@ -231,12 +229,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": application_generate_entity.app_config.tenant_id,
|
||||
"app_id": application_generate_entity.app_config.app_id,
|
||||
"session_factory": session_factory,
|
||||
}
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
|
@ -297,12 +293,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": application_generate_entity.app_config.tenant_id,
|
||||
"app_id": application_generate_entity.app_config.app_id,
|
||||
"session_factory": session_factory,
|
||||
}
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
|
|
@ -58,7 +57,7 @@ from core.app.entities.task_entities import (
|
|||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
|
@ -66,6 +65,7 @@ from core.workflow.enums import SystemVariableKey
|
|||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
|
|
@ -113,7 +113,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
else:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
|
|
|
|||
|
|
@ -18,13 +18,13 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
|||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.repository import RepositoryFactory
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser, Workflow
|
||||
|
|
@ -138,12 +138,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": application_generate_entity.app_config.tenant_id,
|
||||
"app_id": application_generate_entity.app_config.app_id,
|
||||
"session_factory": session_factory,
|
||||
}
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
|
@ -264,12 +262,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": application_generate_entity.app_config.tenant_id,
|
||||
"app_id": application_generate_entity.app_config.app_id,
|
||||
"session_factory": session_factory,
|
||||
}
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
|
@ -329,12 +325,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": application_generate_entity.app_config.tenant_id,
|
||||
"app_id": application_generate_entity.app_config.app_id,
|
||||
"session_factory": session_factory,
|
||||
}
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
|
|
@ -45,6 +44,7 @@ from core.app.entities.task_entities import (
|
|||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
# Core base package
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from core.base.tts.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
|
||||
__all__ = [
|
||||
"AppGeneratorTTSPublisher",
|
||||
"AudioTrunk",
|
||||
]
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
|
|
@ -7,6 +9,8 @@ from extensions.ext_database import db
|
|||
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetIndexToolCallbackHandler:
|
||||
"""Callback handler for dataset tool."""
|
||||
|
|
@ -42,18 +46,31 @@ class DatasetIndexToolCallbackHandler:
|
|||
"""Handle tool end."""
|
||||
for document in documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
).first()
|
||||
document_id = document.metadata["document_id"]
|
||||
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
|
||||
if not dataset_document:
|
||||
_logger.warning(
|
||||
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
|
||||
document_id,
|
||||
)
|
||||
continue
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = ChildChunk.query.filter(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
).first()
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.filter(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if child_chunk:
|
||||
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class IndexingRunner:
|
|||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# get dataset
|
||||
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
|
@ -103,15 +103,17 @@ class IndexingRunner:
|
|||
"""Run the indexing process when the index_status is splitting."""
|
||||
try:
|
||||
# get dataset
|
||||
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
||||
# get exist document_segment list and delete
|
||||
document_segments = DocumentSegment.query.filter_by(
|
||||
dataset_id=dataset.id, document_id=dataset_document.id
|
||||
).all()
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for document_segment in document_segments:
|
||||
db.session.delete(document_segment)
|
||||
|
|
@ -162,15 +164,17 @@ class IndexingRunner:
|
|||
"""Run the indexing process when the index_status is indexing."""
|
||||
try:
|
||||
# get dataset
|
||||
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
||||
# get exist document_segment list and delete
|
||||
document_segments = DocumentSegment.query.filter_by(
|
||||
dataset_id=dataset.id, document_id=dataset_document.id
|
||||
).all()
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
documents = []
|
||||
if document_segments:
|
||||
|
|
@ -254,7 +258,7 @@ class IndexingRunner:
|
|||
|
||||
embedding_model_instance = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(id=dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found.")
|
||||
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
|
||||
|
|
@ -587,7 +591,7 @@ class IndexingRunner:
|
|||
@staticmethod
|
||||
def _process_keyword_index(flask_app, dataset_id, document_id, documents):
|
||||
with flask_app.app_context():
|
||||
dataset = Dataset.query.filter_by(id=dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
keyword = Keyword(dataset)
|
||||
|
|
@ -656,10 +660,10 @@ class IndexingRunner:
|
|||
"""
|
||||
Update the document indexing status.
|
||||
"""
|
||||
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
|
||||
count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count()
|
||||
if count > 0:
|
||||
raise DocumentIsPausedError()
|
||||
document = DatasetDocument.query.filter_by(id=document_id).first()
|
||||
document = db.session.query(DatasetDocument).filter_by(id=document_id).first()
|
||||
if not document:
|
||||
raise DocumentIsDeletedPausedError()
|
||||
|
||||
|
|
@ -668,7 +672,7 @@ class IndexingRunner:
|
|||
if extra_update_params:
|
||||
update_params.update(extra_update_params)
|
||||
|
||||
DatasetDocument.query.filter_by(id=document_id).update(update_params)
|
||||
db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -676,7 +680,7 @@ class IndexingRunner:
|
|||
"""
|
||||
Update the document segment by document id.
|
||||
"""
|
||||
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
|
||||
db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
def _transform(
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
|
||||
|
||||
class TracingProviderEnum(Enum):
|
||||
class TracingProviderEnum(StrEnum):
|
||||
LANGFUSE = "langfuse"
|
||||
LANGSMITH = "langsmith"
|
||||
OPIK = "opik"
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
|||
UnitEnum,
|
||||
)
|
||||
from core.ops.utils import filter_none_values
|
||||
from core.workflow.repository.repository_factory import RepositoryFactory
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
|
@ -113,8 +113,8 @@ class LangFuseDataTrace(BaseTraceInstance):
|
|||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory},
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=trace_info.tenant_id
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
|||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.workflow.repository.repository_factory import RepositoryFactory
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
|
||||
|
|
@ -137,12 +137,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
|||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": trace_info.tenant_id,
|
||||
"app_id": trace_info.metadata.get("app_id"),
|
||||
"session_factory": session_factory,
|
||||
},
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import (
|
|||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.workflow.repository.repository_factory import RepositoryFactory
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
|
||||
|
|
@ -150,12 +150,8 @@ class OpikDataTrace(BaseTraceInstance):
|
|||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": trace_info.tenant_id,
|
||||
"app_id": trace_info.metadata.get("app_id"),
|
||||
"session_factory": session_factory,
|
||||
},
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
|
|
|
|||
|
|
@ -16,11 +16,7 @@ from sqlalchemy.orm import Session
|
|||
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
LangfuseConfig,
|
||||
LangSmithConfig,
|
||||
OpikConfig,
|
||||
TracingProviderEnum,
|
||||
WeaveConfig,
|
||||
)
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
|
|
@ -33,11 +29,7 @@ from core.ops.entities.trace_entity import (
|
|||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
from core.ops.opik_trace.opik_trace import OpikDataTrace
|
||||
from core.ops.utils import get_message_data
|
||||
from core.ops.weave_trace.weave_trace import WeaveDataTrace
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
|
|
@ -45,36 +37,58 @@ from models.workflow import WorkflowAppLog, WorkflowRun
|
|||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
|
||||
def build_opik_trace_instance(config: OpikConfig):
|
||||
return OpikDataTrace(config)
|
||||
class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
|
||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
||||
match provider:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
|
||||
return {
|
||||
"config_class": LangfuseConfig,
|
||||
"secret_keys": ["public_key", "secret_key"],
|
||||
"other_keys": ["host", "project_key"],
|
||||
"trace_instance": LangFuseDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.LANGSMITH:
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
|
||||
return {
|
||||
"config_class": LangSmithConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.OPIK:
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.opik_trace.opik_trace import OpikDataTrace
|
||||
|
||||
return {
|
||||
"config_class": OpikConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "url", "workspace"],
|
||||
"trace_instance": OpikDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.WEAVE:
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
from core.ops.weave_trace.weave_trace import WeaveDataTrace
|
||||
|
||||
return {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
|
||||
|
||||
provider_config_map: dict[str, dict[str, Any]] = {
|
||||
TracingProviderEnum.LANGFUSE.value: {
|
||||
"config_class": LangfuseConfig,
|
||||
"secret_keys": ["public_key", "secret_key"],
|
||||
"other_keys": ["host", "project_key"],
|
||||
"trace_instance": LangFuseDataTrace,
|
||||
},
|
||||
TracingProviderEnum.LANGSMITH.value: {
|
||||
"config_class": LangSmithConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
},
|
||||
TracingProviderEnum.OPIK.value: {
|
||||
"config_class": OpikConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "url", "workspace"],
|
||||
"trace_instance": lambda config: build_opik_trace_instance(config),
|
||||
},
|
||||
TracingProviderEnum.WEAVE.value: {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
},
|
||||
}
|
||||
provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap()
|
||||
|
||||
|
||||
class OpsTraceManager:
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class EndpointProviderDeclaration(BaseModel):
|
|||
"""
|
||||
|
||||
settings: list[ProviderConfig] = Field(default_factory=list)
|
||||
endpoints: Optional[list[EndpointDeclaration]] = Field(default_factory=list)
|
||||
endpoints: Optional[list[EndpointDeclaration]] = Field(default_factory=list[EndpointDeclaration])
|
||||
|
||||
|
||||
class EndpointEntity(BasePluginEntity):
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class PluginResourceRequirements(BaseModel):
|
|||
model: Optional[Model] = Field(default=None)
|
||||
node: Optional[Node] = Field(default=None)
|
||||
endpoint: Optional[Endpoint] = Field(default=None)
|
||||
storage: Storage = Field(default=None)
|
||||
storage: Optional[Storage] = Field(default=None)
|
||||
|
||||
permission: Optional[Permission] = Field(default=None)
|
||||
|
||||
|
|
@ -66,9 +66,9 @@ class PluginCategory(enum.StrEnum):
|
|||
|
||||
class PluginDeclaration(BaseModel):
|
||||
class Plugins(BaseModel):
|
||||
tools: Optional[list[str]] = Field(default_factory=list)
|
||||
models: Optional[list[str]] = Field(default_factory=list)
|
||||
endpoints: Optional[list[str]] = Field(default_factory=list)
|
||||
tools: Optional[list[str]] = Field(default_factory=list[str])
|
||||
models: Optional[list[str]] = Field(default_factory=list[str])
|
||||
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||
|
|
@ -84,6 +84,7 @@ class PluginDeclaration(BaseModel):
|
|||
resource: PluginResourceRequirements
|
||||
plugins: Plugins
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
repo: Optional[str] = Field(default=None)
|
||||
verified: bool = Field(default=False)
|
||||
tool: Optional[ToolProviderEntity] = None
|
||||
model: Optional[ProviderEntity] = None
|
||||
|
|
|
|||
|
|
@ -55,8 +55,8 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
|
|||
mode: str
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
prompt_messages: list[PromptMessage] = Field(default_factory=list)
|
||||
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
|
||||
stop: Optional[list[str]] = Field(default_factory=list)
|
||||
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list[PromptMessageTool])
|
||||
stop: Optional[list[str]] = Field(default_factory=list[str])
|
||||
stream: Optional[bool] = False
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
|||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
|
|
@ -119,12 +120,25 @@ class RetrievalService:
|
|||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
|
||||
def external_retrieve(
|
||||
cls,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_model: Optional[dict] = None,
|
||||
metadata_filtering_conditions: Optional[dict] = None,
|
||||
):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
|
||||
)
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id, dataset_id, query, external_retrieval_model or {}
|
||||
dataset.tenant_id,
|
||||
dataset_id,
|
||||
query,
|
||||
external_retrieval_model or {},
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
return all_documents
|
||||
|
||||
|
|
|
|||
|
|
@ -317,7 +317,7 @@ class NotionExtractor(BaseExtractor):
|
|||
data_source_info["last_edited_time"] = last_edited_time
|
||||
update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)}
|
||||
|
||||
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
|
||||
db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
def get_notion_last_edited_time(self) -> str:
|
||||
|
|
@ -347,14 +347,18 @@ class NotionExtractor(BaseExtractor):
|
|||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
|
||||
)
|
||||
)
|
||||
).first()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(
|
||||
|
|
|
|||
|
|
@ -76,8 +76,7 @@ class WordExtractor(BaseExtractor):
|
|||
parsed = urlparse(url)
|
||||
return bool(parsed.netloc) and bool(parsed.scheme)
|
||||
|
||||
def _extract_images_from_docx(self, doc, image_folder):
|
||||
os.makedirs(image_folder, exist_ok=True)
|
||||
def _extract_images_from_docx(self, doc):
|
||||
image_count = 0
|
||||
image_map = {}
|
||||
|
||||
|
|
@ -210,7 +209,7 @@ class WordExtractor(BaseExtractor):
|
|||
|
||||
content = []
|
||||
|
||||
image_map = self._extract_images_from_docx(doc, image_folder)
|
||||
image_map = self._extract_images_from_docx(doc)
|
||||
|
||||
hyperlinks_url = None
|
||||
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
|
||||
|
|
@ -225,7 +224,7 @@ class WordExtractor(BaseExtractor):
|
|||
xml = ElementTree.XML(run.element.xml)
|
||||
x_child = [c for c in xml.iter() if c is not None]
|
||||
for x in x_child:
|
||||
if x_child is None:
|
||||
if x is None:
|
||||
continue
|
||||
if x.tag.endswith("instrText"):
|
||||
if x.text is None:
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class DatasetRetrieval:
|
|||
else:
|
||||
inputs = {}
|
||||
available_datasets_ids = [dataset.id for dataset in available_datasets]
|
||||
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
||||
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
|
||||
available_datasets_ids,
|
||||
query,
|
||||
tenant_id,
|
||||
|
|
@ -237,12 +237,16 @@ class DatasetRetrieval:
|
|||
if show_retrieve_source:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"dataset_id": dataset.id,
|
||||
|
|
@ -506,19 +510,30 @@ class DatasetRetrieval:
|
|||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
).first()
|
||||
dataset_document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(DatasetDocument.id == document.metadata["document_id"])
|
||||
.first()
|
||||
)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = ChildChunk.query.filter(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
).first()
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.filter(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if child_chunk:
|
||||
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
|
|
@ -649,6 +664,8 @@ class DatasetRetrieval:
|
|||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
) -> Optional[list[DatasetRetrieverBaseTool]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
|
|
@ -706,6 +723,9 @@ class DatasetRetrieval:
|
|||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
retrieve_config=retrieve_config,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
|
@ -826,7 +846,7 @@ class DatasetRetrieval:
|
|||
)
|
||||
return filter_documents[:top_k] if top_k else filter_documents
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
def get_metadata_filter_condition(
|
||||
self,
|
||||
dataset_ids: list,
|
||||
query: str,
|
||||
|
|
@ -876,20 +896,31 @@ class DatasetRetrieval:
|
|||
)
|
||||
elif metadata_filtering_mode == "manual":
|
||||
if metadata_filtering_conditions:
|
||||
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
|
|
|
|||
|
|
@ -4,3 +4,9 @@ Repository implementations for data access.
|
|||
This package contains concrete implementations of the repository interfaces
|
||||
defined in the core.workflow.repository package.
|
||||
"""
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,87 +0,0 @@
|
|||
"""
|
||||
Registry for repository implementations.
|
||||
|
||||
This module is responsible for registering factory functions with the repository factory.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.repository.repository_factory import RepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Storage type constants
|
||||
STORAGE_TYPE_RDBMS = "rdbms"
|
||||
STORAGE_TYPE_HYBRID = "hybrid"
|
||||
|
||||
|
||||
def register_repositories() -> None:
|
||||
"""
|
||||
Register repository factory functions with the RepositoryFactory.
|
||||
|
||||
This function reads configuration settings to determine which repository
|
||||
implementations to register.
|
||||
"""
|
||||
# Configure WorkflowNodeExecutionRepository factory based on configuration
|
||||
workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE
|
||||
|
||||
# Check storage type and register appropriate implementation
|
||||
if workflow_node_execution_storage == STORAGE_TYPE_RDBMS:
|
||||
# Register SQLAlchemy implementation for RDBMS storage
|
||||
logger.info("Registering WorkflowNodeExecution repository with RDBMS storage")
|
||||
RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository)
|
||||
elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID:
|
||||
# Hybrid storage is not yet implemented
|
||||
raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented")
|
||||
else:
|
||||
# Unknown storage type
|
||||
raise ValueError(
|
||||
f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. "
|
||||
f"Supported types: {STORAGE_TYPE_RDBMS}"
|
||||
)
|
||||
|
||||
|
||||
def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation.
|
||||
|
||||
This factory function creates a repository for the RDBMS storage type.
|
||||
|
||||
Args:
|
||||
params: Parameters for creating the repository, including:
|
||||
- tenant_id: Required. The tenant ID for multi-tenancy.
|
||||
- app_id: Optional. The application ID for filtering.
|
||||
- session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided,
|
||||
a new sessionmaker will be created using the global database engine.
|
||||
|
||||
Returns:
|
||||
A WorkflowNodeExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing
|
||||
"""
|
||||
# Extract required parameters
|
||||
tenant_id = params.get("tenant_id")
|
||||
if tenant_id is None:
|
||||
raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage")
|
||||
|
||||
# Extract optional parameters
|
||||
app_id = params.get("app_id")
|
||||
|
||||
# Use the session_factory from params if provided, otherwise create one using the global db engine
|
||||
session_factory = params.get("session_factory")
|
||||
if session_factory is None:
|
||||
# Create a sessionmaker using the same engine as the global db session
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
|
||||
# Create and return the repository
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
||||
)
|
||||
|
|
@ -10,13 +10,13 @@ from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
|||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
|
||||
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
"""
|
||||
WorkflowNodeExecution repository implementations.
|
||||
"""
|
||||
|
||||
from core.repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
]
|
||||
|
|
@ -84,13 +84,17 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
).all()
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
|
|
@ -106,12 +110,16 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = Document.query.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"position": resource_number,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from typing import Any
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -34,7 +35,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
dataset_id: str
|
||||
metadata_filtering_conditions: MetadataCondition
|
||||
user_id: Optional[str] = None
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
inputs: dict
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
|
|
@ -48,7 +51,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
description=description,
|
||||
metadata_filtering_conditions=MetadataCondition(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
@ -61,6 +63,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
return ""
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||
[dataset.id],
|
||||
query,
|
||||
self.tenant_id,
|
||||
self.user_id or "unknown",
|
||||
cast(str, self.retrieve_config.metadata_filtering_mode),
|
||||
cast(ModelConfig, self.retrieve_config.metadata_model_config),
|
||||
self.retrieve_config.metadata_filtering_conditions,
|
||||
self.inputs,
|
||||
)
|
||||
if metadata_filter_document_ids:
|
||||
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
|
||||
else:
|
||||
document_ids_filter = None
|
||||
if dataset.provider == "external":
|
||||
results = []
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
|
|
@ -68,7 +85,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
metadata_condition=self.metadata_filtering_conditions,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
document = RetrievalDocument(
|
||||
|
|
@ -104,12 +121,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
|
||||
return str("\n".join([item.page_content for item in results]))
|
||||
else:
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return ""
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
|
||||
retrieval_method="keyword_search",
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
|
|
@ -128,6 +151,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights"),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
|
|
@ -161,12 +185,16 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
if self.return_resource:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(DatasetDocument) # type: ignore
|
||||
.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"dataset_id": dataset.id,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ class DatasetRetrieverTool(Tool):
|
|||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
) -> list["DatasetRetrieverTool"]:
|
||||
"""
|
||||
get dataset tool
|
||||
|
|
@ -57,6 +59,8 @@ class DatasetRetrieverTool(Tool):
|
|||
return_resource=return_resource,
|
||||
invoke_from=invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
if retrieval_tools is None or len(retrieval_tools) == 0:
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class Variable(Segment):
|
|||
"""
|
||||
|
||||
id: str = Field(
|
||||
default=lambda _: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
description="Unique identity for variable.",
|
||||
)
|
||||
name: str
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class Graph(BaseModel):
|
|||
root_node_id: str = Field(..., description="root node id of the graph")
|
||||
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
|
||||
node_id_config_mapping: dict[str, dict] = Field(
|
||||
default_factory=list, description="node configs mapping (node id: node config)"
|
||||
default_factory=dict, description="node configs mapping (node id: node config)"
|
||||
)
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict, description="graph edge mapping (source node id: edges)"
|
||||
|
|
|
|||
|
|
@ -95,7 +95,12 @@ class StreamProcessor(ABC):
|
|||
if node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
if node_id in reachable_node_ids:
|
||||
return
|
||||
|
||||
self.rest_node_ids.remove(node_id)
|
||||
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
|
||||
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
depth: int = 1,
|
||||
):
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
transformed_result: dict[str, Any] = {}
|
||||
if output_schema is None:
|
||||
|
|
|
|||
|
|
@ -353,27 +353,26 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
|
||||
"""
|
||||
add iteration metadata to event.
|
||||
ensures iteration context (ID, index/parallel_run_id) is added to metadata,
|
||||
"""
|
||||
if not isinstance(event, BaseNodeEvent):
|
||||
return event
|
||||
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
|
||||
event.parallel_mode_run_id = parallel_mode_run_id
|
||||
return event
|
||||
|
||||
iter_metadata = {
|
||||
NodeRunMetadataKey.ITERATION_ID: self.node_id,
|
||||
NodeRunMetadataKey.ITERATION_INDEX: iter_run_index,
|
||||
}
|
||||
if parallel_mode_run_id:
|
||||
# for parallel, the specific branch ID is more important than the sequential index
|
||||
iter_metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
|
||||
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata = {
|
||||
**metadata,
|
||||
NodeRunMetadataKey.ITERATION_ID: self.node_id,
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID
|
||||
if self.node_data.is_parallel
|
||||
else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id
|
||||
if self.node_data.is_parallel
|
||||
else iter_run_index,
|
||||
}
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
current_metadata = event.route_node_state.node_run_result.metadata or {}
|
||||
if NodeRunMetadataKey.ITERATION_ID not in current_metadata:
|
||||
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}
|
||||
|
||||
return event
|
||||
|
||||
def _run_single_iter(
|
||||
|
|
|
|||
|
|
@ -264,6 +264,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
"data_source_type": "external",
|
||||
"retriever_from": "workflow",
|
||||
"score": item.metadata.get("score"),
|
||||
"doc_metadata": item.metadata,
|
||||
},
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
|
|
@ -275,12 +276,16 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = Document.query.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"metadata": {
|
||||
|
|
@ -289,7 +294,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"document_data_source_type": document.data_source_type,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": "workflow",
|
||||
"score": record.score or 0.0,
|
||||
|
|
@ -356,12 +361,12 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
|
||||
conditions = []
|
||||
if node_data.metadata_filtering_conditions:
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
|
|
@ -372,13 +377,24 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
|
|
|
|||
|
|
@ -506,7 +506,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
"dataset_name": metadata.get("dataset_name"),
|
||||
"document_id": metadata.get("document_id"),
|
||||
"document_name": metadata.get("document_name"),
|
||||
"data_source_type": metadata.get("document_data_source_type"),
|
||||
"data_source_type": metadata.get("data_source_type"),
|
||||
"segment_id": metadata.get("segment_id"),
|
||||
"retriever_from": metadata.get("retriever_from"),
|
||||
"score": metadata.get("score"),
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class LoopNodeData(BaseLoopNodeData):
|
|||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list)
|
||||
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -337,7 +337,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
return {"check_break_result": True}
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield event
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
|
|
|
|||
|
|
@ -6,10 +6,9 @@ for accessing and manipulating data, regardless of the underlying
|
|||
storage mechanism.
|
||||
"""
|
||||
|
||||
from core.workflow.repository.repository_factory import RepositoryFactory
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"RepositoryFactory",
|
||||
"OrderConfig",
|
||||
"WorkflowNodeExecutionRepository",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,97 +0,0 @@
|
|||
"""
|
||||
Repository factory for creating repository instances.
|
||||
|
||||
This module provides a simple factory interface for creating repository instances.
|
||||
It does not contain any implementation details or dependencies on specific repositories.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
||||
# Type for factory functions - takes a dict of parameters and returns any repository type
|
||||
RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
|
||||
|
||||
# Type for workflow node execution factory function
|
||||
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
|
||||
|
||||
# Repository type literals
|
||||
_RepositoryType = Literal["workflow_node_execution"]
|
||||
|
||||
|
||||
class RepositoryFactory:
|
||||
"""
|
||||
Factory class for creating repository instances.
|
||||
|
||||
This factory delegates the actual repository creation to implementation-specific
|
||||
factory functions that are registered with the factory at runtime.
|
||||
"""
|
||||
|
||||
# Dictionary to store factory functions
|
||||
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
|
||||
|
||||
@classmethod
|
||||
def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
|
||||
"""
|
||||
Register a factory function for a specific repository type.
|
||||
This is a private method and should not be called directly.
|
||||
|
||||
Args:
|
||||
repository_type: The type of repository (e.g., 'workflow_node_execution')
|
||||
factory_func: A function that takes parameters and returns a repository instance
|
||||
"""
|
||||
cls._factory_functions[repository_type] = factory_func
|
||||
|
||||
@classmethod
|
||||
def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Create a new repository instance with the provided parameters.
|
||||
This is a private method and should not be called directly.
|
||||
|
||||
Args:
|
||||
repository_type: The type of repository to create
|
||||
params: A dictionary of parameters to pass to the factory function
|
||||
|
||||
Returns:
|
||||
A new instance of the requested repository
|
||||
|
||||
Raises:
|
||||
ValueError: If no factory function is registered for the repository type
|
||||
"""
|
||||
if repository_type not in cls._factory_functions:
|
||||
raise ValueError(f"No factory function registered for repository type '{repository_type}'")
|
||||
|
||||
# Use empty dict if params is None
|
||||
params = params or {}
|
||||
|
||||
return cls._factory_functions[repository_type](params)
|
||||
|
||||
@classmethod
|
||||
def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
|
||||
"""
|
||||
Register a factory function for the workflow node execution repository.
|
||||
|
||||
Args:
|
||||
factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
|
||||
"""
|
||||
cls._register_factory("workflow_node_execution", factory_func)
|
||||
|
||||
@classmethod
|
||||
def create_workflow_node_execution_repository(
|
||||
cls, params: Optional[Mapping[str, Any]] = None
|
||||
) -> WorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
|
||||
|
||||
Args:
|
||||
params: A dictionary of parameters to pass to the factory function
|
||||
|
||||
Returns:
|
||||
A new instance of the WorkflowNodeExecutionRepository
|
||||
|
||||
Raises:
|
||||
ValueError: If no factory function is registered for the workflow_node_execution repository type
|
||||
"""
|
||||
# We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
|
||||
return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))
|
||||
|
|
@ -39,7 +39,7 @@ class SubCondition(BaseModel):
|
|||
|
||||
class SubVariableCondition(BaseModel):
|
||||
logical_operator: Literal["and", "or"]
|
||||
conditions: list[SubCondition] = Field(default=list)
|
||||
conditions: list[SubCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from typing import Optional, Union
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
|
|
@ -52,10 +51,11 @@ from core.app.entities.task_entities import (
|
|||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
|
|
@ -102,7 +102,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}")
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
|
|
@ -69,7 +69,7 @@ from models.workflow import (
|
|||
)
|
||||
|
||||
|
||||
class WorkflowCycleManage:
|
||||
class WorkflowCycleManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -39,6 +39,10 @@ def init_app(app: DifyApp):
|
|||
handlers=log_handlers,
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Apply RequestIdFormatter to all handlers
|
||||
apply_request_id_formatter()
|
||||
|
||||
# Disable propagation for noisy loggers to avoid duplicate logs
|
||||
logging.getLogger("sqlalchemy.engine").propagate = False
|
||||
log_tz = dify_config.LOG_TZ
|
||||
|
|
@ -74,3 +78,16 @@ class RequestIdFilter(logging.Filter):
|
|||
def filter(self, record):
|
||||
record.req_id = get_request_id() if flask.has_request_context() else ""
|
||||
return True
|
||||
|
||||
|
||||
class RequestIdFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
if not hasattr(record, "req_id"):
|
||||
record.req_id = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def apply_request_id_formatter():
|
||||
for handler in logging.root.handlers:
|
||||
if handler.formatter:
|
||||
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class Mail:
|
|||
|
||||
match mail_type:
|
||||
case "resend":
|
||||
import resend # type: ignore
|
||||
import resend
|
||||
|
||||
api_key = dify_config.RESEND_API_KEY
|
||||
if not api_key:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import socket
|
|||
import sys
|
||||
from typing import Union
|
||||
|
||||
import flask
|
||||
from celery.signals import worker_init # type: ignore
|
||||
from flask_login import user_loaded_from_request, user_logged_in # type: ignore
|
||||
|
||||
|
|
@ -27,6 +28,8 @@ def on_user_loaded(_sender, user):
|
|||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
from opentelemetry.semconv.trace import SpanAttributes
|
||||
|
||||
def is_celery_worker():
|
||||
return "celery" in sys.argv[0].lower()
|
||||
|
||||
|
|
@ -37,7 +40,9 @@ def init_app(app: DifyApp):
|
|||
def init_flask_instrumentor(app: DifyApp):
|
||||
meter = get_meter("http_metrics", version=dify_config.CURRENT_VERSION)
|
||||
_http_response_counter = meter.create_counter(
|
||||
"http.server.response.count", description="Total number of HTTP responses by status code", unit="{response}"
|
||||
"http.server.response.count",
|
||||
description="Total number of HTTP responses by status code, method and target",
|
||||
unit="{response}",
|
||||
)
|
||||
|
||||
def response_hook(span: Span, status: str, response_headers: list):
|
||||
|
|
@ -50,7 +55,13 @@ def init_app(app: DifyApp):
|
|||
status = status.split(" ")[0]
|
||||
status_code = int(status)
|
||||
status_class = f"{status_code // 100}xx"
|
||||
_http_response_counter.add(1, {"status_code": status_code, "status_class": status_class})
|
||||
attributes: dict[str, str | int] = {"status_code": status_code, "status_class": status_class}
|
||||
request = flask.request
|
||||
if request and request.url_rule:
|
||||
attributes[SpanAttributes.HTTP_TARGET] = str(request.url_rule.rule)
|
||||
if request and request.method:
|
||||
attributes[SpanAttributes.HTTP_METHOD] = str(request.method)
|
||||
_http_response_counter.add(1, attributes)
|
||||
|
||||
instrumentor = FlaskInstrumentor()
|
||||
if dify_config.DEBUG:
|
||||
|
|
@ -103,8 +114,10 @@ def init_app(app: DifyApp):
|
|||
pass
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
|
||||
from opentelemetry.instrumentation.celery import CeleryInstrumentor
|
||||
from opentelemetry.instrumentation.flask import FlaskInstrumentor
|
||||
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
||||
|
|
@ -147,19 +160,32 @@ def init_app(app: DifyApp):
|
|||
sampler = ParentBasedTraceIdRatio(dify_config.OTEL_SAMPLING_RATE)
|
||||
provider = TracerProvider(resource=resource, sampler=sampler)
|
||||
set_tracer_provider(provider)
|
||||
exporter: Union[OTLPSpanExporter, ConsoleSpanExporter]
|
||||
metric_exporter: Union[OTLPMetricExporter, ConsoleMetricExporter]
|
||||
exporter: Union[GRPCSpanExporter, HTTPSpanExporter, ConsoleSpanExporter]
|
||||
metric_exporter: Union[GRPCMetricExporter, HTTPMetricExporter, ConsoleMetricExporter]
|
||||
protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower()
|
||||
if dify_config.OTEL_EXPORTER_TYPE == "otlp":
|
||||
exporter = OTLPSpanExporter(
|
||||
endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces",
|
||||
headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
|
||||
)
|
||||
metric_exporter = OTLPMetricExporter(
|
||||
endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics",
|
||||
headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
|
||||
)
|
||||
if protocol == "grpc":
|
||||
exporter = GRPCSpanExporter(
|
||||
endpoint=dify_config.OTLP_BASE_ENDPOINT,
|
||||
# Header field names must consist of lowercase letters, check RFC7540
|
||||
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
|
||||
insecure=True,
|
||||
)
|
||||
metric_exporter = GRPCMetricExporter(
|
||||
endpoint=dify_config.OTLP_BASE_ENDPOINT,
|
||||
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
|
||||
insecure=True,
|
||||
)
|
||||
else:
|
||||
exporter = HTTPSpanExporter(
|
||||
endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces",
|
||||
headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
|
||||
)
|
||||
metric_exporter = HTTPMetricExporter(
|
||||
endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics",
|
||||
headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
|
||||
)
|
||||
else:
|
||||
# Fallback to console exporter
|
||||
exporter = ConsoleSpanExporter()
|
||||
metric_exporter = ConsoleMetricExporter()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import redis
|
||||
from redis.cache import CacheConfig
|
||||
from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.sentinel import Sentinel
|
||||
|
|
@ -51,6 +52,14 @@ def init_app(app: DifyApp):
|
|||
connection_class: type[Union[Connection, SSLConnection]] = Connection
|
||||
if dify_config.REDIS_USE_SSL:
|
||||
connection_class = SSLConnection
|
||||
resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL
|
||||
if dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE:
|
||||
if resp_protocol >= 3:
|
||||
clientside_cache_config = CacheConfig()
|
||||
else:
|
||||
raise ValueError("Client side cache is only supported in RESP3")
|
||||
else:
|
||||
clientside_cache_config = None
|
||||
|
||||
redis_params: dict[str, Any] = {
|
||||
"username": dify_config.REDIS_USERNAME,
|
||||
|
|
@ -59,6 +68,8 @@ def init_app(app: DifyApp):
|
|||
"encoding": "utf-8",
|
||||
"encoding_errors": "strict",
|
||||
"decode_responses": False,
|
||||
"protocol": resp_protocol,
|
||||
"cache_config": clientside_cache_config,
|
||||
}
|
||||
|
||||
if dify_config.REDIS_USE_SENTINEL:
|
||||
|
|
@ -82,14 +93,22 @@ def init_app(app: DifyApp):
|
|||
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
|
||||
for node in dify_config.REDIS_CLUSTERS.split(",")
|
||||
]
|
||||
# FIXME: mypy error here, try to figure out how to fix it
|
||||
redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD)) # type: ignore
|
||||
redis_client.initialize(
|
||||
RedisCluster(
|
||||
startup_nodes=nodes,
|
||||
password=dify_config.REDIS_CLUSTERS_PASSWORD,
|
||||
protocol=resp_protocol,
|
||||
cache_config=clientside_cache_config,
|
||||
)
|
||||
)
|
||||
else:
|
||||
redis_params.update(
|
||||
{
|
||||
"host": dify_config.REDIS_HOST,
|
||||
"port": dify_config.REDIS_PORT,
|
||||
"connection_class": connection_class,
|
||||
"protocol": resp_protocol,
|
||||
"cache_config": clientside_cache_config,
|
||||
}
|
||||
)
|
||||
pool = redis.ConnectionPool(**redis_params)
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
"""
|
||||
Extension for initializing repositories.
|
||||
|
||||
This extension registers repository implementations with the RepositoryFactory.
|
||||
"""
|
||||
|
||||
from core.repositories.repository_registry import register_repositories
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(_app: DifyApp) -> None:
|
||||
"""
|
||||
Initialize repository implementations.
|
||||
|
||||
Args:
|
||||
_app: The Flask application instance (unused)
|
||||
"""
|
||||
register_repositories()
|
||||
|
|
@ -61,13 +61,17 @@ class NotionOAuth(OAuthDataSource):
|
|||
"total": len(pages),
|
||||
}
|
||||
# save data source binding
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
.first()
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.disabled = False
|
||||
|
|
@ -97,13 +101,17 @@ class NotionOAuth(OAuthDataSource):
|
|||
"total": len(pages),
|
||||
}
|
||||
# save data source binding
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
.first()
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.disabled = False
|
||||
|
|
@ -121,14 +129,18 @@ class NotionOAuth(OAuthDataSource):
|
|||
|
||||
def sync_data_source(self, binding_id: str):
|
||||
# save data source binding
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.id == binding_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.id == binding_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
)
|
||||
)
|
||||
).first()
|
||||
.first()
|
||||
)
|
||||
if data_source_binding:
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||
|
|
|
|||
|
|
@ -5,45 +5,61 @@ Revises: 33f5fac87f29
|
|||
Create Date: 2024-10-10 05:16:14.764268
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from alembic import op, context
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'bbadea11becb'
|
||||
down_revision = 'd8e744d88ed6'
|
||||
revision = "bbadea11becb"
|
||||
down_revision = "d8e744d88ed6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
def _has_name_or_size_column() -> bool:
|
||||
# We cannot access the database in offline mode, so assume
|
||||
# the "name" and "size" columns do not exist.
|
||||
if context.is_offline_mode():
|
||||
# Log a warning message to inform the user that the database schema cannot be inspected
|
||||
# in offline mode, and the generated SQL may not accurately reflect the actual execution.
|
||||
op.execute(
|
||||
"-- Executing in offline mode, assuming the name and size columns do not exist.\n"
|
||||
"-- The generated SQL may differ from what will actually be executed.\n"
|
||||
"-- Please review the migration script carefully!"
|
||||
)
|
||||
|
||||
return False
|
||||
# Use SQLAlchemy inspector to get the columns of the 'tool_files' table
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col["name"] for col in inspector.get_columns("tool_files")]
|
||||
|
||||
# If 'name' or 'size' columns already exist, exit the upgrade function
|
||||
if "name" in columns or "size" in columns:
|
||||
return True
|
||||
return False
|
||||
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Get the database connection
|
||||
conn = op.get_bind()
|
||||
|
||||
# Use SQLAlchemy inspector to get the columns of the 'tool_files' table
|
||||
inspector = sa.inspect(conn)
|
||||
columns = [col['name'] for col in inspector.get_columns('tool_files')]
|
||||
|
||||
# If 'name' or 'size' columns already exist, exit the upgrade function
|
||||
if 'name' in columns or 'size' in columns:
|
||||
if _has_name_or_size_column():
|
||||
return
|
||||
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('name', sa.String(), nullable=True))
|
||||
batch_op.add_column(sa.Column('size', sa.Integer(), nullable=True))
|
||||
with op.batch_alter_table("tool_files", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("name", sa.String(), nullable=True))
|
||||
batch_op.add_column(sa.Column("size", sa.Integer(), nullable=True))
|
||||
op.execute("UPDATE tool_files SET name = '' WHERE name IS NULL")
|
||||
op.execute("UPDATE tool_files SET size = -1 WHERE size IS NULL")
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.alter_column('name', existing_type=sa.String(), nullable=False)
|
||||
batch_op.alter_column('size', existing_type=sa.Integer(), nullable=False)
|
||||
with op.batch_alter_table("tool_files", schema=None) as batch_op:
|
||||
batch_op.alter_column("name", existing_type=sa.String(), nullable=False)
|
||||
batch_op.alter_column("size", existing_type=sa.Integer(), nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_files', schema=None) as batch_op:
|
||||
batch_op.drop_column('size')
|
||||
batch_op.drop_column('name')
|
||||
with op.batch_alter_table("tool_files", schema=None) as batch_op:
|
||||
batch_op.drop_column("size")
|
||||
batch_op.drop_column("name")
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -5,28 +5,38 @@ Revises: e1944c35e15e
|
|||
Create Date: 2024-12-23 11:54:15.344543
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op, context
|
||||
from sqlalchemy import inspect
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd7999dfa4aae'
|
||||
down_revision = 'e1944c35e15e'
|
||||
revision = "d7999dfa4aae"
|
||||
down_revision = "e1944c35e15e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Check if column exists before attempting to remove it
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
has_column = 'retry_index' in [col['name'] for col in inspector.get_columns('workflow_node_executions')]
|
||||
def _has_retry_index_column() -> bool:
|
||||
if context.is_offline_mode():
|
||||
# Log a warning message to inform the user that the database schema cannot be inspected
|
||||
# in offline mode, and the generated SQL may not accurately reflect the actual execution.
|
||||
op.execute(
|
||||
'-- Executing in offline mode: assuming the "retry_index" column does not exist.\n'
|
||||
"-- The generated SQL may differ from what will actually be executed.\n"
|
||||
"-- Please review the migration script carefully!"
|
||||
)
|
||||
return False
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
return "retry_index" in [col["name"] for col in inspector.get_columns("workflow_node_executions")]
|
||||
|
||||
has_column = _has_retry_index_column()
|
||||
|
||||
if has_column:
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.drop_column('retry_index')
|
||||
with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
|
||||
batch_op.drop_column("retry_index")
|
||||
|
||||
|
||||
def downgrade():
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
"""add index for workflow_conversation_variables.conversation_id
|
||||
|
||||
Revision ID: d28f2004b072
|
||||
Revises: 6a9f914f656c
|
||||
Create Date: 2025-05-14 14:03:36.713828
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd28f2004b072'
|
||||
down_revision = '6a9f914f656c'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('workflow_conversation_variables_conversation_id_idx'), ['conversation_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('workflow_conversation_variables_conversation_id_idx'))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import enum
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from flask_login import UserMixin # type: ignore
|
||||
from sqlalchemy import func
|
||||
|
|
@ -46,13 +47,12 @@ class Account(UserMixin, Base):
|
|||
|
||||
@property
|
||||
def current_tenant(self):
|
||||
# FIXME: fix the type error later, because the type is important maybe cause some bugs
|
||||
return self._current_tenant # type: ignore
|
||||
|
||||
@current_tenant.setter
|
||||
def current_tenant(self, value: "Tenant"):
|
||||
tenant = value
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=self.id).first()
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
|
||||
if ta:
|
||||
tenant.current_role = ta.role
|
||||
else:
|
||||
|
|
@ -64,25 +64,23 @@ class Account(UserMixin, Base):
|
|||
def current_tenant_id(self) -> str | None:
|
||||
return self._current_tenant.id if self._current_tenant else None
|
||||
|
||||
@current_tenant_id.setter
|
||||
def current_tenant_id(self, value: str):
|
||||
try:
|
||||
tenant_account_join = (
|
||||
def set_tenant_id(self, tenant_id: str):
|
||||
tenant_account_join = cast(
|
||||
tuple[Tenant, TenantAccountJoin],
|
||||
(
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.filter(Tenant.id == value)
|
||||
.filter(Tenant.id == tenant_id)
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.account_id == self.id)
|
||||
.one_or_none()
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
tenant.current_role = ta.role
|
||||
else:
|
||||
tenant = None
|
||||
except Exception:
|
||||
tenant = None
|
||||
if not tenant_account_join:
|
||||
return
|
||||
|
||||
tenant, join = tenant_account_join
|
||||
tenant.current_role = join.role
|
||||
self._current_tenant = tenant
|
||||
|
||||
@property
|
||||
|
|
@ -191,7 +189,7 @@ class TenantAccountRole(enum.StrEnum):
|
|||
}
|
||||
|
||||
|
||||
class Tenant(db.Model): # type: ignore[name-defined]
|
||||
class Tenant(Base):
|
||||
__tablename__ = "tenants"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||
|
||||
|
|
@ -220,7 +218,7 @@ class Tenant(db.Model): # type: ignore[name-defined]
|
|||
self.custom_config = json.dumps(value)
|
||||
|
||||
|
||||
class TenantAccountJoin(db.Model): # type: ignore[name-defined]
|
||||
class TenantAccountJoin(Base):
|
||||
__tablename__ = "tenant_account_joins"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
|
||||
|
|
@ -239,7 +237,7 @@ class TenantAccountJoin(db.Model): # type: ignore[name-defined]
|
|||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class AccountIntegrate(db.Model): # type: ignore[name-defined]
|
||||
class AccountIntegrate(Base):
|
||||
__tablename__ = "account_integrates"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
|
||||
|
|
@ -256,7 +254,7 @@ class AccountIntegrate(db.Model): # type: ignore[name-defined]
|
|||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class InvitationCode(db.Model): # type: ignore[name-defined]
|
||||
class InvitationCode(Base):
|
||||
__tablename__ = "invitation_codes"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import enum
|
|||
|
||||
from sqlalchemy import func
|
||||
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
|
@ -13,7 +14,7 @@ class APIBasedExtensionPoint(enum.Enum):
|
|||
APP_MODERATION_OUTPUT = "app.moderation.output"
|
||||
|
||||
|
||||
class APIBasedExtension(db.Model): # type: ignore[name-defined]
|
||||
class APIBasedExtension(Base):
|
||||
__tablename__ = "api_based_extensions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from extensions.ext_storage import storage
|
|||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .model import App, Tag, TagBinding, UploadFile
|
||||
from .types import StringUUID
|
||||
|
|
@ -33,7 +34,7 @@ class DatasetPermissionEnum(enum.StrEnum):
|
|||
PARTIAL_TEAM = "partial_members"
|
||||
|
||||
|
||||
class Dataset(db.Model): # type: ignore[name-defined]
|
||||
class Dataset(Base):
|
||||
__tablename__ = "datasets"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
|
||||
|
|
@ -92,7 +93,8 @@ class Dataset(db.Model): # type: ignore[name-defined]
|
|||
@property
|
||||
def latest_process_rule(self):
|
||||
return (
|
||||
DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id)
|
||||
db.session.query(DatasetProcessRule)
|
||||
.filter(DatasetProcessRule.dataset_id == self.id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
|
@ -137,7 +139,8 @@ class Dataset(db.Model): # type: ignore[name-defined]
|
|||
@property
|
||||
def word_count(self):
|
||||
return (
|
||||
Document.query.with_entities(func.coalesce(func.sum(Document.word_count)))
|
||||
db.session.query(Document)
|
||||
.with_entities(func.coalesce(func.sum(Document.word_count)))
|
||||
.filter(Document.dataset_id == self.id)
|
||||
.scalar()
|
||||
)
|
||||
|
|
@ -255,7 +258,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
|
|||
return f"Vector_index_{normalized_dataset_id}_Node"
|
||||
|
||||
|
||||
class DatasetProcessRule(db.Model): # type: ignore[name-defined]
|
||||
class DatasetProcessRule(Base):
|
||||
__tablename__ = "dataset_process_rules"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
|
||||
|
|
@ -295,7 +298,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
|
|||
return None
|
||||
|
||||
|
||||
class Document(db.Model): # type: ignore[name-defined]
|
||||
class Document(Base):
|
||||
__tablename__ = "documents"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="document_pkey"),
|
||||
|
|
@ -439,12 +442,13 @@ class Document(db.Model): # type: ignore[name-defined]
|
|||
|
||||
@property
|
||||
def segment_count(self):
|
||||
return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
|
||||
return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count()
|
||||
|
||||
@property
|
||||
def hit_count(self):
|
||||
return (
|
||||
DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
|
||||
.filter(DocumentSegment.document_id == self.id)
|
||||
.scalar()
|
||||
)
|
||||
|
|
@ -635,7 +639,7 @@ class Document(db.Model): # type: ignore[name-defined]
|
|||
)
|
||||
|
||||
|
||||
class DocumentSegment(db.Model): # type: ignore[name-defined]
|
||||
class DocumentSegment(Base):
|
||||
__tablename__ = "document_segments"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
|
||||
|
|
@ -786,7 +790,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
|
|||
return text
|
||||
|
||||
|
||||
class ChildChunk(db.Model): # type: ignore[name-defined]
|
||||
class ChildChunk(Base):
|
||||
__tablename__ = "child_chunks"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
|
||||
|
|
@ -829,7 +833,7 @@ class ChildChunk(db.Model): # type: ignore[name-defined]
|
|||
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
|
||||
|
||||
|
||||
class AppDatasetJoin(db.Model): # type: ignore[name-defined]
|
||||
class AppDatasetJoin(Base):
|
||||
__tablename__ = "app_dataset_joins"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
|
||||
|
|
@ -846,7 +850,7 @@ class AppDatasetJoin(db.Model): # type: ignore[name-defined]
|
|||
return db.session.get(App, self.app_id)
|
||||
|
||||
|
||||
class DatasetQuery(db.Model): # type: ignore[name-defined]
|
||||
class DatasetQuery(Base):
|
||||
__tablename__ = "dataset_queries"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
|
||||
|
|
@ -863,7 +867,7 @@ class DatasetQuery(db.Model): # type: ignore[name-defined]
|
|||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
|
||||
class DatasetKeywordTable(Base):
|
||||
__tablename__ = "dataset_keyword_tables"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
|
||||
|
|
@ -891,7 +895,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
|
|||
return dct
|
||||
|
||||
# get dataset
|
||||
dataset = Dataset.query.filter_by(id=self.dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
|
||||
if not dataset:
|
||||
return None
|
||||
if self.data_source_type == "database":
|
||||
|
|
@ -908,7 +912,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
|
|||
return None
|
||||
|
||||
|
||||
class Embedding(db.Model): # type: ignore[name-defined]
|
||||
class Embedding(Base):
|
||||
__tablename__ = "embeddings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="embedding_pkey"),
|
||||
|
|
@ -932,7 +936,7 @@ class Embedding(db.Model): # type: ignore[name-defined]
|
|||
return cast(list[float], pickle.loads(self.embedding)) # noqa: S301
|
||||
|
||||
|
||||
class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
|
||||
class DatasetCollectionBinding(Base):
|
||||
__tablename__ = "dataset_collection_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
|
||||
|
|
@ -947,7 +951,7 @@ class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
|
|||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TidbAuthBinding(db.Model): # type: ignore[name-defined]
|
||||
class TidbAuthBinding(Base):
|
||||
__tablename__ = "tidb_auth_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
|
||||
|
|
@ -967,7 +971,7 @@ class TidbAuthBinding(db.Model): # type: ignore[name-defined]
|
|||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class Whitelist(db.Model): # type: ignore[name-defined]
|
||||
class Whitelist(Base):
|
||||
__tablename__ = "whitelists"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
|
||||
|
|
@ -979,7 +983,7 @@ class Whitelist(db.Model): # type: ignore[name-defined]
|
|||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetPermission(db.Model): # type: ignore[name-defined]
|
||||
class DatasetPermission(Base):
|
||||
__tablename__ = "dataset_permissions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
|
||||
|
|
@ -996,7 +1000,7 @@ class DatasetPermission(db.Model): # type: ignore[name-defined]
|
|||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
|
||||
class ExternalKnowledgeApis(Base):
|
||||
__tablename__ = "external_knowledge_apis"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
|
||||
|
|
@ -1049,7 +1053,7 @@ class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
|
|||
return dataset_bindings
|
||||
|
||||
|
||||
class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
|
||||
class ExternalKnowledgeBindings(Base):
|
||||
__tablename__ = "external_knowledge_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
|
||||
|
|
@ -1070,7 +1074,7 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
|
|||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
|
||||
class DatasetAutoDisableLog(Base):
|
||||
__tablename__ = "dataset_auto_disable_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
|
||||
|
|
@ -1087,7 +1091,7 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
|
|||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
|
||||
class RateLimitLog(db.Model): # type: ignore[name-defined]
|
||||
class RateLimitLog(Base):
|
||||
__tablename__ = "rate_limit_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
|
||||
|
|
@ -1102,7 +1106,7 @@ class RateLimitLog(db.Model): # type: ignore[name-defined]
|
|||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
|
||||
class DatasetMetadata(db.Model): # type: ignore[name-defined]
|
||||
class DatasetMetadata(Base):
|
||||
__tablename__ = "dataset_metadatas"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
|
||||
|
|
@ -1121,7 +1125,7 @@ class DatasetMetadata(db.Model): # type: ignore[name-defined]
|
|||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
|
||||
|
||||
class DatasetMetadataBinding(db.Model): # type: ignore[name-defined]
|
||||
class DatasetMetadataBinding(Base):
|
||||
__tablename__ = "dataset_metadata_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin # type: ignore
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
|
|
@ -25,13 +25,13 @@ from constants import DEFAULT_FILE_NUMBER_LIMITS
|
|||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from core.file import helpers as file_helpers
|
||||
from libs.helper import generate_string
|
||||
from models.base import Base
|
||||
from models.enums import CreatedByRole
|
||||
from models.workflow import WorkflowRunStatus
|
||||
|
||||
from .account import Account, Tenant
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .enums import CreatedByRole
|
||||
from .types import StringUUID
|
||||
from .workflow import WorkflowRunStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .workflow import Workflow
|
||||
|
|
@ -602,7 +602,7 @@ class InstalledApp(Base):
|
|||
return tenant
|
||||
|
||||
|
||||
class Conversation(db.Model): # type: ignore[name-defined]
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="conversation_pkey"),
|
||||
|
|
@ -794,7 +794,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
|
|||
|
||||
for message in messages:
|
||||
if message.workflow_run:
|
||||
status_counts[message.workflow_run.status] += 1
|
||||
status_counts[WorkflowRunStatus(message.workflow_run.status)] += 1
|
||||
|
||||
return (
|
||||
{
|
||||
|
|
@ -864,7 +864,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
|
|||
}
|
||||
|
||||
|
||||
class Message(db.Model): # type: ignore[name-defined]
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
__table_args__ = (
|
||||
PrimaryKeyConstraint("id", name="message_pkey"),
|
||||
|
|
@ -1211,7 +1211,7 @@ class Message(db.Model): # type: ignore[name-defined]
|
|||
)
|
||||
|
||||
|
||||
class MessageFeedback(db.Model): # type: ignore[name-defined]
|
||||
class MessageFeedback(Base):
|
||||
__tablename__ = "message_feedbacks"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
|
||||
|
|
@ -1237,8 +1237,23 @@ class MessageFeedback(db.Model): # type: ignore[name-defined]
|
|||
account = db.session.query(Account).filter(Account.id == self.from_account_id).first()
|
||||
return account
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"app_id": str(self.app_id),
|
||||
"conversation_id": str(self.conversation_id),
|
||||
"message_id": str(self.message_id),
|
||||
"rating": self.rating,
|
||||
"content": self.content,
|
||||
"from_source": self.from_source,
|
||||
"from_end_user_id": str(self.from_end_user_id) if self.from_end_user_id else None,
|
||||
"from_account_id": str(self.from_account_id) if self.from_account_id else None,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
class MessageFile(db.Model): # type: ignore[name-defined]
|
||||
|
||||
class MessageFile(Base):
|
||||
__tablename__ = "message_files"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_file_pkey"),
|
||||
|
|
@ -1279,7 +1294,7 @@ class MessageFile(db.Model): # type: ignore[name-defined]
|
|||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class MessageAnnotation(db.Model): # type: ignore[name-defined]
|
||||
class MessageAnnotation(Base):
|
||||
__tablename__ = "message_annotations"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
|
||||
|
|
@ -1310,7 +1325,7 @@ class MessageAnnotation(db.Model): # type: ignore[name-defined]
|
|||
return account
|
||||
|
||||
|
||||
class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
|
||||
class AppAnnotationHitHistory(Base):
|
||||
__tablename__ = "app_annotation_hit_histories"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
|
||||
|
|
@ -1322,7 +1337,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
|
|||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
annotation_id = db.Column(StringUUID, nullable=False)
|
||||
annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
source = db.Column(db.Text, nullable=False)
|
||||
question = db.Column(db.Text, nullable=False)
|
||||
account_id = db.Column(StringUUID, nullable=False)
|
||||
|
|
@ -1348,7 +1363,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
|
|||
return account
|
||||
|
||||
|
||||
class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
|
||||
class AppAnnotationSetting(Base):
|
||||
__tablename__ = "app_annotation_settings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
|
||||
|
|
@ -1364,26 +1379,6 @@ class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
|
|||
updated_user_id = db.Column(StringUUID, nullable=False)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def created_account(self):
|
||||
account = (
|
||||
db.session.query(Account)
|
||||
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
|
||||
.filter(AppAnnotationSetting.id == self.annotation_id)
|
||||
.first()
|
||||
)
|
||||
return account
|
||||
|
||||
@property
|
||||
def updated_account(self):
|
||||
account = (
|
||||
db.session.query(Account)
|
||||
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
|
||||
.filter(AppAnnotationSetting.id == self.annotation_id)
|
||||
.first()
|
||||
)
|
||||
return account
|
||||
|
||||
@property
|
||||
def collection_binding_detail(self):
|
||||
from .dataset import DatasetCollectionBinding
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@ from enum import Enum
|
|||
|
||||
from sqlalchemy import func
|
||||
|
||||
from models.base import Base
|
||||
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from .engine import db
|
|||
from .types import StringUUID
|
||||
|
||||
|
||||
class DataSourceOauthBinding(db.Model): # type: ignore[name-defined]
|
||||
class DataSourceOauthBinding(Base):
|
||||
__tablename__ = "data_source_oauth_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from deprecated import deprecated
|
||||
|
|
@ -304,8 +304,11 @@ class DeprecatedPublishedAppTool(Base):
|
|||
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
# id of the app
|
||||
app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
|
||||
|
||||
user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
# who published this tool
|
||||
description = db.Column(db.Text, nullable=False)
|
||||
# llm_description of the tool, for LLM
|
||||
|
|
@ -325,34 +328,3 @@ class DeprecatedPublishedAppTool(Base):
|
|||
@property
|
||||
def description_i18n(self) -> I18nObject:
|
||||
return I18nObject(**json.loads(self.description))
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True)
|
||||
file_key: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
mimetype: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
original_url: Mapped[Optional[str]] = db.Column(db.String(2048), nullable=True)
|
||||
name: Mapped[str] = mapped_column(default="")
|
||||
size: Mapped[int] = mapped_column(default=-1)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
file_key: str,
|
||||
mimetype: str,
|
||||
original_url: Optional[str] = None,
|
||||
name: str,
|
||||
size: int,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.tenant_id = tenant_id
|
||||
self.conversation_id = conversation_id
|
||||
self.file_key = file_key
|
||||
self.mimetype = mimetype
|
||||
self.original_url = original_url
|
||||
self.name = name
|
||||
self.size = size
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
|||
from models.model import AppMode
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import Index, PrimaryKeyConstraint, func
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
import contexts
|
||||
|
|
@ -18,11 +18,11 @@ from core.helper import encrypter
|
|||
from core.variables import SecretVariable, Variable
|
||||
from factories import variable_factory
|
||||
from libs import helper
|
||||
from models.base import Base
|
||||
from models.enums import CreatedByRole
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .enums import CreatedByRole
|
||||
from .types import StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -736,8 +736,7 @@ class WorkflowAppLog(Base):
|
|||
__tablename__ = "workflow_app_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
|
||||
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id", "created_at"),
|
||||
db.Index("workflow_app_log_workflow_run_idx", "workflow_run_id"),
|
||||
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
|
|
@ -769,17 +768,12 @@ class WorkflowAppLog(Base):
|
|||
|
||||
class ConversationVariable(Base):
|
||||
__tablename__ = "workflow_conversation_variables"
|
||||
__table_args__ = (
|
||||
PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"),
|
||||
Index("workflow__conversation_variables_app_id_idx", "app_id"),
|
||||
Index("workflow__conversation_variables_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
|
||||
data = mapped_column(db.Text, nullable=False)
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True)
|
||||
updated_at = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ dependencies = [
|
|||
"chardet~=5.1.0",
|
||||
"flask~=3.1.0",
|
||||
"flask-compress~=1.17",
|
||||
"flask-cors~=4.0.0",
|
||||
"flask-cors~=5.0.0",
|
||||
"flask-login~=0.6.3",
|
||||
"flask-migrate~=4.0.7",
|
||||
"flask-restful~=0.3.10",
|
||||
|
|
@ -63,25 +63,24 @@ dependencies = [
|
|||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.6",
|
||||
"pycryptodome==3.19.1",
|
||||
"pydantic~=2.9.2",
|
||||
"pydantic-extra-types~=2.9.0",
|
||||
"pydantic-settings~=2.6.0",
|
||||
"pydantic~=2.11.4",
|
||||
"pydantic-extra-types~=2.10.3",
|
||||
"pydantic-settings~=2.9.1",
|
||||
"pyjwt~=2.8.0",
|
||||
"pypdfium2~=4.30.0",
|
||||
"pypdfium2==4.30.0",
|
||||
"python-docx~=1.1.0",
|
||||
"python-dotenv==1.0.1",
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy==0.2.0",
|
||||
"redis[hiredis]~=5.0.3",
|
||||
"resend~=0.7.0",
|
||||
"sentry-sdk[flask]~=1.44.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=6.0.0",
|
||||
"resend~=2.9.0",
|
||||
"sentry-sdk[flask]~=2.28.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.41.0",
|
||||
"tiktoken~=0.9.0",
|
||||
"tokenizers~=0.15.0",
|
||||
"transformers~=4.35.0",
|
||||
"transformers~=4.51.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
|
||||
"weave~=0.51.34",
|
||||
"weave~=0.51.0",
|
||||
"yarl~=1.18.3",
|
||||
"webvtt-py~=0.5.1",
|
||||
]
|
||||
|
|
@ -195,7 +194,7 @@ vdb = [
|
|||
"tcvectordb~=1.6.4",
|
||||
"tidb-vector==0.0.9",
|
||||
"upstash-vector==0.6.0",
|
||||
"volcengine-compat~=1.0.156",
|
||||
"volcengine-compat~=1.0.0",
|
||||
"weaviate-client~=3.24.0",
|
||||
"xinference-client~=1.2.2",
|
||||
]
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
|
|
@ -20,6 +21,8 @@ from models.model import (
|
|||
from models.web import SavedMessage
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
def clean_messages():
|
||||
|
|
@ -46,7 +49,14 @@ def clean_messages():
|
|||
break
|
||||
for message in messages:
|
||||
plan_sandbox_clean_message_day = message.created_at
|
||||
app = App.query.filter_by(id=message.app_id).first()
|
||||
app = db.session.query(App).filter_by(id=message.app_id).first()
|
||||
if not app:
|
||||
_logger.warning(
|
||||
"Expected App record to exist, but none was found, app_id=%s, message_id=%s",
|
||||
message.app_id,
|
||||
message.id,
|
||||
)
|
||||
continue
|
||||
features_cache_key = f"features:{app.tenant_id}"
|
||||
plan_cache = redis_client.get(features_cache_key)
|
||||
if plan_cache is None:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import datetime
|
|||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import app
|
||||
|
|
@ -51,8 +51,9 @@ def clean_unused_datasets_task():
|
|||
)
|
||||
|
||||
# Main query with join and filter
|
||||
datasets = (
|
||||
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
|
||||
stmt = (
|
||||
select(Dataset)
|
||||
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
|
||||
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
|
||||
.filter(
|
||||
Dataset.created_at < plan_sandbox_clean_day,
|
||||
|
|
@ -60,9 +61,10 @@ def clean_unused_datasets_task():
|
|||
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
|
||||
)
|
||||
.order_by(Dataset.created_at.desc())
|
||||
.paginate(page=1, per_page=50)
|
||||
)
|
||||
|
||||
datasets = db.paginate(stmt, page=1, per_page=50)
|
||||
|
||||
except NotFound:
|
||||
break
|
||||
if datasets.items is None or len(datasets.items) == 0:
|
||||
|
|
@ -99,7 +101,7 @@ def clean_unused_datasets_task():
|
|||
# update document
|
||||
update_params = {Document.enabled: False}
|
||||
|
||||
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
|
||||
db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
|
||||
db.session.commit()
|
||||
click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
|
||||
except Exception as e:
|
||||
|
|
@ -135,8 +137,9 @@ def clean_unused_datasets_task():
|
|||
)
|
||||
|
||||
# Main query with join and filter
|
||||
datasets = (
|
||||
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
|
||||
stmt = (
|
||||
select(Dataset)
|
||||
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
|
||||
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
|
||||
.filter(
|
||||
Dataset.created_at < plan_pro_clean_day,
|
||||
|
|
@ -144,8 +147,8 @@ def clean_unused_datasets_task():
|
|||
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
|
||||
)
|
||||
.order_by(Dataset.created_at.desc())
|
||||
.paginate(page=1, per_page=50)
|
||||
)
|
||||
datasets = db.paginate(stmt, page=1, per_page=50)
|
||||
|
||||
except NotFound:
|
||||
break
|
||||
|
|
@ -175,7 +178,7 @@ def clean_unused_datasets_task():
|
|||
# update document
|
||||
update_params = {Document.enabled: False}
|
||||
|
||||
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
|
||||
db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
|
||||
db.session.commit()
|
||||
click.echo(
|
||||
click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ def create_tidb_serverless_task():
|
|||
while True:
|
||||
try:
|
||||
# check the number of idle tidb serverless
|
||||
idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count()
|
||||
idle_tidb_serverless_number = (
|
||||
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count()
|
||||
)
|
||||
if idle_tidb_serverless_number >= tidb_serverless_number:
|
||||
break
|
||||
# create tidb serverless
|
||||
|
|
|
|||
|
|
@ -29,7 +29,9 @@ def mail_clean_document_notify_task():
|
|||
|
||||
# send document clean notify mail
|
||||
try:
|
||||
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
|
||||
dataset_auto_disable_logs = (
|
||||
db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all()
|
||||
)
|
||||
# group by tenant_id
|
||||
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
|
||||
for dataset_auto_disable_log in dataset_auto_disable_logs:
|
||||
|
|
@ -43,14 +45,16 @@ def mail_clean_document_notify_task():
|
|||
if plan != "sandbox":
|
||||
knowledge_details = []
|
||||
# check tenant
|
||||
tenant = Tenant.query.filter(Tenant.id == tenant_id).first()
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
# check current owner
|
||||
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
current_owner_join = (
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
)
|
||||
if not current_owner_join:
|
||||
continue
|
||||
account = Account.query.filter(Account.id == current_owner_join.account_id).first()
|
||||
account = db.session.query(Account).filter(Account.id == current_owner_join.account_id).first()
|
||||
if not account:
|
||||
continue
|
||||
|
||||
|
|
@ -63,7 +67,7 @@ def mail_clean_document_notify_task():
|
|||
)
|
||||
|
||||
for dataset_id, document_ids in dataset_auto_dataset_map.items():
|
||||
dataset = Dataset.query.filter(Dataset.id == dataset_id).first()
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
document_count = len(document_ids)
|
||||
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import click
|
|||
import app
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import TidbAuthBinding
|
||||
|
||||
|
||||
|
|
@ -14,9 +15,11 @@ def update_tidb_serverless_status_task():
|
|||
start_at = time.perf_counter()
|
||||
try:
|
||||
# check the number of idle tidb serverless
|
||||
tidb_serverless_list = TidbAuthBinding.query.filter(
|
||||
TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING"
|
||||
).all()
|
||||
tidb_serverless_list = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
|
||||
.all()
|
||||
)
|
||||
if len(tidb_serverless_list) == 0:
|
||||
return
|
||||
# update tidb serverless status
|
||||
|
|
|
|||
|
|
@ -108,17 +108,20 @@ class AccountService:
|
|||
if account.status == AccountStatus.BANNED.value:
|
||||
raise Unauthorized("Account is banned.")
|
||||
|
||||
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
|
||||
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
||||
if current_tenant:
|
||||
account.current_tenant_id = current_tenant.tenant_id
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
else:
|
||||
available_ta = (
|
||||
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter_by(account_id=account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
.first()
|
||||
)
|
||||
if not available_ta:
|
||||
return None
|
||||
|
||||
account.current_tenant_id = available_ta.tenant_id
|
||||
account.set_tenant_id(available_ta.tenant_id)
|
||||
available_ta.current = True
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -297,9 +300,9 @@ class AccountService:
|
|||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
|
||||
account_id=account.id, provider=provider
|
||||
).first()
|
||||
account_integrate: Optional[AccountIntegrate] = (
|
||||
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
|
||||
)
|
||||
|
||||
if account_integrate:
|
||||
# If it exists, update the record
|
||||
|
|
@ -612,7 +615,10 @@ class TenantService:
|
|||
):
|
||||
"""Check if user have a workspace or not"""
|
||||
available_ta = (
|
||||
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter_by(account_id=account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if available_ta:
|
||||
|
|
@ -670,7 +676,7 @@ class TenantService:
|
|||
if not tenant:
|
||||
raise TenantNotFoundError("Tenant not found.")
|
||||
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
if ta:
|
||||
tenant.role = ta.role
|
||||
else:
|
||||
|
|
@ -699,12 +705,12 @@ class TenantService:
|
|||
if not tenant_account_join:
|
||||
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
||||
else:
|
||||
TenantAccountJoin.query.filter(
|
||||
db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
|
||||
).update({"current": False})
|
||||
tenant_account_join.current = True
|
||||
# Set the current tenant for the account
|
||||
account.current_tenant_id = tenant_account_join.tenant_id
|
||||
account.set_tenant_id(tenant_account_join.tenant_id)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -791,7 +797,7 @@ class TenantService:
|
|||
if operator.id == member.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
|
||||
ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first()
|
||||
|
||||
if not ta_operator or ta_operator.role not in perms[action]:
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
|
|
@ -804,7 +810,7 @@ class TenantService:
|
|||
|
||||
TenantService.check_member_permission(tenant, operator, account, "remove")
|
||||
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
if not ta:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
|
||||
|
|
@ -816,15 +822,23 @@ class TenantService:
|
|||
"""Update member role"""
|
||||
TenantService.check_member_permission(tenant, operator, member, "update")
|
||||
|
||||
target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
|
||||
target_member_join = (
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first()
|
||||
)
|
||||
|
||||
if not target_member_join:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
|
||||
if target_member_join.role == new_role:
|
||||
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
|
||||
|
||||
if new_role == "owner":
|
||||
# Find the current owner and change their role to 'admin'
|
||||
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
current_owner_join.role = "admin"
|
||||
current_owner_join = (
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
)
|
||||
if current_owner_join:
|
||||
current_owner_join.role = "admin"
|
||||
|
||||
# Update the role of the target member
|
||||
target_member_join.role = new_role
|
||||
|
|
@ -841,7 +855,7 @@ class TenantService:
|
|||
|
||||
@staticmethod
|
||||
def get_custom_config(tenant_id: str) -> dict:
|
||||
tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404()
|
||||
tenant = db.get_or_404(Tenant, tenant_id)
|
||||
|
||||
return cast(dict, tenant.custom_config_dict)
|
||||
|
||||
|
|
@ -966,7 +980,7 @@ class RegisterService:
|
|||
TenantService.switch_tenant(account, tenant.id)
|
||||
else:
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add")
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
|
||||
if not ta:
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import cast
|
|||
|
||||
import pandas as pd
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, select
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
|
@ -124,8 +124,9 @@ class AppAnnotationService:
|
|||
if not app:
|
||||
raise NotFound("App not found")
|
||||
if keyword:
|
||||
annotations = (
|
||||
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
|
||||
stmt = (
|
||||
select(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.filter(
|
||||
or_(
|
||||
MessageAnnotation.question.ilike("%{}%".format(keyword)),
|
||||
|
|
@ -133,14 +134,14 @@ class AppAnnotationService:
|
|||
)
|
||||
)
|
||||
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
else:
|
||||
annotations = (
|
||||
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
|
||||
stmt = (
|
||||
select(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
return annotations.items, annotations.total
|
||||
|
||||
@classmethod
|
||||
|
|
@ -325,13 +326,16 @@ class AppAnnotationService:
|
|||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
annotation_hit_histories = (
|
||||
AppAnnotationHitHistory.query.filter(
|
||||
stmt = (
|
||||
select(AppAnnotationHitHistory)
|
||||
.filter(
|
||||
AppAnnotationHitHistory.app_id == app_id,
|
||||
AppAnnotationHitHistory.annotation_id == annotation_id,
|
||||
)
|
||||
.order_by(AppAnnotationHitHistory.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
annotation_hit_histories = db.paginate(
|
||||
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
|
||||
)
|
||||
return annotation_hit_histories.items, annotation_hit_histories.total
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from collections import Counter
|
|||
from typing import Any, Optional
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
|
@ -77,11 +77,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
|
|||
class DatasetService:
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
||||
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
|
||||
query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
|
||||
|
||||
if user:
|
||||
# get permitted dataset ids
|
||||
dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all()
|
||||
dataset_permission = (
|
||||
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
|
||||
)
|
||||
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
|
||||
|
||||
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
|
||||
|
|
@ -129,7 +131,7 @@ class DatasetService:
|
|||
else:
|
||||
return [], 0
|
||||
|
||||
datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||
datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||
|
||||
return datasets.items, datasets.total
|
||||
|
||||
|
|
@ -153,9 +155,10 @@ class DatasetService:
|
|||
|
||||
@staticmethod
|
||||
def get_datasets_by_ids(ids, tenant_id):
|
||||
datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate(
|
||||
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False
|
||||
)
|
||||
stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
|
||||
|
||||
return datasets.items, datasets.total
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -174,7 +177,7 @@ class DatasetService:
|
|||
retrieval_model: Optional[RetrievalModel] = None,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
||||
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
||||
embedding_model = None
|
||||
if indexing_technique == "high_quality":
|
||||
|
|
@ -235,7 +238,7 @@ class DatasetService:
|
|||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first()
|
||||
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -436,7 +439,7 @@ class DatasetService:
|
|||
# update Retrieval model
|
||||
filtered_data["retrieval_model"] = data["retrieval_model"]
|
||||
|
||||
dataset.query.filter_by(id=dataset_id).update(filtered_data)
|
||||
db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
|
||||
|
||||
db.session.commit()
|
||||
if action:
|
||||
|
|
@ -460,7 +463,7 @@ class DatasetService:
|
|||
|
||||
@staticmethod
|
||||
def dataset_use_check(dataset_id) -> bool:
|
||||
count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count()
|
||||
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
|
||||
if count > 0:
|
||||
return True
|
||||
return False
|
||||
|
|
@ -475,7 +478,9 @@ class DatasetService:
|
|||
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
if dataset.permission == "partial_members":
|
||||
user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first()
|
||||
user_permission = (
|
||||
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
|
||||
)
|
||||
if (
|
||||
not user_permission
|
||||
and dataset.tenant_id != user.current_tenant_id
|
||||
|
|
@ -499,23 +504,24 @@ class DatasetService:
|
|||
|
||||
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
if not any(
|
||||
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
|
||||
dp.dataset_id == dataset.id
|
||||
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
|
||||
):
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
|
||||
dataset_queries = (
|
||||
DatasetQuery.query.filter_by(dataset_id=dataset_id)
|
||||
.order_by(db.desc(DatasetQuery.created_at))
|
||||
.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||
)
|
||||
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
|
||||
|
||||
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||
|
||||
return dataset_queries.items, dataset_queries.total
|
||||
|
||||
@staticmethod
|
||||
def get_related_apps(dataset_id: str):
|
||||
return (
|
||||
AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id)
|
||||
db.session.query(AppDatasetJoin)
|
||||
.filter(AppDatasetJoin.dataset_id == dataset_id)
|
||||
.order_by(db.desc(AppDatasetJoin.created_at))
|
||||
.all()
|
||||
)
|
||||
|
|
@ -530,10 +536,14 @@ class DatasetService:
|
|||
}
|
||||
# get recent 30 days auto disable logs
|
||||
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
|
||||
DatasetAutoDisableLog.dataset_id == dataset_id,
|
||||
DatasetAutoDisableLog.created_at >= start_date,
|
||||
).all()
|
||||
dataset_auto_disable_logs = (
|
||||
db.session.query(DatasetAutoDisableLog)
|
||||
.filter(
|
||||
DatasetAutoDisableLog.dataset_id == dataset_id,
|
||||
DatasetAutoDisableLog.created_at >= start_date,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if dataset_auto_disable_logs:
|
||||
return {
|
||||
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
|
||||
|
|
@ -873,7 +883,9 @@ class DocumentService:
|
|||
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
document = (
|
||||
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
)
|
||||
if document:
|
||||
return document.position + 1
|
||||
else:
|
||||
|
|
@ -1010,13 +1022,17 @@ class DocumentService:
|
|||
}
|
||||
# check duplicate
|
||||
if knowledge_config.duplicate:
|
||||
document = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="upload_file",
|
||||
enabled=True,
|
||||
name=file_name,
|
||||
).first()
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="upload_file",
|
||||
enabled=True,
|
||||
name=file_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
|
|
@ -1054,12 +1070,16 @@ class DocumentService:
|
|||
raise ValueError("No notion info list found.")
|
||||
exist_page_ids = []
|
||||
exist_document = {}
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
).all()
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
|
|
@ -1067,14 +1087,18 @@ class DocumentService:
|
|||
exist_document[data_source_info["notion_page_id"]] = document.id
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
).first()
|
||||
.first()
|
||||
)
|
||||
if not data_source_binding:
|
||||
raise ValueError("Data source binding not found.")
|
||||
for page in notion_info.pages:
|
||||
|
|
@ -1206,12 +1230,16 @@ class DocumentService:
|
|||
|
||||
@staticmethod
|
||||
def get_tenant_documents_count():
|
||||
documents_count = Document.query.filter(
|
||||
Document.completed_at.isnot(None),
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
).count()
|
||||
documents_count = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.completed_at.isnot(None),
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
return documents_count
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -1278,14 +1306,18 @@ class DocumentService:
|
|||
notion_info_list = document_data.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
).first()
|
||||
.first()
|
||||
)
|
||||
if not data_source_binding:
|
||||
raise ValueError("Data source binding not found.")
|
||||
for page in notion_info.pages:
|
||||
|
|
@ -1328,7 +1360,7 @@ class DocumentService:
|
|||
db.session.commit()
|
||||
# update document segment
|
||||
update_params = {DocumentSegment.status: "re_segment"}
|
||||
DocumentSegment.query.filter_by(document_id=document.id).update(update_params)
|
||||
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params)
|
||||
db.session.commit()
|
||||
# trigger async task
|
||||
document_indexing_update_task.delay(document.dataset_id, document.id)
|
||||
|
|
@ -1918,7 +1950,8 @@ class SegmentService:
|
|||
@classmethod
|
||||
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
|
||||
index_node_ids = (
|
||||
DocumentSegment.query.with_entities(DocumentSegment.index_node_id)
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(DocumentSegment.index_node_id)
|
||||
.filter(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
|
|
@ -2157,20 +2190,28 @@ class SegmentService:
|
|||
def get_child_chunks(
|
||||
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
|
||||
):
|
||||
query = ChildChunk.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
segment_id=segment_id,
|
||||
).order_by(ChildChunk.position.asc())
|
||||
query = (
|
||||
select(ChildChunk)
|
||||
.filter_by(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
segment_id=segment_id,
|
||||
)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
)
|
||||
if keyword:
|
||||
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
|
||||
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
@classmethod
|
||||
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
|
||||
"""Get a child chunk by its ID."""
|
||||
result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first()
|
||||
result = (
|
||||
db.session.query(ChildChunk)
|
||||
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
return result if isinstance(result, ChildChunk) else None
|
||||
|
||||
@classmethod
|
||||
|
|
@ -2184,7 +2225,7 @@ class SegmentService:
|
|||
limit: int = 20,
|
||||
):
|
||||
"""Get segments for a document with optional filtering."""
|
||||
query = DocumentSegment.query.filter(
|
||||
query = select(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
|
||||
)
|
||||
|
||||
|
|
@ -2194,9 +2235,8 @@ class SegmentService:
|
|||
if keyword:
|
||||
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||
|
||||
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
|
||||
page=page, per_page=limit, max_per_page=100, error_out=False
|
||||
)
|
||||
query = query.order_by(DocumentSegment.position.asc())
|
||||
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
return paginated_segments.items, paginated_segments.total
|
||||
|
||||
|
|
@ -2236,9 +2276,11 @@ class SegmentService:
|
|||
raise ValueError(ex.description)
|
||||
|
||||
# check segment
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
|
|
@ -2251,9 +2293,11 @@ class SegmentService:
|
|||
@classmethod
|
||||
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
|
||||
"""Get a segment by its ID."""
|
||||
result = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id
|
||||
).first()
|
||||
result = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
return result if isinstance(result, DocumentSegment) else None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast
|
|||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.helper import ssrf_proxy
|
||||
|
|
@ -24,14 +25,20 @@ from services.errors.dataset import DatasetNameDuplicateError
|
|||
|
||||
class ExternalDatasetService:
|
||||
@staticmethod
|
||||
def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]:
|
||||
query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by(
|
||||
ExternalKnowledgeApis.created_at.desc()
|
||||
def get_external_knowledge_apis(
|
||||
page, per_page, tenant_id, search=None
|
||||
) -> tuple[list[ExternalKnowledgeApis], int | None]:
|
||||
query = (
|
||||
select(ExternalKnowledgeApis)
|
||||
.filter(ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.order_by(ExternalKnowledgeApis.created_at.desc())
|
||||
)
|
||||
if search:
|
||||
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
|
||||
|
||||
external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
|
||||
external_knowledge_apis = db.paginate(
|
||||
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
|
||||
)
|
||||
|
||||
return external_knowledge_apis.items, external_knowledge_apis.total
|
||||
|
||||
|
|
@ -92,18 +99,18 @@ class ExternalDatasetService:
|
|||
|
||||
@staticmethod
|
||||
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
|
||||
id=external_knowledge_api_id
|
||||
).first()
|
||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
return external_knowledge_api
|
||||
|
||||
@staticmethod
|
||||
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
|
||||
id=external_knowledge_api_id, tenant_id=tenant_id
|
||||
).first()
|
||||
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
|
||||
|
|
@ -120,9 +127,9 @@ class ExternalDatasetService:
|
|||
|
||||
@staticmethod
|
||||
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
|
||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
||||
id=external_knowledge_api_id, tenant_id=tenant_id
|
||||
).first()
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
|
||||
|
|
@ -131,25 +138,29 @@ class ExternalDatasetService:
|
|||
|
||||
@staticmethod
|
||||
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
|
||||
count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count()
|
||||
count = (
|
||||
db.session.query(ExternalKnowledgeBindings)
|
||||
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
|
||||
.count()
|
||||
)
|
||||
if count > 0:
|
||||
return True, count
|
||||
return False, 0
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by(
|
||||
dataset_id=dataset_id, tenant_id=tenant_id
|
||||
).first()
|
||||
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("external knowledge binding not found")
|
||||
return external_knowledge_binding
|
||||
|
||||
@staticmethod
|
||||
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
|
||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
||||
id=external_knowledge_api_id, tenant_id=tenant_id
|
||||
).first()
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
settings = json.loads(external_knowledge_api.settings)
|
||||
|
|
@ -212,11 +223,13 @@ class ExternalDatasetService:
|
|||
@staticmethod
|
||||
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
|
||||
# check if dataset name already exists
|
||||
if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first():
|
||||
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
|
||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
||||
id=args.get("external_knowledge_api_id"), tenant_id=tenant_id
|
||||
).first()
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis)
|
||||
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
|
|
@ -254,15 +267,17 @@ class ExternalDatasetService:
|
|||
external_retrieval_parameters: dict,
|
||||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
) -> list:
|
||||
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
|
||||
dataset_id=dataset_id, tenant_id=tenant_id
|
||||
).first()
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("external knowledge binding not found")
|
||||
|
||||
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
|
||||
id=external_knowledge_binding.external_knowledge_api_id
|
||||
).first()
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis)
|
||||
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
|
||||
.first()
|
||||
)
|
||||
if not external_knowledge_api:
|
||||
raise ValueError("external api template not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ class HitTestingService:
|
|||
query: str,
|
||||
account: Account,
|
||||
external_retrieval_model: dict,
|
||||
metadata_filtering_conditions: dict,
|
||||
) -> dict:
|
||||
if dataset.provider != "external":
|
||||
return {
|
||||
|
|
@ -82,6 +83,7 @@ class HitTestingService:
|
|||
dataset_id=dataset.id,
|
||||
query=cls.escape_query_for_search(query),
|
||||
external_retrieval_model=external_retrieval_model,
|
||||
metadata_filtering_conditions=metadata_filtering_conditions,
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
|
|
|||
|
|
@ -177,6 +177,21 @@ class MessageService:
|
|||
|
||||
return feedback
|
||||
|
||||
@classmethod
|
||||
def get_all_messages_feedbacks(cls, app_model: App, page: int, limit: int):
|
||||
"""Get all feedbacks of an app"""
|
||||
offset = (page - 1) * limit
|
||||
feedbacks = (
|
||||
db.session.query(MessageFeedback)
|
||||
.filter(MessageFeedback.app_id == app_model.id)
|
||||
.order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [record.to_dict() for record in feedbacks]
|
||||
|
||||
@classmethod
|
||||
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
message = (
|
||||
|
|
|
|||
|
|
@ -20,9 +20,11 @@ class MetadataService:
|
|||
@staticmethod
|
||||
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
|
||||
# check if metadata name already exists
|
||||
if DatasetMetadata.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name
|
||||
).first():
|
||||
if (
|
||||
db.session.query(DatasetMetadata)
|
||||
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
|
||||
.first()
|
||||
):
|
||||
raise ValueError("Metadata name already exists.")
|
||||
for field in BuiltInField:
|
||||
if field.value == metadata_args.name:
|
||||
|
|
@ -42,16 +44,18 @@ class MetadataService:
|
|||
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
# check if metadata name already exists
|
||||
if DatasetMetadata.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name
|
||||
).first():
|
||||
if (
|
||||
db.session.query(DatasetMetadata)
|
||||
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name)
|
||||
.first()
|
||||
):
|
||||
raise ValueError("Metadata name already exists.")
|
||||
for field in BuiltInField:
|
||||
if field.value == name:
|
||||
raise ValueError("Metadata name already exists in Built-in fields.")
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
|
||||
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
old_name = metadata.name
|
||||
|
|
@ -60,7 +64,9 @@ class MetadataService:
|
|||
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
|
||||
# update related documents
|
||||
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
|
||||
dataset_metadata_bindings = (
|
||||
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
|
||||
)
|
||||
if dataset_metadata_bindings:
|
||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
|
|
@ -82,13 +88,15 @@ class MetadataService:
|
|||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
|
||||
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
db.session.delete(metadata)
|
||||
|
||||
# deal related documents
|
||||
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
|
||||
dataset_metadata_bindings = (
|
||||
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
|
||||
)
|
||||
if dataset_metadata_bindings:
|
||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
|
|
@ -193,7 +201,7 @@ class MetadataService:
|
|||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# deal metadata binding
|
||||
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete()
|
||||
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
|
||||
for metadata_value in operation.metadata_list:
|
||||
dataset_metadata_binding = DatasetMetadataBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
|
@ -230,9 +238,9 @@ class MetadataService:
|
|||
"id": item.get("id"),
|
||||
"name": item.get("name"),
|
||||
"type": item.get("type"),
|
||||
"count": DatasetMetadataBinding.query.filter_by(
|
||||
metadata_id=item.get("id"), dataset_id=dataset.id
|
||||
).count(),
|
||||
"count": db.session.query(DatasetMetadataBinding)
|
||||
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
|
||||
.count(),
|
||||
}
|
||||
for item in dataset.doc_metadata or []
|
||||
if item.get("id") != "built-in"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
|
|
@ -12,17 +13,27 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
|
|||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorService:
|
||||
@classmethod
|
||||
def create_segments_vector(
|
||||
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
|
||||
):
|
||||
documents = []
|
||||
documents: list[Document] = []
|
||||
|
||||
document: Document | None = None
|
||||
for segment in segments:
|
||||
if doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
document = DatasetDocument.query.filter_by(id=segment.document_id).first()
|
||||
document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
|
||||
if not document:
|
||||
_logger.warning(
|
||||
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
|
||||
segment.document_id,
|
||||
segment.id,
|
||||
)
|
||||
continue
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import threading
|
|||
from typing import Optional
|
||||
|
||||
import contexts
|
||||
from core.workflow.repository import RepositoryFactory
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
|
|
@ -129,12 +129,8 @@ class WorkflowRunService:
|
|||
return []
|
||||
|
||||
# Use the repository to get the node executions
|
||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": app_model.tenant_id,
|
||||
"app_id": app_model.id,
|
||||
"session_factory": db.session.get_bind(),
|
||||
}
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
# Use the repository to get the node executions with ordering
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
|||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.variables import Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
|
|
@ -21,7 +22,6 @@ from core.workflow.nodes.enums import ErrorStrategy
|
|||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.event.types import NodeEvent
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.repository import RepositoryFactory
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -285,12 +285,8 @@ class WorkflowService:
|
|||
workflow_node_execution.workflow_id = draft_workflow.id
|
||||
|
||||
# Use the repository to save the workflow node execution
|
||||
repository = RepositoryFactory.create_workflow_node_execution_repository(
|
||||
params={
|
||||
"tenant_id": app_model.tenant_id,
|
||||
"app_id": app_model.id,
|
||||
"session_factory": db.session.get_bind(),
|
||||
}
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id
|
||||
)
|
||||
repository.save(workflow_node_execution)
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue