Merge branch 'main' into e-300

This commit is contained in:
NFish 2025-05-15 10:34:44 +08:00
commit b38286f06a
523 changed files with 14695 additions and 4125 deletions

View File

@ -476,6 +476,7 @@ LOGIN_LOCKOUT_DURATION=86400
ENABLE_OTEL=false
OTLP_BASE_ENDPOINT=http://localhost:4318
OTLP_API_KEY=
OTEL_EXPORTER_OTLP_PROTOCOL=
OTEL_EXPORTER_TYPE=otlp
OTEL_SAMPLING_RATE=0.1
OTEL_BATCH_EXPORT_SCHEDULE_DELAY=5000

View File

@ -54,7 +54,6 @@ def initialize_extensions(app: DifyApp):
ext_otel,
ext_proxy_fix,
ext_redis,
ext_repositories,
ext_sentry,
ext_set_secretkey,
ext_storage,
@ -75,7 +74,6 @@ def initialize_extensions(app: DifyApp):
ext_migrate,
ext_redis,
ext_storage,
ext_repositories,
ext_celery,
ext_login,
ext_mail,

View File

@ -6,6 +6,7 @@ from typing import Optional
import click
from flask import current_app
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
@ -297,11 +298,11 @@ def migrate_knowledge_vector_database():
page = 1
while True:
try:
datasets = (
Dataset.query.filter(Dataset.indexing_technique == "high_quality")
.order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50)
stmt = (
select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
)
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
except NotFound:
break
@ -551,11 +552,12 @@ def old_metadata_migration():
page = 1
while True:
try:
documents = (
DatasetDocument.query.filter(DatasetDocument.doc_metadata is not None)
stmt = (
select(DatasetDocument)
.filter(DatasetDocument.doc_metadata.is_not(None))
.order_by(DatasetDocument.created_at.desc())
.paginate(page=page, per_page=50)
)
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
except NotFound:
break
if not documents:
@ -592,11 +594,15 @@ def old_metadata_migration():
)
db.session.add(dataset_metadata_binding)
else:
dataset_metadata_binding = DatasetMetadataBinding.query.filter(
DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
).first()
dataset_metadata_binding = (
db.session.query(DatasetMetadataBinding) # type: ignore
.filter(
DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
)
.first()
)
if not dataset_metadata_binding:
dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=document.tenant_id,

View File

@ -74,7 +74,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
description="URL endpoint for the code execution service",
default="http://sandbox:8194",
default=HttpUrl("http://sandbox:8194"),
)
CODE_EXECUTION_API_KEY: str = Field(
@ -145,7 +145,7 @@ class PluginConfig(BaseSettings):
PLUGIN_DAEMON_URL: HttpUrl = Field(
description="Plugin API URL",
default="http://localhost:5002",
default=HttpUrl("http://localhost:5002"),
)
PLUGIN_DAEMON_KEY: str = Field(
@ -188,7 +188,7 @@ class MarketplaceConfig(BaseSettings):
MARKETPLACE_API_URL: HttpUrl = Field(
description="Marketplace API URL",
default="https://marketplace.dify.ai",
default=HttpUrl("https://marketplace.dify.ai"),
)

View File

@ -1,6 +1,6 @@
import os
from typing import Any, Literal, Optional
from urllib.parse import quote_plus
from urllib.parse import parse_qsl, quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings
@ -173,17 +173,31 @@ class DatabaseConfig(BaseSettings):
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
description="Number of processes for the retrieval service, default to CPU cores.",
default=os.cpu_count(),
default=os.cpu_count() or 1,
)
@computed_field
@computed_field # type: ignore[misc]
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options'
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
options = db_extras_dict.get("options", "")
# Always include timezone
timezone_opt = "-c timezone=UTC"
if options:
# Merge user options and timezone
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
return {
"pool_size": self.SQLALCHEMY_POOL_SIZE,
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": {"options": "-c timezone=UTC"},
"connect_args": connect_args,
}

View File

@ -83,3 +83,13 @@ class RedisConfig(BaseSettings):
description="Password for Redis Clusters authentication (if required)",
default=None,
)
REDIS_SERIALIZATION_PROTOCOL: int = Field(
description="Redis serialization protocol (RESP) version",
default=3,
)
REDIS_ENABLE_CLIENT_SIDE_CACHE: bool = Field(
description="Enable client side cache in redis",
default=False,
)

View File

@ -27,6 +27,11 @@ class OTelConfig(BaseSettings):
default="otlp",
)
OTEL_EXPORTER_OTLP_PROTOCOL: str = Field(
description="OTLP exporter protocol ('grpc' or 'http')",
default="http",
)
OTEL_SAMPLING_RATE: float = Field(default=0.1, description="Sampling rate for traces (0.0 to 1.0)")
OTEL_BATCH_EXPORT_SCHEDULE_DELAY: int = Field(

View File

@ -1,5 +1,7 @@
from flask_restful import fields
from libs.helper import AppIconUrlField
parameters__system_parameters = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
@ -22,3 +24,20 @@ parameters_fields = {
"file_upload": fields.Raw,
"system_parameters": fields.Nested(parameters__system_parameters),
}
site_fields = {
"title": fields.String,
"chat_color_theme": fields.String,
"chat_color_theme_inverted": fields.Boolean,
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"description": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"default_language": fields.String,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
}

View File

@ -526,14 +526,20 @@ class DatasetIndexingStatusApi(Resource):
)
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
documents_status.append(marshal(document, document_status_fields))

View File

@ -6,7 +6,7 @@ from typing import cast
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -112,7 +112,7 @@ class GetProcessRuleApi(Resource):
limits = DocumentService.DEFAULT_RULES["limits"]
if document_id:
# get the latest process rule
document = Document.query.get_or_404(document_id)
document = db.get_or_404(Document, document_id)
dataset = DatasetService.get_dataset(document.dataset_id)
@ -175,7 +175,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
if search:
search = f"%{search}%"
@ -209,18 +209,24 @@ class DatasetDocumentListApi(Resource):
desc(Document.position),
)
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
if fetch:
for document in documents:
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields)
@ -563,14 +569,20 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
documents = self.get_batch_documents(dataset_id, batch)
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
@ -589,14 +601,20 @@ class DocumentIndexingStatusApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment"
).count()
completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments

View File

@ -4,6 +4,7 @@ import pandas as pd
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -26,6 +27,7 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required
@ -74,9 +76,14 @@ class DatasetDocumentSegmentListApi(Resource):
hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"]
query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).order_by(DocumentSegment.position.asc())
query = (
select(DocumentSegment)
.filter(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id,
)
.order_by(DocumentSegment.position.asc())
)
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
@ -93,7 +100,7 @@ class DatasetDocumentSegmentListApi(Resource):
elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False)
segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
response = {
"data": marshal(segments.items, segment_fields),
@ -276,9 +283,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -320,9 +329,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -423,9 +434,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_dataset_editor:
@ -478,9 +491,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
parser = reqparse.RequestParser()
@ -523,9 +538,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -567,16 +584,20 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
child_chunk = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -612,16 +633,20 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
child_chunk = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor

View File

@ -209,6 +209,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
@ -219,6 +220,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
query=args["query"],
account=current_user,
external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"],
)
return response

View File

@ -87,7 +87,7 @@ class InstalledAppsListApi(Resource):
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args()
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first()
if recommended_app is None:
raise NotFound("App not found")
@ -100,9 +100,11 @@ class InstalledAppsListApi(Resource):
if not app.is_public:
raise Forbidden("You can't install a non-public app")
installed_app = InstalledApp.query.filter(
and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)
).first()
installed_app = (
db.session.query(InstalledApp)
.filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.first()
)
if installed_app is None:
# todo: position

View File

@ -79,7 +79,6 @@ class MemberInviteEmailApi(Resource):
invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
)
break
except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})

View File

@ -3,6 +3,7 @@ import logging
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
import services
@ -88,9 +89,8 @@ class WorkspaceListApi(Resource):
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(
page=args["page"], per_page=args["limit"], error_out=False
)
stmt = select(Tenant).order_by(Tenant.created_at.desc())
tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False)
has_more = False
if tenants.has_next:
@ -162,7 +162,7 @@ class CustomConfigWorkspaceApi(Resource):
parser.add_argument("replace_webapp_logo", type=str, location="json")
args = parser.parse_args()
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404()
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"],
@ -226,7 +226,7 @@ class WorkspaceInfoApi(Resource):
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404()
tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant.name = args["name"]
db.session.commit()

View File

@ -6,6 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp)
from . import index
from .app import annotation, app, audio, completion, conversation, file, message, workflow
from .app import annotation, app, audio, completion, conversation, file, message, site, workflow
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
from .workspace import models

View File

@ -93,6 +93,18 @@ class MessageFeedbackApi(Resource):
return {"result": "success"}
class AppGetFeedbacksApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""Get All Feedbacks of an app"""
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, default=1, location="args")
parser.add_argument("limit", type=int_range(1, 101), required=False, default=20, location="args")
args = parser.parse_args()
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"])
return {"data": feedbacks}
class MessageSuggestedApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, message_id):
@ -119,3 +131,4 @@ class MessageSuggestedApi(Resource):
api.add_resource(MessageListApi, "/messages")
api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks")
api.add_resource(MessageSuggestedApi, "/messages/<uuid:message_id>/suggested")
api.add_resource(AppGetFeedbacksApi, "/app/feedbacks")

View File

@ -0,0 +1,30 @@
from flask_restful import Resource, marshal_with
from werkzeug.exceptions import Forbidden
from controllers.common import fields
from controllers.service_api import api
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.account import TenantStatus
from models.model import App, Site
class AppSiteApi(Resource):
"""Resource for app sites."""
@validate_app_token
@marshal_with(fields.site_fields)
def get(self, app_model: App):
"""Retrieve app site info."""
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return site
api.add_resource(AppSiteApi, "/site")

View File

@ -313,7 +313,7 @@ class DatasetApi(DatasetApiResource):
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {"result": "success"}, 204
return 204
else:
raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError:

View File

@ -2,10 +2,10 @@ import json
from flask import request
from flask_restful import marshal, reqparse
from sqlalchemy import desc
from sqlalchemy import desc, select
from werkzeug.exceptions import NotFound
import services.dataset_service
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api
from controllers.service_api.app.error import (
@ -323,7 +323,7 @@ class DocumentDeleteApi(DatasetApiResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
return 204
class DocumentListApi(DatasetApiResource):
@ -337,7 +337,7 @@ class DocumentListApi(DatasetApiResource):
if not dataset:
raise NotFound("Dataset not found.")
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
if search:
search = f"%{search}%"
@ -345,7 +345,7 @@ class DocumentListApi(DatasetApiResource):
query = query.order_by(desc(Document.created_at), desc(Document.position))
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
response = {
@ -374,14 +374,20 @@ class DocumentIndexingStatusApi(DatasetApiResource):
raise NotFound("Documents not found.")
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:

View File

@ -159,7 +159,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
return {"result": "success"}, 204
return 204
@cloud_edition_billing_resource_check("vector_space", "dataset")
def post(self, tenant_id, dataset_id, document_id, segment_id):
@ -344,7 +344,7 @@ class DatasetChildChunkApi(DatasetApiResource):
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return {"result": "success"}, 204
return 204
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")

View File

@ -1,11 +1,11 @@
from flask import request
from flask_restful import Resource, marshal_with, reqparse
from controllers.common import fields
from controllers.web import api
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
from core.app.app_config.common.parameters_mapping import \
get_parameters_from_feature_dict
from flask import request
from flask_restful import Resource, marshal_with, reqparse
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from libs.passport import PassportService
from models.model import App, AppMode
from services.app_service import AppService

View File

@ -91,6 +91,8 @@ class BaseAgentRunner(AppRunner):
return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback,
user_id=user_id,
inputs=cast(dict, application_generate_entity.inputs),
)
# get how many agent thoughts have been created
self.agent_thought_count = (

View File

@ -69,13 +69,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools
# fix metadata filter not work
if app_config.dataset is not None:
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
for key, dataset_retriever_tool in tool_instances.items():
if hasattr(dataset_retriever_tool, "retrieval_tool"):
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = ""
@ -87,6 +80,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.total_tokens += usage.total_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price

View File

@ -45,13 +45,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
# fix metadata filter not work
if app_config.dataset is not None:
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
for key, dataset_retriever_tool in tool_instances.items():
if hasattr(dataset_retriever_tool, "retrieval_tool"):
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
assert app_config.agent
iteration_step = 1
@ -72,6 +65,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.total_tokens += usage.total_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price

View File

@ -25,7 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.workflow.repository import RepositoryFactory
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from factories import file_factory
@ -163,12 +163,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id,
app_id=application_generate_entity.app_config.app_id,
)
return self._generate(
@ -231,12 +229,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id,
app_id=application_generate_entity.app_config.app_id,
)
return self._generate(
@ -297,12 +293,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id,
app_id=application_generate_entity.app_config.app_id,
)
return self._generate(

View File

@ -9,7 +9,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
@ -58,7 +57,7 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
@ -66,6 +65,7 @@ from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
from models import Conversation, EndUser, Message, MessageFile
@ -113,7 +113,7 @@ class AdvancedChatAppGenerateTaskPipeline:
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManage(
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.QUERY: message.query,

View File

@ -18,13 +18,13 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.repository import RepositoryFactory
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser, Workflow
@ -138,12 +138,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id,
app_id=application_generate_entity.app_config.app_id,
)
return self._generate(
@ -264,12 +262,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id,
app_id=application_generate_entity.app_config.app_id,
)
return self._generate(
@ -329,12 +325,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": application_generate_entity.app_config.tenant_id,
"app_id": application_generate_entity.app_config.app_id,
"session_factory": session_factory,
}
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
tenant_id=application_generate_entity.app_config.tenant_id,
app_id=application_generate_entity.app_config.app_id,
)
return self._generate(

View File

@ -9,7 +9,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
@ -45,6 +44,7 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (

View File

@ -0,0 +1 @@
# Core base package

View File

@ -0,0 +1,6 @@
from core.base.tts.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
__all__ = [
"AppGeneratorTTSPublisher",
"AudioTrunk",
]

View File

@ -1,3 +1,5 @@
import logging
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@ -7,6 +9,8 @@ from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
_logger = logging.getLogger(__name__)
class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""
@ -42,18 +46,31 @@ class DatasetIndexToolCallbackHandler:
"""Handle tool end."""
for document in documents:
if document.metadata is not None:
dataset_document = DatasetDocument.query.filter(
DatasetDocument.id == document.metadata["document_id"]
).first()
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
document_id,
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
).first()
child_chunk = (
db.session.query(ChildChunk)
.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
)
if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
)
else:
query = db.session.query(DocumentSegment).filter(

View File

@ -51,7 +51,7 @@ class IndexingRunner:
for dataset_document in dataset_documents:
try:
# get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
@ -103,15 +103,17 @@ class IndexingRunner:
"""Run the indexing process when the index_status is splitting."""
try:
# get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id, document_id=dataset_document.id
).all()
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.all()
)
for document_segment in document_segments:
db.session.delete(document_segment)
@ -162,15 +164,17 @@ class IndexingRunner:
"""Run the indexing process when the index_status is indexing."""
try:
# get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id, document_id=dataset_document.id
).all()
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.all()
)
documents = []
if document_segments:
@ -254,7 +258,7 @@ class IndexingRunner:
embedding_model_instance = None
if dataset_id:
dataset = Dataset.query.filter_by(id=dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("Dataset not found.")
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
@ -587,7 +591,7 @@ class IndexingRunner:
@staticmethod
def _process_keyword_index(flask_app, dataset_id, document_id, documents):
with flask_app.app_context():
dataset = Dataset.query.filter_by(id=dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
keyword = Keyword(dataset)
@ -656,10 +660,10 @@ class IndexingRunner:
"""
Update the document indexing status.
"""
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedError()
document = DatasetDocument.query.filter_by(id=document_id).first()
document = db.session.query(DatasetDocument).filter_by(id=document_id).first()
if not document:
raise DocumentIsDeletedPausedError()
@ -668,7 +672,7 @@ class IndexingRunner:
if extra_update_params:
update_params.update(extra_update_params)
DatasetDocument.query.filter_by(id=document_id).update(update_params)
db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params)
db.session.commit()
@staticmethod
@ -676,7 +680,7 @@ class IndexingRunner:
"""
Update the document segment by document id.
"""
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()
def _transform(

View File

@ -1,9 +1,9 @@
from enum import Enum
from enum import StrEnum
from pydantic import BaseModel, ValidationInfo, field_validator
class TracingProviderEnum(Enum):
class TracingProviderEnum(StrEnum):
LANGFUSE = "langfuse"
LANGSMITH = "langsmith"
OPIK = "opik"

View File

@ -29,7 +29,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum,
)
from core.ops.utils import filter_none_values
from core.workflow.repository.repository_factory import RepositoryFactory
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from extensions.ext_database import db
from models.model import EndUser
@ -113,8 +113,8 @@ class LangFuseDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory},
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=trace_info.tenant_id
)
# Get all executions for this workflow run

View File

@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.workflow.repository.repository_factory import RepositoryFactory
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from extensions.ext_database import db
from models.model import EndUser, MessageFile
@ -137,12 +137,8 @@ class LangSmithDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": trace_info.tenant_id,
"app_id": trace_info.metadata.get("app_id"),
"session_factory": session_factory,
},
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
)
# Get all executions for this workflow run

View File

@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.workflow.repository.repository_factory import RepositoryFactory
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from extensions.ext_database import db
from models.model import EndUser, MessageFile
@ -150,12 +150,8 @@ class OpikDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": trace_info.tenant_id,
"app_id": trace_info.metadata.get("app_id"),
"session_factory": session_factory,
},
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
)
# Get all executions for this workflow run

View File

@ -16,11 +16,7 @@ from sqlalchemy.orm import Session
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
LangfuseConfig,
LangSmithConfig,
OpikConfig,
TracingProviderEnum,
WeaveConfig,
)
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
@ -33,11 +29,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from core.ops.opik_trace.opik_trace import OpikDataTrace
from core.ops.utils import get_message_data
from core.ops.weave_trace.weave_trace import WeaveDataTrace
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
@ -45,36 +37,58 @@ from models.workflow import WorkflowAppLog, WorkflowRun
from tasks.ops_trace_task import process_trace_tasks
def build_opik_trace_instance(config: OpikConfig):
return OpikDataTrace(config)
class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
def __getitem__(self, provider: str) -> dict[str, Any]:
match provider:
case TracingProviderEnum.LANGFUSE:
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
return {
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
"other_keys": ["host", "project_key"],
"trace_instance": LangFuseDataTrace,
}
case TracingProviderEnum.LANGSMITH:
from core.ops.entities.config_entity import LangSmithConfig
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
return {
"config_class": LangSmithConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
}
case TracingProviderEnum.OPIK:
from core.ops.entities.config_entity import OpikConfig
from core.ops.opik_trace.opik_trace import OpikDataTrace
return {
"config_class": OpikConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"],
"trace_instance": OpikDataTrace,
}
case TracingProviderEnum.WEAVE:
from core.ops.entities.config_entity import WeaveConfig
from core.ops.weave_trace.weave_trace import WeaveDataTrace
return {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint"],
"trace_instance": WeaveDataTrace,
}
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
provider_config_map: dict[str, dict[str, Any]] = {
TracingProviderEnum.LANGFUSE.value: {
"config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"],
"other_keys": ["host", "project_key"],
"trace_instance": LangFuseDataTrace,
},
TracingProviderEnum.LANGSMITH.value: {
"config_class": LangSmithConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace,
},
TracingProviderEnum.OPIK.value: {
"config_class": OpikConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"],
"trace_instance": lambda config: build_opik_trace_instance(config),
},
TracingProviderEnum.WEAVE.value: {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint"],
"trace_instance": WeaveDataTrace,
},
}
provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap()
class OpsTraceManager:

View File

@ -24,7 +24,7 @@ class EndpointProviderDeclaration(BaseModel):
"""
settings: list[ProviderConfig] = Field(default_factory=list)
endpoints: Optional[list[EndpointDeclaration]] = Field(default_factory=list)
endpoints: Optional[list[EndpointDeclaration]] = Field(default_factory=list[EndpointDeclaration])
class EndpointEntity(BasePluginEntity):

View File

@ -52,7 +52,7 @@ class PluginResourceRequirements(BaseModel):
model: Optional[Model] = Field(default=None)
node: Optional[Node] = Field(default=None)
endpoint: Optional[Endpoint] = Field(default=None)
storage: Storage = Field(default=None)
storage: Optional[Storage] = Field(default=None)
permission: Optional[Permission] = Field(default=None)
@ -66,9 +66,9 @@ class PluginCategory(enum.StrEnum):
class PluginDeclaration(BaseModel):
class Plugins(BaseModel):
tools: Optional[list[str]] = Field(default_factory=list)
models: Optional[list[str]] = Field(default_factory=list)
endpoints: Optional[list[str]] = Field(default_factory=list)
tools: Optional[list[str]] = Field(default_factory=list[str])
models: Optional[list[str]] = Field(default_factory=list[str])
endpoints: Optional[list[str]] = Field(default_factory=list[str])
class Meta(BaseModel):
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
@ -84,6 +84,7 @@ class PluginDeclaration(BaseModel):
resource: PluginResourceRequirements
plugins: Plugins
tags: list[str] = Field(default_factory=list)
repo: Optional[str] = Field(default=None)
verified: bool = Field(default=False)
tool: Optional[ToolProviderEntity] = None
model: Optional[ProviderEntity] = None

View File

@ -55,8 +55,8 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
mode: str
completion_params: dict[str, Any] = Field(default_factory=dict)
prompt_messages: list[PromptMessage] = Field(default_factory=list)
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
stop: Optional[list[str]] = Field(default_factory=list)
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list[PromptMessageTool])
stop: Optional[list[str]] = Field(default_factory=list[str])
stream: Optional[bool] = False
model_config = ConfigDict(protected_namespaces=())

View File

@ -10,6 +10,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
@ -119,12 +120,25 @@ class RetrievalService:
return all_documents
@classmethod
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
def external_retrieve(
cls,
dataset_id: str,
query: str,
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
metadata_condition = (
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id, dataset_id, query, external_retrieval_model or {}
dataset.tenant_id,
dataset_id,
query,
external_retrieval_model or {},
metadata_condition=metadata_condition,
)
return all_documents

View File

@ -317,7 +317,7 @@ class NotionExtractor(BaseExtractor):
data_source_info["last_edited_time"] = last_edited_time
update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)}
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params)
db.session.commit()
def get_notion_last_edited_time(self) -> str:
@ -347,14 +347,18 @@ class NotionExtractor(BaseExtractor):
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
)
).first()
.first()
)
if not data_source_binding:
raise Exception(

View File

@ -76,8 +76,7 @@ class WordExtractor(BaseExtractor):
parsed = urlparse(url)
return bool(parsed.netloc) and bool(parsed.scheme)
def _extract_images_from_docx(self, doc, image_folder):
os.makedirs(image_folder, exist_ok=True)
def _extract_images_from_docx(self, doc):
image_count = 0
image_map = {}
@ -210,7 +209,7 @@ class WordExtractor(BaseExtractor):
content = []
image_map = self._extract_images_from_docx(doc, image_folder)
image_map = self._extract_images_from_docx(doc)
hyperlinks_url = None
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
@ -225,7 +224,7 @@ class WordExtractor(BaseExtractor):
xml = ElementTree.XML(run.element.xml)
x_child = [c for c in xml.iter() if c is not None]
for x in x_child:
if x_child is None:
if x is None:
continue
if x.tag.endswith("instrText"):
if x.text is None:

View File

@ -149,7 +149,7 @@ class DatasetRetrieval:
else:
inputs = {}
available_datasets_ids = [dataset.id for dataset in available_datasets]
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
available_datasets_ids,
query,
tenant_id,
@ -237,12 +237,16 @@ class DatasetRetrieval:
if show_retrieve_source:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument)
.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
)
if dataset and document:
source = {
"dataset_id": dataset.id,
@ -506,19 +510,30 @@ class DatasetRetrieval:
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
if document.metadata is not None:
dataset_document = DatasetDocument.query.filter(
DatasetDocument.id == document.metadata["document_id"]
).first()
dataset_document = (
db.session.query(DatasetDocument)
.filter(DatasetDocument.id == document.metadata["document_id"])
.first()
)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
).first()
child_chunk = (
db.session.query(ChildChunk)
.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
)
if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
db.session.commit()
else:
@ -649,6 +664,8 @@ class DatasetRetrieval:
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
) -> Optional[list[DatasetRetrieverBaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
@ -706,6 +723,9 @@ class DatasetRetrieval:
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
retrieve_config=retrieve_config,
user_id=user_id,
inputs=inputs,
)
tools.append(tool)
@ -826,7 +846,7 @@ class DatasetRetrieval:
)
return filter_documents[:top_k] if top_k else filter_documents
def _get_metadata_filter_condition(
def get_metadata_filter_condition(
self,
dataset_ids: list,
query: str,
@ -876,20 +896,31 @@ class DatasetRetrieval:
)
elif metadata_filtering_mode == "manual":
if metadata_filtering_conditions:
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
conditions = []
for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else:
raise ValueError("Invalid metadata filtering mode")
if filters:

View File

@ -4,3 +4,9 @@ Repository implementations for data access.
This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"SQLAlchemyWorkflowNodeExecutionRepository",
]

View File

@ -1,87 +0,0 @@
"""
Registry for repository implementations.
This module is responsible for registering factory functions with the repository factory.
"""
import logging
from collections.abc import Mapping
from typing import Any
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
logger = logging.getLogger(__name__)
# Storage type constants
STORAGE_TYPE_RDBMS = "rdbms"
STORAGE_TYPE_HYBRID = "hybrid"
def register_repositories() -> None:
"""
Register repository factory functions with the RepositoryFactory.
This function reads configuration settings to determine which repository
implementations to register.
"""
# Configure WorkflowNodeExecutionRepository factory based on configuration
workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE
# Check storage type and register appropriate implementation
if workflow_node_execution_storage == STORAGE_TYPE_RDBMS:
# Register SQLAlchemy implementation for RDBMS storage
logger.info("Registering WorkflowNodeExecution repository with RDBMS storage")
RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository)
elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID:
# Hybrid storage is not yet implemented
raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented")
else:
# Unknown storage type
raise ValueError(
f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. "
f"Supported types: {STORAGE_TYPE_RDBMS}"
)
def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository:
"""
Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation.
This factory function creates a repository for the RDBMS storage type.
Args:
params: Parameters for creating the repository, including:
- tenant_id: Required. The tenant ID for multi-tenancy.
- app_id: Optional. The application ID for filtering.
- session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided,
a new sessionmaker will be created using the global database engine.
Returns:
A WorkflowNodeExecutionRepository instance
Raises:
ValueError: If required parameters are missing
"""
# Extract required parameters
tenant_id = params.get("tenant_id")
if tenant_id is None:
raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage")
# Extract optional parameters
app_id = params.get("app_id")
# Use the session_factory from params if provided, otherwise create one using the global db engine
session_factory = params.get("session_factory")
if session_factory is None:
# Create a sessionmaker using the same engine as the global db session
session_factory = sessionmaker(bind=db.engine)
# Create and return the repository
return SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
)

View File

@ -10,13 +10,13 @@ from sqlalchemy import UnaryExpression, asc, delete, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
class SQLAlchemyWorkflowNodeExecutionRepository:
class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
"""
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.

View File

@ -1,9 +0,0 @@
"""
WorkflowNodeExecution repository implementations.
"""
from core.repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"SQLAlchemyWorkflowNodeExecutionRepository",
]

View File

@ -84,13 +84,17 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
.all()
)
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
@ -106,12 +110,16 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
context_list = []
resource_number = 1
for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(Document)
.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
)
if dataset and document:
source = {
"position": resource_number,

View File

@ -1,11 +1,12 @@
from typing import Any
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
@ -34,7 +35,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
dataset_id: str
metadata_filtering_conditions: MetadataCondition
user_id: Optional[str] = None
retrieve_config: DatasetRetrieveConfigEntity
inputs: dict
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
@ -48,7 +51,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
metadata_filtering_conditions=MetadataCondition(),
**kwargs,
)
@ -61,6 +63,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return ""
for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)
dataset_retrieval = DatasetRetrieval()
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
[dataset.id],
query,
self.tenant_id,
self.user_id or "unknown",
cast(str, self.retrieve_config.metadata_filtering_mode),
cast(ModelConfig, self.retrieve_config.metadata_model_config),
self.retrieve_config.metadata_filtering_conditions,
self.inputs,
)
if metadata_filter_document_ids:
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
else:
document_ids_filter = None
if dataset.provider == "external":
results = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
@ -68,7 +85,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
dataset_id=dataset.id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=self.metadata_filtering_conditions,
metadata_condition=metadata_condition,
)
for external_document in external_documents:
document = RetrievalDocument(
@ -104,12 +121,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return str("\n".join([item.page_content for item in results]))
else:
if metadata_condition and not document_ids_filter:
return ""
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
document_ids_filter=document_ids_filter,
)
return str("\n".join([document.page_content for document in documents]))
else:
@ -128,6 +151,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights"),
document_ids_filter=document_ids_filter,
)
else:
documents = []
@ -161,12 +185,16 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if self.return_resource:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument) # type: ignore
.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
)
if dataset and document:
source = {
"dataset_id": dataset.id,

View File

@ -34,6 +34,8 @@ class DatasetRetrieverTool(Tool):
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
) -> list["DatasetRetrieverTool"]:
"""
get dataset tool
@ -57,6 +59,8 @@ class DatasetRetrieverTool(Tool):
return_resource=return_resource,
invoke_from=invoke_from,
hit_callback=hit_callback,
user_id=user_id,
inputs=inputs,
)
if retrieval_tools is None or len(retrieval_tools) == 0:
return []

View File

@ -30,7 +30,7 @@ class Variable(Segment):
"""
id: str = Field(
default=lambda _: str(uuid4()),
default_factory=lambda: str(uuid4()),
description="Unique identity for variable.",
)
name: str

View File

@ -36,7 +36,7 @@ class Graph(BaseModel):
root_node_id: str = Field(..., description="root node id of the graph")
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field(
default_factory=list, description="node configs mapping (node id: node config)"
default_factory=dict, description="node configs mapping (node id: node config)"
)
edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict, description="graph edge mapping (source node id: edges)"

View File

@ -95,7 +95,12 @@ class StreamProcessor(ABC):
if node_id not in self.rest_node_ids:
return
if node_id in reachable_node_ids:
return
self.rest_node_ids.remove(node_id)
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue

View File

@ -127,7 +127,7 @@ class CodeNode(BaseNode[CodeNodeData]):
depth: int = 1,
):
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
transformed_result: dict[str, Any] = {}
if output_schema is None:

View File

@ -353,27 +353,26 @@ class IterationNode(BaseNode[IterationNodeData]):
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
"""
add iteration metadata to event.
ensures iteration context (ID, index/parallel_run_id) is added to metadata,
"""
if not isinstance(event, BaseNodeEvent):
return event
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id
return event
iter_metadata = {
NodeRunMetadataKey.ITERATION_ID: self.node_id,
NodeRunMetadataKey.ITERATION_INDEX: iter_run_index,
}
if parallel_mode_run_id:
# for parallel, the specific branch ID is more important than the sequential index
iter_metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata = {
**metadata,
NodeRunMetadataKey.ITERATION_ID: self.node_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID
if self.node_data.is_parallel
else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id
if self.node_data.is_parallel
else iter_run_index,
}
event.route_node_state.node_run_result.metadata = metadata
current_metadata = event.route_node_state.node_run_result.metadata or {}
if NodeRunMetadataKey.ITERATION_ID not in current_metadata:
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}
return event
def _run_single_iter(

View File

@ -264,6 +264,7 @@ class KnowledgeRetrievalNode(LLMNode):
"data_source_type": "external",
"retriever_from": "workflow",
"score": item.metadata.get("score"),
"doc_metadata": item.metadata,
},
"title": item.metadata.get("title"),
"content": item.page_content,
@ -275,12 +276,16 @@ class KnowledgeRetrievalNode(LLMNode):
if records:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = (
db.session.query(Document)
.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
)
if dataset and document:
source = {
"metadata": {
@ -289,7 +294,7 @@ class KnowledgeRetrievalNode(LLMNode):
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"document_data_source_type": document.data_source_type,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": "workflow",
"score": record.score or 0.0,
@ -356,12 +361,12 @@ class KnowledgeRetrievalNode(LLMNode):
)
elif node_data.metadata_filtering_mode == "manual":
if node_data.metadata_filtering_conditions:
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
conditions = []
if node_data.metadata_filtering_conditions:
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
@ -372,13 +377,24 @@ class KnowledgeRetrievalNode(LLMNode):
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
else:
raise ValueError("Invalid expected metadata value type")
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else:
raise ValueError("Invalid metadata filtering mode")
if filters:

View File

@ -506,7 +506,7 @@ class LLMNode(BaseNode[LLMNodeData]):
"dataset_name": metadata.get("dataset_name"),
"document_id": metadata.get("document_id"),
"document_name": metadata.get("document_name"),
"data_source_type": metadata.get("document_data_source_type"),
"data_source_type": metadata.get("data_source_type"),
"segment_id": metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"),
"score": metadata.get("score"),

View File

@ -26,7 +26,7 @@ class LoopNodeData(BaseLoopNodeData):
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list)
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
outputs: Optional[Mapping[str, Any]] = None

View File

@ -337,7 +337,7 @@ class LoopNode(BaseNode[LoopNodeData]):
return {"check_break_result": True}
elif isinstance(event, NodeRunFailedEvent):
# Loop run failed
yield event
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,

View File

@ -6,10 +6,9 @@ for accessing and manipulating data, regardless of the underlying
storage mechanism.
"""
from core.workflow.repository.repository_factory import RepositoryFactory
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
__all__ = [
"RepositoryFactory",
"OrderConfig",
"WorkflowNodeExecutionRepository",
]

View File

@ -1,97 +0,0 @@
"""
Repository factory for creating repository instances.
This module provides a simple factory interface for creating repository instances.
It does not contain any implementation details or dependencies on specific repositories.
"""
from collections.abc import Callable, Mapping
from typing import Any, Literal, Optional, cast
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
# Type for factory functions - takes a dict of parameters and returns any repository type
RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
# Type for workflow node execution factory function
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
# Repository type literals
_RepositoryType = Literal["workflow_node_execution"]
class RepositoryFactory:
"""
Factory class for creating repository instances.
This factory delegates the actual repository creation to implementation-specific
factory functions that are registered with the factory at runtime.
"""
# Dictionary to store factory functions
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
@classmethod
def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
"""
Register a factory function for a specific repository type.
This is a private method and should not be called directly.
Args:
repository_type: The type of repository (e.g., 'workflow_node_execution')
factory_func: A function that takes parameters and returns a repository instance
"""
cls._factory_functions[repository_type] = factory_func
@classmethod
def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
"""
Create a new repository instance with the provided parameters.
This is a private method and should not be called directly.
Args:
repository_type: The type of repository to create
params: A dictionary of parameters to pass to the factory function
Returns:
A new instance of the requested repository
Raises:
ValueError: If no factory function is registered for the repository type
"""
if repository_type not in cls._factory_functions:
raise ValueError(f"No factory function registered for repository type '{repository_type}'")
# Use empty dict if params is None
params = params or {}
return cls._factory_functions[repository_type](params)
@classmethod
def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
"""
Register a factory function for the workflow node execution repository.
Args:
factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
"""
cls._register_factory("workflow_node_execution", factory_func)
@classmethod
def create_workflow_node_execution_repository(
cls, params: Optional[Mapping[str, Any]] = None
) -> WorkflowNodeExecutionRepository:
"""
Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
Args:
params: A dictionary of parameters to pass to the factory function
Returns:
A new instance of the WorkflowNodeExecutionRepository
Raises:
ValueError: If no factory function is registered for the workflow_node_execution repository type
"""
# We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))

View File

@ -39,7 +39,7 @@ class SubCondition(BaseModel):
class SubVariableCondition(BaseModel):
logical_operator: Literal["and", "or"]
conditions: list[SubCondition] = Field(default=list)
conditions: list[SubCondition] = Field(default_factory=list)
class Condition(BaseModel):

View File

@ -6,7 +6,6 @@ from typing import Optional, Union
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import (
InvokeFrom,
@ -52,10 +51,11 @@ from core.app.entities.task_entities import (
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
@ -102,7 +102,7 @@ class WorkflowAppGenerateTaskPipeline:
else:
raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManage(
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.FILES: application_generate_entity.files,

View File

@ -69,7 +69,7 @@ from models.workflow import (
)
class WorkflowCycleManage:
class WorkflowCycleManager:
def __init__(
self,
*,

View File

@ -39,6 +39,10 @@ def init_app(app: DifyApp):
handlers=log_handlers,
force=True,
)
# Apply RequestIdFormatter to all handlers
apply_request_id_formatter()
# Disable propagation for noisy loggers to avoid duplicate logs
logging.getLogger("sqlalchemy.engine").propagate = False
log_tz = dify_config.LOG_TZ
@ -74,3 +78,16 @@ class RequestIdFilter(logging.Filter):
def filter(self, record):
record.req_id = get_request_id() if flask.has_request_context() else ""
return True
class RequestIdFormatter(logging.Formatter):
def format(self, record):
if not hasattr(record, "req_id"):
record.req_id = ""
return super().format(record)
def apply_request_id_formatter():
for handler in logging.root.handlers:
if handler.formatter:
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)

View File

@ -26,7 +26,7 @@ class Mail:
match mail_type:
case "resend":
import resend # type: ignore
import resend
api_key = dify_config.RESEND_API_KEY
if not api_key:

View File

@ -6,6 +6,7 @@ import socket
import sys
from typing import Union
import flask
from celery.signals import worker_init # type: ignore
from flask_login import user_loaded_from_request, user_logged_in # type: ignore
@ -27,6 +28,8 @@ def on_user_loaded(_sender, user):
def init_app(app: DifyApp):
from opentelemetry.semconv.trace import SpanAttributes
def is_celery_worker():
return "celery" in sys.argv[0].lower()
@ -37,7 +40,9 @@ def init_app(app: DifyApp):
def init_flask_instrumentor(app: DifyApp):
meter = get_meter("http_metrics", version=dify_config.CURRENT_VERSION)
_http_response_counter = meter.create_counter(
"http.server.response.count", description="Total number of HTTP responses by status code", unit="{response}"
"http.server.response.count",
description="Total number of HTTP responses by status code, method and target",
unit="{response}",
)
def response_hook(span: Span, status: str, response_headers: list):
@ -50,7 +55,13 @@ def init_app(app: DifyApp):
status = status.split(" ")[0]
status_code = int(status)
status_class = f"{status_code // 100}xx"
_http_response_counter.add(1, {"status_code": status_code, "status_class": status_class})
attributes: dict[str, str | int] = {"status_code": status_code, "status_class": status_class}
request = flask.request
if request and request.url_rule:
attributes[SpanAttributes.HTTP_TARGET] = str(request.url_rule.rule)
if request and request.method:
attributes[SpanAttributes.HTTP_METHOD] = str(request.method)
_http_response_counter.add(1, attributes)
instrumentor = FlaskInstrumentor()
if dify_config.DEBUG:
@ -103,8 +114,10 @@ def init_app(app: DifyApp):
pass
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
from opentelemetry.instrumentation.celery import CeleryInstrumentor
from opentelemetry.instrumentation.flask import FlaskInstrumentor
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
@ -147,19 +160,32 @@ def init_app(app: DifyApp):
sampler = ParentBasedTraceIdRatio(dify_config.OTEL_SAMPLING_RATE)
provider = TracerProvider(resource=resource, sampler=sampler)
set_tracer_provider(provider)
exporter: Union[OTLPSpanExporter, ConsoleSpanExporter]
metric_exporter: Union[OTLPMetricExporter, ConsoleMetricExporter]
exporter: Union[GRPCSpanExporter, HTTPSpanExporter, ConsoleSpanExporter]
metric_exporter: Union[GRPCMetricExporter, HTTPMetricExporter, ConsoleMetricExporter]
protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower()
if dify_config.OTEL_EXPORTER_TYPE == "otlp":
exporter = OTLPSpanExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces",
headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
)
metric_exporter = OTLPMetricExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics",
headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
)
if protocol == "grpc":
exporter = GRPCSpanExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT,
# Header field names must consist of lowercase letters, check RFC7540
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
insecure=True,
)
metric_exporter = GRPCMetricExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT,
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
insecure=True,
)
else:
exporter = HTTPSpanExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces",
headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
)
metric_exporter = HTTPMetricExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics",
headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
)
else:
# Fallback to console exporter
exporter = ConsoleSpanExporter()
metric_exporter = ConsoleMetricExporter()

View File

@ -1,6 +1,7 @@
from typing import Any, Union
import redis
from redis.cache import CacheConfig
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.sentinel import Sentinel
@ -51,6 +52,14 @@ def init_app(app: DifyApp):
connection_class: type[Union[Connection, SSLConnection]] = Connection
if dify_config.REDIS_USE_SSL:
connection_class = SSLConnection
resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL
if dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE:
if resp_protocol >= 3:
clientside_cache_config = CacheConfig()
else:
raise ValueError("Client side cache is only supported in RESP3")
else:
clientside_cache_config = None
redis_params: dict[str, Any] = {
"username": dify_config.REDIS_USERNAME,
@ -59,6 +68,8 @@ def init_app(app: DifyApp):
"encoding": "utf-8",
"encoding_errors": "strict",
"decode_responses": False,
"protocol": resp_protocol,
"cache_config": clientside_cache_config,
}
if dify_config.REDIS_USE_SENTINEL:
@ -82,14 +93,22 @@ def init_app(app: DifyApp):
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
for node in dify_config.REDIS_CLUSTERS.split(",")
]
# FIXME: mypy error here, try to figure out how to fix it
redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD)) # type: ignore
redis_client.initialize(
RedisCluster(
startup_nodes=nodes,
password=dify_config.REDIS_CLUSTERS_PASSWORD,
protocol=resp_protocol,
cache_config=clientside_cache_config,
)
)
else:
redis_params.update(
{
"host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT,
"connection_class": connection_class,
"protocol": resp_protocol,
"cache_config": clientside_cache_config,
}
)
pool = redis.ConnectionPool(**redis_params)

View File

@ -1,18 +0,0 @@
"""
Extension for initializing repositories.
This extension registers repository implementations with the RepositoryFactory.
"""
from core.repositories.repository_registry import register_repositories
from dify_app import DifyApp
def init_app(_app: DifyApp) -> None:
"""
Initialize repository implementations.
Args:
_app: The Flask application instance (unused)
"""
register_repositories()

View File

@ -61,13 +61,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages),
}
# save data source binding
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
).first()
.first()
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
@ -97,13 +101,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages),
}
# save data source binding
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
).first()
.first()
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
@ -121,14 +129,18 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str):
# save data source binding
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
)
).first()
.first()
)
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)

View File

@ -5,45 +5,61 @@ Revises: 33f5fac87f29
Create Date: 2024-10-10 05:16:14.764268
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op, context
# revision identifiers, used by Alembic.
revision = 'bbadea11becb'
down_revision = 'd8e744d88ed6'
revision = "bbadea11becb"
down_revision = "d8e744d88ed6"
branch_labels = None
depends_on = None
def upgrade():
def _has_name_or_size_column() -> bool:
# We cannot access the database in offline mode, so assume
# the "name" and "size" columns do not exist.
if context.is_offline_mode():
# Log a warning message to inform the user that the database schema cannot be inspected
# in offline mode, and the generated SQL may not accurately reflect the actual execution.
op.execute(
"-- Executing in offline mode, assuming the name and size columns do not exist.\n"
"-- The generated SQL may differ from what will actually be executed.\n"
"-- Please review the migration script carefully!"
)
return False
# Use SQLAlchemy inspector to get the columns of the 'tool_files' table
inspector = sa.inspect(conn)
columns = [col["name"] for col in inspector.get_columns("tool_files")]
# If 'name' or 'size' columns already exist, exit the upgrade function
if "name" in columns or "size" in columns:
return True
return False
# ### commands auto generated by Alembic - please adjust! ###
# Get the database connection
conn = op.get_bind()
# Use SQLAlchemy inspector to get the columns of the 'tool_files' table
inspector = sa.inspect(conn)
columns = [col['name'] for col in inspector.get_columns('tool_files')]
# If 'name' or 'size' columns already exist, exit the upgrade function
if 'name' in columns or 'size' in columns:
if _has_name_or_size_column():
return
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.add_column(sa.Column('name', sa.String(), nullable=True))
batch_op.add_column(sa.Column('size', sa.Integer(), nullable=True))
with op.batch_alter_table("tool_files", schema=None) as batch_op:
batch_op.add_column(sa.Column("name", sa.String(), nullable=True))
batch_op.add_column(sa.Column("size", sa.Integer(), nullable=True))
op.execute("UPDATE tool_files SET name = '' WHERE name IS NULL")
op.execute("UPDATE tool_files SET size = -1 WHERE size IS NULL")
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('name', existing_type=sa.String(), nullable=False)
batch_op.alter_column('size', existing_type=sa.Integer(), nullable=False)
with op.batch_alter_table("tool_files", schema=None) as batch_op:
batch_op.alter_column("name", existing_type=sa.String(), nullable=False)
batch_op.alter_column("size", existing_type=sa.Integer(), nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.drop_column('size')
batch_op.drop_column('name')
with op.batch_alter_table("tool_files", schema=None) as batch_op:
batch_op.drop_column("size")
batch_op.drop_column("name")
# ### end Alembic commands ###

View File

@ -5,28 +5,38 @@ Revises: e1944c35e15e
Create Date: 2024-12-23 11:54:15.344543
"""
from alembic import op
import models as models
import sqlalchemy as sa
from alembic import op, context
from sqlalchemy import inspect
# revision identifiers, used by Alembic.
revision = 'd7999dfa4aae'
down_revision = 'e1944c35e15e'
revision = "d7999dfa4aae"
down_revision = "e1944c35e15e"
branch_labels = None
depends_on = None
def upgrade():
# Check if column exists before attempting to remove it
conn = op.get_bind()
inspector = inspect(conn)
has_column = 'retry_index' in [col['name'] for col in inspector.get_columns('workflow_node_executions')]
def _has_retry_index_column() -> bool:
if context.is_offline_mode():
# Log a warning message to inform the user that the database schema cannot be inspected
# in offline mode, and the generated SQL may not accurately reflect the actual execution.
op.execute(
'-- Executing in offline mode: assuming the "retry_index" column does not exist.\n'
"-- The generated SQL may differ from what will actually be executed.\n"
"-- Please review the migration script carefully!"
)
return False
conn = op.get_bind()
inspector = inspect(conn)
return "retry_index" in [col["name"] for col in inspector.get_columns("workflow_node_executions")]
has_column = _has_retry_index_column()
if has_column:
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.drop_column('retry_index')
with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
batch_op.drop_column("retry_index")
def downgrade():

View File

@ -0,0 +1,33 @@
"""add index for workflow_conversation_variables.conversation_id
Revision ID: d28f2004b072
Revises: 6a9f914f656c
Create Date: 2025-05-14 14:03:36.713828
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd28f2004b072'
down_revision = '6a9f914f656c'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op:
batch_op.create_index(batch_op.f('workflow_conversation_variables_conversation_id_idx'), ['conversation_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('workflow_conversation_variables_conversation_id_idx'))
# ### end Alembic commands ###

View File

@ -1,5 +1,6 @@
import enum
import json
from typing import cast
from flask_login import UserMixin # type: ignore
from sqlalchemy import func
@ -46,13 +47,12 @@ class Account(UserMixin, Base):
@property
def current_tenant(self):
# FIXME: fix the type error later, because the type is important maybe cause some bugs
return self._current_tenant # type: ignore
@current_tenant.setter
def current_tenant(self, value: "Tenant"):
tenant = value
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=self.id).first()
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
if ta:
tenant.current_role = ta.role
else:
@ -64,25 +64,23 @@ class Account(UserMixin, Base):
def current_tenant_id(self) -> str | None:
return self._current_tenant.id if self._current_tenant else None
@current_tenant_id.setter
def current_tenant_id(self, value: str):
try:
tenant_account_join = (
def set_tenant_id(self, tenant_id: str):
tenant_account_join = cast(
tuple[Tenant, TenantAccountJoin],
(
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == value)
.filter(Tenant.id == tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == self.id)
.one_or_none()
)
),
)
if tenant_account_join:
tenant, ta = tenant_account_join
tenant.current_role = ta.role
else:
tenant = None
except Exception:
tenant = None
if not tenant_account_join:
return
tenant, join = tenant_account_join
tenant.current_role = join.role
self._current_tenant = tenant
@property
@ -191,7 +189,7 @@ class TenantAccountRole(enum.StrEnum):
}
class Tenant(db.Model): # type: ignore[name-defined]
class Tenant(Base):
__tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
@ -220,7 +218,7 @@ class Tenant(db.Model): # type: ignore[name-defined]
self.custom_config = json.dumps(value)
class TenantAccountJoin(db.Model): # type: ignore[name-defined]
class TenantAccountJoin(Base):
__tablename__ = "tenant_account_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
@ -239,7 +237,7 @@ class TenantAccountJoin(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class AccountIntegrate(db.Model): # type: ignore[name-defined]
class AccountIntegrate(Base):
__tablename__ = "account_integrates"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
@ -256,7 +254,7 @@ class AccountIntegrate(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class InvitationCode(db.Model): # type: ignore[name-defined]
class InvitationCode(Base):
__tablename__ = "invitation_codes"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),

View File

@ -2,6 +2,7 @@ import enum
from sqlalchemy import func
from .base import Base
from .engine import db
from .types import StringUUID
@ -13,7 +14,7 @@ class APIBasedExtensionPoint(enum.Enum):
APP_MODERATION_OUTPUT = "app.moderation.output"
class APIBasedExtension(db.Model): # type: ignore[name-defined]
class APIBasedExtension(Base):
__tablename__ = "api_based_extensions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),

View File

@ -22,6 +22,7 @@ from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from .account import Account
from .base import Base
from .engine import db
from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID
@ -33,7 +34,7 @@ class DatasetPermissionEnum(enum.StrEnum):
PARTIAL_TEAM = "partial_members"
class Dataset(db.Model): # type: ignore[name-defined]
class Dataset(Base):
__tablename__ = "datasets"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
@ -92,7 +93,8 @@ class Dataset(db.Model): # type: ignore[name-defined]
@property
def latest_process_rule(self):
return (
DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id)
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc())
.first()
)
@ -137,7 +139,8 @@ class Dataset(db.Model): # type: ignore[name-defined]
@property
def word_count(self):
return (
Document.query.with_entities(func.coalesce(func.sum(Document.word_count)))
db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count)))
.filter(Document.dataset_id == self.id)
.scalar()
)
@ -255,7 +258,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
return f"Vector_index_{normalized_dataset_id}_Node"
class DatasetProcessRule(db.Model): # type: ignore[name-defined]
class DatasetProcessRule(Base):
__tablename__ = "dataset_process_rules"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
@ -295,7 +298,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
return None
class Document(db.Model): # type: ignore[name-defined]
class Document(Base):
__tablename__ = "documents"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_pkey"),
@ -439,12 +442,13 @@ class Document(db.Model): # type: ignore[name-defined]
@property
def segment_count(self):
return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count()
@property
def hit_count(self):
return (
DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
.filter(DocumentSegment.document_id == self.id)
.scalar()
)
@ -635,7 +639,7 @@ class Document(db.Model): # type: ignore[name-defined]
)
class DocumentSegment(db.Model): # type: ignore[name-defined]
class DocumentSegment(Base):
__tablename__ = "document_segments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
@ -786,7 +790,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
return text
class ChildChunk(db.Model): # type: ignore[name-defined]
class ChildChunk(Base):
__tablename__ = "child_chunks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
@ -829,7 +833,7 @@ class ChildChunk(db.Model): # type: ignore[name-defined]
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(db.Model): # type: ignore[name-defined]
class AppDatasetJoin(Base):
__tablename__ = "app_dataset_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
@ -846,7 +850,7 @@ class AppDatasetJoin(db.Model): # type: ignore[name-defined]
return db.session.get(App, self.app_id)
class DatasetQuery(db.Model): # type: ignore[name-defined]
class DatasetQuery(Base):
__tablename__ = "dataset_queries"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
@ -863,7 +867,7 @@ class DatasetQuery(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
class DatasetKeywordTable(Base):
__tablename__ = "dataset_keyword_tables"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
@ -891,7 +895,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
return dct
# get dataset
dataset = Dataset.query.filter_by(id=self.dataset_id).first()
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
if not dataset:
return None
if self.data_source_type == "database":
@ -908,7 +912,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
return None
class Embedding(db.Model): # type: ignore[name-defined]
class Embedding(Base):
__tablename__ = "embeddings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="embedding_pkey"),
@ -932,7 +936,7 @@ class Embedding(db.Model): # type: ignore[name-defined]
return cast(list[float], pickle.loads(self.embedding)) # noqa: S301
class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
class DatasetCollectionBinding(Base):
__tablename__ = "dataset_collection_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
@ -947,7 +951,7 @@ class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TidbAuthBinding(db.Model): # type: ignore[name-defined]
class TidbAuthBinding(Base):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@ -967,7 +971,7 @@ class TidbAuthBinding(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class Whitelist(db.Model): # type: ignore[name-defined]
class Whitelist(Base):
__tablename__ = "whitelists"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
@ -979,7 +983,7 @@ class Whitelist(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetPermission(db.Model): # type: ignore[name-defined]
class DatasetPermission(Base):
__tablename__ = "dataset_permissions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
@ -996,7 +1000,7 @@ class DatasetPermission(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
class ExternalKnowledgeApis(Base):
__tablename__ = "external_knowledge_apis"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
@ -1049,7 +1053,7 @@ class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
return dataset_bindings
class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
class ExternalKnowledgeBindings(Base):
__tablename__ = "external_knowledge_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
@ -1070,7 +1074,7 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
class DatasetAutoDisableLog(Base):
__tablename__ = "dataset_auto_disable_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
@ -1087,7 +1091,7 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class RateLimitLog(db.Model): # type: ignore[name-defined]
class RateLimitLog(Base):
__tablename__ = "rate_limit_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
@ -1102,7 +1106,7 @@ class RateLimitLog(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class DatasetMetadata(db.Model): # type: ignore[name-defined]
class DatasetMetadata(Base):
__tablename__ = "dataset_metadatas"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
@ -1121,7 +1125,7 @@ class DatasetMetadata(db.Model): # type: ignore[name-defined]
updated_by = db.Column(StringUUID, nullable=True)
class DatasetMetadataBinding(db.Model): # type: ignore[name-defined]
class DatasetMetadataBinding(Base):
__tablename__ = "dataset_metadata_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),

View File

@ -16,7 +16,7 @@ if TYPE_CHECKING:
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore
from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy.orm import Mapped, Session, mapped_column
@ -25,13 +25,13 @@ from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers
from libs.helper import generate_string
from models.base import Base
from models.enums import CreatedByRole
from models.workflow import WorkflowRunStatus
from .account import Account, Tenant
from .base import Base
from .engine import db
from .enums import CreatedByRole
from .types import StringUUID
from .workflow import WorkflowRunStatus
if TYPE_CHECKING:
from .workflow import Workflow
@ -602,7 +602,7 @@ class InstalledApp(Base):
return tenant
class Conversation(db.Model): # type: ignore[name-defined]
class Conversation(Base):
__tablename__ = "conversations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="conversation_pkey"),
@ -794,7 +794,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
for message in messages:
if message.workflow_run:
status_counts[message.workflow_run.status] += 1
status_counts[WorkflowRunStatus(message.workflow_run.status)] += 1
return (
{
@ -864,7 +864,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
}
class Message(db.Model): # type: ignore[name-defined]
class Message(Base):
__tablename__ = "messages"
__table_args__ = (
PrimaryKeyConstraint("id", name="message_pkey"),
@ -1211,7 +1211,7 @@ class Message(db.Model): # type: ignore[name-defined]
)
class MessageFeedback(db.Model): # type: ignore[name-defined]
class MessageFeedback(Base):
__tablename__ = "message_feedbacks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
@ -1237,8 +1237,23 @@ class MessageFeedback(db.Model): # type: ignore[name-defined]
account = db.session.query(Account).filter(Account.id == self.from_account_id).first()
return account
def to_dict(self):
return {
"id": str(self.id),
"app_id": str(self.app_id),
"conversation_id": str(self.conversation_id),
"message_id": str(self.message_id),
"rating": self.rating,
"content": self.content,
"from_source": self.from_source,
"from_end_user_id": str(self.from_end_user_id) if self.from_end_user_id else None,
"from_account_id": str(self.from_account_id) if self.from_account_id else None,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
class MessageFile(db.Model): # type: ignore[name-defined]
class MessageFile(Base):
__tablename__ = "message_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_file_pkey"),
@ -1279,7 +1294,7 @@ class MessageFile(db.Model): # type: ignore[name-defined]
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageAnnotation(db.Model): # type: ignore[name-defined]
class MessageAnnotation(Base):
__tablename__ = "message_annotations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
@ -1310,7 +1325,7 @@ class MessageAnnotation(db.Model): # type: ignore[name-defined]
return account
class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
class AppAnnotationHitHistory(Base):
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@ -1322,7 +1337,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False)
annotation_id = db.Column(StringUUID, nullable=False)
annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False)
source = db.Column(db.Text, nullable=False)
question = db.Column(db.Text, nullable=False)
account_id = db.Column(StringUUID, nullable=False)
@ -1348,7 +1363,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
return account
class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
class AppAnnotationSetting(Base):
__tablename__ = "app_annotation_settings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
@ -1364,26 +1379,6 @@ class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
updated_user_id = db.Column(StringUUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def created_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account
@property
def updated_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account
@property
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding

View File

@ -2,8 +2,7 @@ from enum import Enum
from sqlalchemy import func
from models.base import Base
from .base import Base
from .engine import db
from .types import StringUUID

View File

@ -9,7 +9,7 @@ from .engine import db
from .types import StringUUID
class DataSourceOauthBinding(db.Model): # type: ignore[name-defined]
class DataSourceOauthBinding(Base):
__tablename__ = "data_source_oauth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),

View File

@ -1,6 +1,6 @@
import json
from datetime import datetime
from typing import Any, Optional, cast
from typing import Any, cast
import sqlalchemy as sa
from deprecated import deprecated
@ -304,8 +304,11 @@ class DeprecatedPublishedAppTool(Base):
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# id of the app
app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
# who published this tool
description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM
@ -325,34 +328,3 @@ class DeprecatedPublishedAppTool(Base):
@property
def description_i18n(self) -> I18nObject:
return I18nObject(**json.loads(self.description))
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
conversation_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True)
file_key: Mapped[str] = db.Column(db.String(255), nullable=False)
mimetype: Mapped[str] = db.Column(db.String(255), nullable=False)
original_url: Mapped[Optional[str]] = db.Column(db.String(2048), nullable=True)
name: Mapped[str] = mapped_column(default="")
size: Mapped[int] = mapped_column(default=-1)
def __init__(
self,
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str] = None,
file_key: str,
mimetype: str,
original_url: Optional[str] = None,
name: str,
size: int,
):
self.user_id = user_id
self.tenant_id = tenant_id
self.conversation_id = conversation_id
self.file_key = file_key
self.mimetype = mimetype
self.original_url = original_url
self.name = name
self.size = size

View File

@ -9,7 +9,7 @@ if TYPE_CHECKING:
from models.model import AppMode
import sqlalchemy as sa
from sqlalchemy import Index, PrimaryKeyConstraint, func
from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column
import contexts
@ -18,11 +18,11 @@ from core.helper import encrypter
from core.variables import SecretVariable, Variable
from factories import variable_factory
from libs import helper
from models.base import Base
from models.enums import CreatedByRole
from .account import Account
from .base import Base
from .engine import db
from .enums import CreatedByRole
from .types import StringUUID
if TYPE_CHECKING:
@ -736,8 +736,7 @@ class WorkflowAppLog(Base):
__tablename__ = "workflow_app_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id", "created_at"),
db.Index("workflow_app_log_workflow_run_idx", "workflow_run_id"),
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
@ -769,17 +768,12 @@ class WorkflowAppLog(Base):
class ConversationVariable(Base):
__tablename__ = "workflow_conversation_variables"
__table_args__ = (
PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"),
Index("workflow__conversation_variables_app_id_idx", "app_id"),
Index("workflow__conversation_variables_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
data = mapped_column(db.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True)
updated_at = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)

View File

@ -14,7 +14,7 @@ dependencies = [
"chardet~=5.1.0",
"flask~=3.1.0",
"flask-compress~=1.17",
"flask-cors~=4.0.0",
"flask-cors~=5.0.0",
"flask-login~=0.6.3",
"flask-migrate~=4.0.7",
"flask-restful~=0.3.10",
@ -63,25 +63,24 @@ dependencies = [
"psycogreen~=1.0.2",
"psycopg2-binary~=2.9.6",
"pycryptodome==3.19.1",
"pydantic~=2.9.2",
"pydantic-extra-types~=2.9.0",
"pydantic-settings~=2.6.0",
"pydantic~=2.11.4",
"pydantic-extra-types~=2.10.3",
"pydantic-settings~=2.9.1",
"pyjwt~=2.8.0",
"pypdfium2~=4.30.0",
"pypdfium2==4.30.0",
"python-docx~=1.1.0",
"python-dotenv==1.0.1",
"pyyaml~=6.0.1",
"readabilipy==0.2.0",
"redis[hiredis]~=5.0.3",
"resend~=0.7.0",
"sentry-sdk[flask]~=1.44.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=6.0.0",
"resend~=2.9.0",
"sentry-sdk[flask]~=2.28.0",
"sqlalchemy~=2.0.29",
"starlette==0.41.0",
"tiktoken~=0.9.0",
"tokenizers~=0.15.0",
"transformers~=4.35.0",
"transformers~=4.51.0",
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
"weave~=0.51.34",
"weave~=0.51.0",
"yarl~=1.18.3",
"webvtt-py~=0.5.1",
]
@ -195,7 +194,7 @@ vdb = [
"tcvectordb~=1.6.4",
"tidb-vector==0.0.9",
"upstash-vector==0.6.0",
"volcengine-compat~=1.0.156",
"volcengine-compat~=1.0.0",
"weaviate-client~=3.24.0",
"xinference-client~=1.2.2",
]

View File

@ -1,4 +1,5 @@
import datetime
import logging
import time
import click
@ -20,6 +21,8 @@ from models.model import (
from models.web import SavedMessage
from services.feature_service import FeatureService
_logger = logging.getLogger(__name__)
@app.celery.task(queue="dataset")
def clean_messages():
@ -46,7 +49,14 @@ def clean_messages():
break
for message in messages:
plan_sandbox_clean_message_day = message.created_at
app = App.query.filter_by(id=message.app_id).first()
app = db.session.query(App).filter_by(id=message.app_id).first()
if not app:
_logger.warning(
"Expected App record to exist, but none was found, app_id=%s, message_id=%s",
message.app_id,
message.id,
)
continue
features_cache_key = f"features:{app.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:

View File

@ -2,7 +2,7 @@ import datetime
import time
import click
from sqlalchemy import func
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound
import app
@ -51,8 +51,9 @@ def clean_unused_datasets_task():
)
# Main query with join and filter
datasets = (
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
stmt = (
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter(
Dataset.created_at < plan_sandbox_clean_day,
@ -60,9 +61,10 @@ def clean_unused_datasets_task():
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
)
.order_by(Dataset.created_at.desc())
.paginate(page=1, per_page=50)
)
datasets = db.paginate(stmt, page=1, per_page=50)
except NotFound:
break
if datasets.items is None or len(datasets.items) == 0:
@ -99,7 +101,7 @@ def clean_unused_datasets_task():
# update document
update_params = {Document.enabled: False}
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit()
click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
except Exception as e:
@ -135,8 +137,9 @@ def clean_unused_datasets_task():
)
# Main query with join and filter
datasets = (
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
stmt = (
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter(
Dataset.created_at < plan_pro_clean_day,
@ -144,8 +147,8 @@ def clean_unused_datasets_task():
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
)
.order_by(Dataset.created_at.desc())
.paginate(page=1, per_page=50)
)
datasets = db.paginate(stmt, page=1, per_page=50)
except NotFound:
break
@ -175,7 +178,7 @@ def clean_unused_datasets_task():
# update document
update_params = {Document.enabled: False}
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit()
click.echo(
click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")

View File

@ -19,7 +19,9 @@ def create_tidb_serverless_task():
while True:
try:
# check the number of idle tidb serverless
idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count()
idle_tidb_serverless_number = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count()
)
if idle_tidb_serverless_number >= tidb_serverless_number:
break
# create tidb serverless

View File

@ -29,7 +29,9 @@ def mail_clean_document_notify_task():
# send document clean notify mail
try:
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all()
)
# group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs:
@ -43,14 +45,16 @@ def mail_clean_document_notify_task():
if plan != "sandbox":
knowledge_details = []
# check tenant
tenant = Tenant.query.filter(Tenant.id == tenant_id).first()
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
if not tenant:
continue
# check current owner
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
current_owner_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
)
if not current_owner_join:
continue
account = Account.query.filter(Account.id == current_owner_join.account_id).first()
account = db.session.query(Account).filter(Account.id == current_owner_join.account_id).first()
if not account:
continue
@ -63,7 +67,7 @@ def mail_clean_document_notify_task():
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = Dataset.query.filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")

View File

@ -5,6 +5,7 @@ import click
import app
from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding
@ -14,9 +15,11 @@ def update_tidb_serverless_status_task():
start_at = time.perf_counter()
try:
# check the number of idle tidb serverless
tidb_serverless_list = TidbAuthBinding.query.filter(
TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING"
).all()
tidb_serverless_list = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
.all()
)
if len(tidb_serverless_list) == 0:
return
# update tidb serverless status

View File

@ -108,17 +108,20 @@ class AccountService:
if account.status == AccountStatus.BANNED.value:
raise Unauthorized("Account is banned.")
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
if current_tenant:
account.current_tenant_id = current_tenant.tenant_id
account.set_tenant_id(current_tenant.tenant_id)
else:
available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
)
if not available_ta:
return None
account.current_tenant_id = available_ta.tenant_id
account.set_tenant_id(available_ta.tenant_id)
available_ta.current = True
db.session.commit()
@ -297,9 +300,9 @@ class AccountService:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
account_id=account.id, provider=provider
).first()
account_integrate: Optional[AccountIntegrate] = (
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
)
if account_integrate:
# If it exists, update the record
@ -612,7 +615,10 @@ class TenantService:
):
"""Check if user have a workspace or not"""
available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
)
if available_ta:
@ -670,7 +676,7 @@ class TenantService:
if not tenant:
raise TenantNotFoundError("Tenant not found.")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
if ta:
tenant.role = ta.role
else:
@ -699,12 +705,12 @@ class TenantService:
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:
TenantAccountJoin.query.filter(
db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
).update({"current": False})
tenant_account_join.current = True
# Set the current tenant for the account
account.current_tenant_id = tenant_account_join.tenant_id
account.set_tenant_id(tenant_account_join.tenant_id)
db.session.commit()
@staticmethod
@ -791,7 +797,7 @@ class TenantService:
if operator.id == member.id:
raise CannotOperateSelfError("Cannot operate self.")
ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first()
if not ta_operator or ta_operator.role not in perms[action]:
raise NoPermissionError(f"No permission to {action} member.")
@ -804,7 +810,7 @@ class TenantService:
TenantService.check_member_permission(tenant, operator, account, "remove")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
if not ta:
raise MemberNotInTenantError("Member not in tenant.")
@ -816,15 +822,23 @@ class TenantService:
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update")
target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
target_member_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first()
)
if not target_member_join:
raise MemberNotInTenantError("Member not in tenant.")
if target_member_join.role == new_role:
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
if new_role == "owner":
# Find the current owner and change their role to 'admin'
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
current_owner_join.role = "admin"
current_owner_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
)
if current_owner_join:
current_owner_join.role = "admin"
# Update the role of the target member
target_member_join.role = new_role
@ -841,7 +855,7 @@ class TenantService:
@staticmethod
def get_custom_config(tenant_id: str) -> dict:
tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404()
tenant = db.get_or_404(Tenant, tenant_id)
return cast(dict, tenant.custom_config_dict)
@ -966,7 +980,7 @@ class RegisterService:
TenantService.switch_tenant(account, tenant.id)
else:
TenantService.check_member_permission(tenant, inviter, account, "add")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
if not ta:
TenantService.create_tenant_member(tenant, account, role)

View File

@ -4,7 +4,7 @@ from typing import cast
import pandas as pd
from flask_login import current_user
from sqlalchemy import or_
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@ -124,8 +124,9 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
if keyword:
annotations = (
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
stmt = (
select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.filter(
or_(
MessageAnnotation.question.ilike("%{}%".format(keyword)),
@ -133,14 +134,14 @@ class AppAnnotationService:
)
)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
else:
annotations = (
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id)
stmt = (
select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
return annotations.items, annotations.total
@classmethod
@ -325,13 +326,16 @@ class AppAnnotationService:
if not annotation:
raise NotFound("Annotation not found")
annotation_hit_histories = (
AppAnnotationHitHistory.query.filter(
stmt = (
select(AppAnnotationHitHistory)
.filter(
AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)
.order_by(AppAnnotationHitHistory.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
)
annotation_hit_histories = db.paginate(
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
)
return annotation_hit_histories.items, annotation_hit_histories.total

View File

@ -9,7 +9,7 @@ from collections import Counter
from typing import Any, Optional
from flask_login import current_user
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -77,11 +77,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user:
# get permitted dataset ids
dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all()
dataset_permission = (
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
)
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
@ -129,7 +131,7 @@ class DatasetService:
else:
return [], 0
datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
return datasets.items, datasets.total
@ -153,9 +155,10 @@ class DatasetService:
@staticmethod
def get_datasets_by_ids(ids, tenant_id):
datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate(
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False
)
stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
return datasets.items, datasets.total
@staticmethod
@ -174,7 +177,7 @@ class DatasetService:
retrieval_model: Optional[RetrievalModel] = None,
):
# check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None
if indexing_technique == "high_quality":
@ -235,7 +238,7 @@ class DatasetService:
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first()
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset
@staticmethod
@ -436,7 +439,7 @@ class DatasetService:
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
dataset.query.filter_by(id=dataset_id).update(filtered_data)
db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
db.session.commit()
if action:
@ -460,7 +463,7 @@ class DatasetService:
@staticmethod
def dataset_use_check(dataset_id) -> bool:
count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count()
count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
if count > 0:
return True
return False
@ -475,7 +478,9 @@ class DatasetService:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.")
if dataset.permission == "partial_members":
user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first()
user_permission = (
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
)
if (
not user_permission
and dataset.tenant_id != user.current_tenant_id
@ -499,23 +504,24 @@ class DatasetService:
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any(
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
dp.dataset_id == dataset.id
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
):
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
dataset_queries = (
DatasetQuery.query.filter_by(dataset_id=dataset_id)
.order_by(db.desc(DatasetQuery.created_at))
.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
)
stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
return dataset_queries.items, dataset_queries.total
@staticmethod
def get_related_apps(dataset_id: str):
return (
AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id)
db.session.query(AppDatasetJoin)
.filter(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at))
.all()
)
@ -530,10 +536,14 @@ class DatasetService:
}
# get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
).all()
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
.filter(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
.all()
)
if dataset_auto_disable_logs:
return {
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
@ -873,7 +883,9 @@ class DocumentService:
@staticmethod
def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
document = (
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
)
if document:
return document.position + 1
else:
@ -1010,13 +1022,17 @@ class DocumentService:
}
# check duplicate
if knowledge_config.duplicate:
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
).first()
document = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
)
.first()
)
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
@ -1054,12 +1070,16 @@ class DocumentService:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
).all()
documents = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
.all()
)
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
@ -1067,14 +1087,18 @@ class DocumentService:
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).first()
.first()
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:
@ -1206,12 +1230,16 @@ class DocumentService:
@staticmethod
def get_tenant_documents_count():
documents_count = Document.query.filter(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
).count()
documents_count = (
db.session.query(Document)
.filter(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
)
.count()
)
return documents_count
@staticmethod
@ -1278,14 +1306,18 @@ class DocumentService:
notion_info_list = document_data.data_source.info_list.notion_info_list
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).first()
.first()
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:
@ -1328,7 +1360,7 @@ class DocumentService:
db.session.commit()
# update document segment
update_params = {DocumentSegment.status: "re_segment"}
DocumentSegment.query.filter_by(document_id=document.id).update(update_params)
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params)
db.session.commit()
# trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id)
@ -1918,7 +1950,8 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
index_node_ids = (
DocumentSegment.query.with_entities(DocumentSegment.index_node_id)
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
@ -2157,20 +2190,28 @@ class SegmentService:
def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
):
query = ChildChunk.query.filter_by(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id,
document_id=document_id,
segment_id=segment_id,
).order_by(ChildChunk.position.asc())
query = (
select(ChildChunk)
.filter_by(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id,
document_id=document_id,
segment_id=segment_id,
)
.order_by(ChildChunk.position.asc())
)
if keyword:
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
"""Get a child chunk by its ID."""
result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first()
result = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, ChildChunk) else None
@classmethod
@ -2184,7 +2225,7 @@ class SegmentService:
limit: int = 20,
):
"""Get segments for a document with optional filtering."""
query = DocumentSegment.query.filter(
query = select(DocumentSegment).filter(
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
)
@ -2194,9 +2235,8 @@ class SegmentService:
if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
page=page, per_page=limit, max_per_page=100, error_out=False
)
query = query.order_by(DocumentSegment.position.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
return paginated_segments.items, paginated_segments.total
@ -2236,9 +2276,11 @@ class SegmentService:
raise ValueError(ex.description)
# check segment
segment = DocumentSegment.query.filter(
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.first()
)
if not segment:
raise NotFound("Segment not found.")
@ -2251,9 +2293,11 @@ class SegmentService:
@classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
"""Get a segment by its ID."""
result = DocumentSegment.query.filter(
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id
).first()
result = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, DocumentSegment) else None

View File

@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast
from urllib.parse import urlparse
import httpx
from sqlalchemy import select
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
@ -24,14 +25,20 @@ from services.errors.dataset import DatasetNameDuplicateError
class ExternalDatasetService:
@staticmethod
def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]:
query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by(
ExternalKnowledgeApis.created_at.desc()
def get_external_knowledge_apis(
page, per_page, tenant_id, search=None
) -> tuple[list[ExternalKnowledgeApis], int | None]:
query = (
select(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.tenant_id == tenant_id)
.order_by(ExternalKnowledgeApis.created_at.desc())
)
if search:
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
)
return external_knowledge_apis.items, external_knowledge_apis.total
@ -92,18 +99,18 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id
).first()
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
return external_knowledge_api
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
@ -120,9 +127,9 @@ class ExternalDatasetService:
@staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@ -131,25 +138,29 @@ class ExternalDatasetService:
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count()
count = (
db.session.query(ExternalKnowledgeBindings)
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
.count()
)
if count > 0:
return True, count
return False, 0
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
return external_knowledge_binding
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_api_id, tenant_id=tenant_id
).first()
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
settings = json.loads(external_knowledge_api.settings)
@ -212,11 +223,13 @@ class ExternalDatasetService:
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first():
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=args.get("external_knowledge_api_id"), tenant_id=tenant_id
).first()
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
.first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@ -254,15 +267,17 @@ class ExternalDatasetService:
external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None,
) -> list:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
).first()
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by(
id=external_knowledge_binding.external_knowledge_api_id
).first()
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
.first()
)
if not external_knowledge_api:
raise ValueError("external api template not found")

View File

@ -69,6 +69,7 @@ class HitTestingService:
query: str,
account: Account,
external_retrieval_model: dict,
metadata_filtering_conditions: dict,
) -> dict:
if dataset.provider != "external":
return {
@ -82,6 +83,7 @@ class HitTestingService:
dataset_id=dataset.id,
query=cls.escape_query_for_search(query),
external_retrieval_model=external_retrieval_model,
metadata_filtering_conditions=metadata_filtering_conditions,
)
end = time.perf_counter()

View File

@ -177,6 +177,21 @@ class MessageService:
return feedback
@classmethod
def get_all_messages_feedbacks(cls, app_model: App, page: int, limit: int):
"""Get all feedbacks of an app"""
offset = (page - 1) * limit
feedbacks = (
db.session.query(MessageFeedback)
.filter(MessageFeedback.app_id == app_model.id)
.order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc())
.limit(limit)
.offset(offset)
.all()
)
return [record.to_dict() for record in feedbacks]
@classmethod
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
message = (

View File

@ -20,9 +20,11 @@ class MetadataService:
@staticmethod
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
# check if metadata name already exists
if DatasetMetadata.query.filter_by(
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name
).first():
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.first()
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
if field.value == metadata_args.name:
@ -42,16 +44,18 @@ class MetadataService:
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists
if DatasetMetadata.query.filter_by(
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name
).first():
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name)
.first()
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
if field.value == name:
raise ValueError("Metadata name already exists in Built-in fields.")
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
old_name = metadata.name
@ -60,7 +64,9 @@ class MetadataService:
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@ -82,13 +88,15 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
db.session.delete(metadata)
# deal related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@ -193,7 +201,7 @@ class MetadataService:
db.session.add(document)
db.session.commit()
# deal metadata binding
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete()
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
for metadata_value in operation.metadata_list:
dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=current_user.current_tenant_id,
@ -230,9 +238,9 @@ class MetadataService:
"id": item.get("id"),
"name": item.get("name"),
"type": item.get("type"),
"count": DatasetMetadataBinding.query.filter_by(
metadata_id=item.get("id"), dataset_id=dataset.id
).count(),
"count": db.session.query(DatasetMetadataBinding)
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
.count(),
}
for item in dataset.doc_metadata or []
if item.get("id") != "built-in"

View File

@ -1,3 +1,4 @@
import logging
from typing import Optional
from core.model_manager import ModelInstance, ModelManager
@ -12,17 +13,27 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
_logger = logging.getLogger(__name__)
class VectorService:
@classmethod
def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents = []
documents: list[Document] = []
document: Document | None = None
for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX:
document = DatasetDocument.query.filter_by(id=segment.document_id).first()
document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
segment.document_id,
segment.id,
)
continue
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)

View File

@ -2,7 +2,7 @@ import threading
from typing import Optional
import contexts
from core.workflow.repository import RepositoryFactory
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import OrderConfig
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -129,12 +129,8 @@ class WorkflowRunService:
return []
# Use the repository to get the node executions
repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": app_model.tenant_id,
"app_id": app_model.id,
"session_factory": db.session.get_bind(),
}
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id
)
# Use the repository to get the node executions with ordering

View File

@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.errors import WorkflowNodeRunFailedError
@ -21,7 +22,6 @@ from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.repository import RepositoryFactory
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
@ -285,12 +285,8 @@ class WorkflowService:
workflow_node_execution.workflow_id = draft_workflow.id
# Use the repository to save the workflow node execution
repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": app_model.tenant_id,
"app_id": app_model.id,
"session_factory": db.session.get_bind(),
}
repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id
)
repository.save(workflow_node_execution)

Some files were not shown because too many files have changed in this diff Show More