mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/end-user-oauth
This commit is contained in:
commit
b60ba0b192
|
|
@ -12,7 +12,7 @@ P = ParamSpec("P")
|
|||
R = TypeVar("R")
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
|
|
@ -38,10 +38,10 @@ def admin_required(view: Callable[P, R]):
|
|||
|
||||
@console_ns.route("/admin/insert-explore-apps")
|
||||
class InsertExploreAppListApi(Resource):
|
||||
@api.doc("insert_explore_app")
|
||||
@api.doc(description="Insert or update an app in the explore list")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("insert_explore_app")
|
||||
@console_ns.doc(description="Insert or update an app in the explore list")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InsertExploreAppRequest",
|
||||
{
|
||||
"app_id": fields.String(required=True, description="Application ID"),
|
||||
|
|
@ -55,9 +55,9 @@ class InsertExploreAppListApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "App updated successfully")
|
||||
@api.response(201, "App inserted successfully")
|
||||
@api.response(404, "App not found")
|
||||
@console_ns.response(200, "App updated successfully")
|
||||
@console_ns.response(201, "App inserted successfully")
|
||||
@console_ns.response(404, "App not found")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
|
|
@ -131,10 +131,10 @@ class InsertExploreAppListApi(Resource):
|
|||
|
||||
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
|
||||
class InsertExploreAppApi(Resource):
|
||||
@api.doc("delete_explore_app")
|
||||
@api.doc(description="Remove an app from the explore list")
|
||||
@api.doc(params={"app_id": "Application ID to remove"})
|
||||
@api.response(204, "App removed successfully")
|
||||
@console_ns.doc("delete_explore_app")
|
||||
@console_ns.doc(description="Remove an app from the explore list")
|
||||
@console_ns.doc(params={"app_id": "Application ID to remove"})
|
||||
@console_ns.response(204, "App removed successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from libs.login import current_account_with_tenant, login_required
|
|||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
from . import api, console_ns
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
api_key_fields = {
|
||||
|
|
@ -24,6 +24,12 @@ api_key_fields = {
|
|||
|
||||
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
|
||||
|
||||
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
|
||||
|
||||
api_key_list_model = console_ns.model(
|
||||
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
)
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
if resource_model == App:
|
||||
|
|
@ -52,7 +58,7 @@ class BaseApiKeyListResource(Resource):
|
|||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
||||
@marshal_with(api_key_list)
|
||||
@marshal_with(api_key_list_model)
|
||||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
|
|
@ -66,7 +72,7 @@ class BaseApiKeyListResource(Resource):
|
|||
).all()
|
||||
return {"items": keys}
|
||||
|
||||
@marshal_with(api_key_fields)
|
||||
@marshal_with(api_key_item_model)
|
||||
@edit_permission_required
|
||||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
|
|
@ -133,20 +139,20 @@ class BaseApiKeyResource(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@api.doc("get_app_api_keys")
|
||||
@api.doc(description="Get all API keys for an app")
|
||||
@api.doc(params={"resource_id": "App ID"})
|
||||
@api.response(200, "Success", api_key_list)
|
||||
def get(self, resource_id):
|
||||
@console_ns.doc("get_app_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for an app"""
|
||||
return super().get(resource_id)
|
||||
|
||||
@api.doc("create_app_api_key")
|
||||
@api.doc(description="Create a new API key for an app")
|
||||
@api.doc(params={"resource_id": "App ID"})
|
||||
@api.response(201, "API key created successfully", api_key_fields)
|
||||
@api.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id):
|
||||
@console_ns.doc("create_app_api_key")
|
||||
@console_ns.doc(description="Create a new API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
|
|
@ -158,10 +164,10 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
class AppApiKeyResource(BaseApiKeyResource):
|
||||
@api.doc("delete_app_api_key")
|
||||
@api.doc(description="Delete an API key for an app")
|
||||
@api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@api.response(204, "API key deleted successfully")
|
||||
@console_ns.doc("delete_app_api_key")
|
||||
@console_ns.doc(description="Delete an API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id, api_key_id):
|
||||
"""Delete an API key for an app"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
|
@ -173,20 +179,20 @@ class AppApiKeyResource(BaseApiKeyResource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:resource_id>/api-keys")
|
||||
class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@api.doc("get_dataset_api_keys")
|
||||
@api.doc(description="Get all API keys for a dataset")
|
||||
@api.doc(params={"resource_id": "Dataset ID"})
|
||||
@api.response(200, "Success", api_key_list)
|
||||
def get(self, resource_id):
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for a dataset"""
|
||||
return super().get(resource_id)
|
||||
|
||||
@api.doc("create_dataset_api_key")
|
||||
@api.doc(description="Create a new API key for a dataset")
|
||||
@api.doc(params={"resource_id": "Dataset ID"})
|
||||
@api.response(201, "API key created successfully", api_key_fields)
|
||||
@api.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id):
|
||||
@console_ns.doc("create_dataset_api_key")
|
||||
@console_ns.doc(description="Create a new API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
|
|
@ -198,10 +204,10 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
@api.doc("delete_dataset_api_key")
|
||||
@api.doc(description="Delete an API key for a dataset")
|
||||
@api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@api.response(204, "API key deleted successfully")
|
||||
@console_ns.doc("delete_dataset_api_key")
|
||||
@console_ns.doc(description="Delete an API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id, api_key_id):
|
||||
"""Delete an API key for a dataset"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
|
@ -16,13 +16,13 @@ parser = (
|
|||
|
||||
@console_ns.route("/app/prompt-templates")
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
@api.doc("get_advanced_prompt_templates")
|
||||
@api.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
@console_ns.doc("get_advanced_prompt_templates")
|
||||
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
||||
)
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.helper import uuid_value
|
||||
|
|
@ -17,12 +17,14 @@ parser = (
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
|
||||
class AgentLogApi(Resource):
|
||||
@api.doc("get_agent_logs")
|
||||
@api.doc(description="Get agent execution logs for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(parser)
|
||||
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.doc("get_agent_logs")
|
||||
@console_ns.doc(description="Get agent execution logs for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
|
||||
)
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from flask import request
|
|||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
|
|
@ -15,6 +15,7 @@ from extensions.ext_redis import redis_client
|
|||
from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
build_annotation_model,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
|
|
@ -23,11 +24,11 @@ from services.annotation_service import AppAnnotationService
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@api.doc("annotation_reply_action")
|
||||
@api.doc(description="Enable or disable annotation reply for an app")
|
||||
@api.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("annotation_reply_action")
|
||||
@console_ns.doc(description="Enable or disable annotation reply for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AnnotationReplyActionRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
|
||||
|
|
@ -36,8 +37,8 @@ class AnnotationReplyActionApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Action completed successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "Action completed successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -61,11 +62,11 @@ class AnnotationReplyActionApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-setting")
|
||||
class AppAnnotationSettingDetailApi(Resource):
|
||||
@api.doc("get_annotation_setting")
|
||||
@api.doc(description="Get annotation settings for an app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Annotation settings retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("get_annotation_setting")
|
||||
@console_ns.doc(description="Get annotation settings for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Annotation settings retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -78,11 +79,11 @@ class AppAnnotationSettingDetailApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
|
||||
class AppAnnotationSettingUpdateApi(Resource):
|
||||
@api.doc("update_annotation_setting")
|
||||
@api.doc(description="Update annotation settings for an app")
|
||||
@api.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_annotation_setting")
|
||||
@console_ns.doc(description="Update annotation settings for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AnnotationSettingUpdateRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold"),
|
||||
|
|
@ -91,8 +92,8 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Settings updated successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "Settings updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -110,11 +111,11 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>")
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@api.doc("get_annotation_reply_action_status")
|
||||
@api.doc(description="Get status of annotation reply action job")
|
||||
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
|
||||
@api.response(200, "Job status retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("get_annotation_reply_action_status")
|
||||
@console_ns.doc(description="Get status of annotation reply action job")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
|
||||
@console_ns.response(200, "Job status retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -138,17 +139,17 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
class AnnotationApi(Resource):
|
||||
@api.doc("list_annotations")
|
||||
@api.doc(description="Get annotations for an app with pagination")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
@console_ns.doc("list_annotations")
|
||||
@console_ns.doc(description="Get annotations for an app with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
||||
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
|
||||
)
|
||||
@api.response(200, "Annotations retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "Annotations retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -169,11 +170,11 @@ class AnnotationApi(Resource):
|
|||
}
|
||||
return response, 200
|
||||
|
||||
@api.doc("create_annotation")
|
||||
@api.doc(description="Create a new annotation for an app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_annotation")
|
||||
@console_ns.doc(description="Create a new annotation for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID (optional)"),
|
||||
|
|
@ -184,8 +185,8 @@ class AnnotationApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "Annotation created successfully", annotation_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -235,11 +236,15 @@ class AnnotationApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
|
||||
class AnnotationExportApi(Resource):
|
||||
@api.doc("export_annotations")
|
||||
@api.doc(description="Export all annotations for an app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields)))
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("export_annotations")
|
||||
@console_ns.doc(description="Export all annotations for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotations exported successfully",
|
||||
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -260,13 +265,13 @@ parser = (
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@api.doc("update_delete_annotation")
|
||||
@api.doc(description="Update or delete an annotation")
|
||||
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@api.response(200, "Annotation updated successfully", annotation_fields)
|
||||
@api.response(204, "Annotation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.expect(parser)
|
||||
@console_ns.doc("update_delete_annotation")
|
||||
@console_ns.doc(description="Update or delete an annotation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(204, "Annotation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -293,12 +298,12 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
|
||||
class AnnotationBatchImportApi(Resource):
|
||||
@api.doc("batch_import_annotations")
|
||||
@api.doc(description="Batch import annotations from CSV file")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Batch import started successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(400, "No file uploaded or too many files")
|
||||
@console_ns.doc("batch_import_annotations")
|
||||
@console_ns.doc(description="Batch import annotations from CSV file")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Batch import started successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "No file uploaded or too many files")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -323,11 +328,11 @@ class AnnotationBatchImportApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
|
||||
class AnnotationBatchImportStatusApi(Resource):
|
||||
@api.doc("get_batch_import_status")
|
||||
@api.doc(description="Get status of batch import job")
|
||||
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
|
||||
@api.response(200, "Job status retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("get_batch_import_status")
|
||||
@console_ns.doc(description="Get status of batch import job")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
|
||||
@console_ns.response(200, "Job status retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -350,18 +355,27 @@ class AnnotationBatchImportStatusApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
|
||||
class AnnotationHitHistoryListApi(Resource):
|
||||
@api.doc("list_annotation_hit_histories")
|
||||
@api.doc(description="Get hit histories for an annotation")
|
||||
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
@console_ns.doc("list_annotation_hit_histories")
|
||||
@console_ns.doc(description="Get hit histories for an annotation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
||||
)
|
||||
@api.response(
|
||||
200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields))
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Hit histories retrieved successfully",
|
||||
console_ns.model(
|
||||
"AnnotationHitHistoryList",
|
||||
{
|
||||
"data": fields.List(
|
||||
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, abort
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
|
|
@ -18,7 +18,15 @@ from controllers.console.wraps import (
|
|||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||
from fields.app_fields import (
|
||||
deleted_tool_fields,
|
||||
model_config_fields,
|
||||
model_config_partial_fields,
|
||||
site_fields,
|
||||
tag_fields,
|
||||
)
|
||||
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
|
||||
from libs.helper import AppIconUrlField, TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import App, Workflow
|
||||
|
|
@ -29,13 +37,118 @@ from services.feature_service import FeatureService
|
|||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base models first
|
||||
tag_model = console_ns.model("Tag", tag_fields)
|
||||
|
||||
workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict)
|
||||
|
||||
model_config_model = console_ns.model("ModelConfig", model_config_fields)
|
||||
|
||||
model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields)
|
||||
|
||||
deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields)
|
||||
|
||||
site_model = console_ns.model("Site", site_fields)
|
||||
|
||||
app_partial_model = console_ns.model(
|
||||
"AppPartial",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"max_active_requests": fields.Raw(),
|
||||
"description": fields.String(attribute="desc_or_prompt"),
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
"access_mode": fields.String,
|
||||
"create_user_name": fields.String,
|
||||
"author_name": fields.String,
|
||||
"has_draft_trigger": fields.Boolean,
|
||||
},
|
||||
)
|
||||
|
||||
app_detail_model = console_ns.model(
|
||||
"AppDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"enable_site": fields.Boolean,
|
||||
"enable_api": fields.Boolean,
|
||||
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"tracing": fields.Raw,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
},
|
||||
)
|
||||
|
||||
app_detail_with_site_model = console_ns.model(
|
||||
"AppDetailWithSite",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"enable_site": fields.Boolean,
|
||||
"enable_api": fields.Boolean,
|
||||
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"api_base_url": fields.String,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"max_active_requests": fields.Integer,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"deleted_tools": fields.List(fields.Nested(deleted_tool_model)),
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
"site": fields.Nested(site_model),
|
||||
},
|
||||
)
|
||||
|
||||
app_pagination_model = console_ns.model(
|
||||
"AppPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(app_partial_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@api.doc("list_apps")
|
||||
@api.doc(description="Get list of applications with pagination and filtering")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
@console_ns.doc("list_apps")
|
||||
@console_ns.doc(description="Get list of applications with pagination and filtering")
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
|
||||
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
|
||||
.add_argument(
|
||||
|
|
@ -50,7 +163,7 @@ class AppListApi(Resource):
|
|||
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
|
||||
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
|
||||
)
|
||||
@api.response(200, "Success", app_pagination_fields)
|
||||
@console_ns.response(200, "Success", app_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -137,12 +250,12 @@ class AppListApi(Resource):
|
|||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
return marshal(app_pagination, app_pagination_fields), 200
|
||||
return marshal(app_pagination, app_pagination_model), 200
|
||||
|
||||
@api.doc("create_app")
|
||||
@api.doc(description="Create a new application")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_app")
|
||||
@console_ns.doc(description="Create a new application")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
|
|
@ -154,13 +267,13 @@ class AppListApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "App created successfully", app_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.response(201, "App created successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields)
|
||||
@marshal_with(app_detail_model)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
|
|
@ -188,16 +301,16 @@ class AppListApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>")
|
||||
class AppApi(Resource):
|
||||
@api.doc("get_app_detail")
|
||||
@api.doc(description="Get application details")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Success", app_detail_fields_with_site)
|
||||
@console_ns.doc("get_app_detail")
|
||||
@console_ns.doc(description="Get application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Success", app_detail_with_site_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
app_service = AppService()
|
||||
|
|
@ -210,11 +323,11 @@ class AppApi(Resource):
|
|||
|
||||
return app_model
|
||||
|
||||
@api.doc("update_app")
|
||||
@api.doc(description="Update application details")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_app")
|
||||
@console_ns.doc(description="Update application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
|
|
@ -227,15 +340,15 @@ class AppApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "App updated successfully", app_detail_fields_with_site)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def put(self, app_model):
|
||||
"""Update app"""
|
||||
parser = (
|
||||
|
|
@ -265,11 +378,11 @@ class AppApi(Resource):
|
|||
|
||||
return app_model
|
||||
|
||||
@api.doc("delete_app")
|
||||
@api.doc(description="Delete application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(204, "App deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("delete_app")
|
||||
@console_ns.doc(description="Delete application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(204, "App deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -285,11 +398,11 @@ class AppApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/copy")
|
||||
class AppCopyApi(Resource):
|
||||
@api.doc("copy_app")
|
||||
@api.doc(description="Create a copy of an existing application")
|
||||
@api.doc(params={"app_id": "Application ID to copy"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("copy_app")
|
||||
@console_ns.doc(description="Create a copy of an existing application")
|
||||
@console_ns.doc(params={"app_id": "Application ID to copy"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CopyAppRequest",
|
||||
{
|
||||
"name": fields.String(description="Name for the copied app"),
|
||||
|
|
@ -300,14 +413,14 @@ class AppCopyApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "App copied successfully", app_detail_fields_with_site)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def post(self, app_model):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
|
|
@ -346,20 +459,20 @@ class AppCopyApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/export")
|
||||
class AppExportApi(Resource):
|
||||
@api.doc("export_app")
|
||||
@api.doc(description="Export application configuration as DSL")
|
||||
@api.doc(params={"app_id": "Application ID to export"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
@console_ns.doc("export_app")
|
||||
@console_ns.doc(description="Export application configuration as DSL")
|
||||
@console_ns.doc(params={"app_id": "Application ID to export"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
|
||||
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"App exported successfully",
|
||||
api.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
|
||||
console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
|
||||
)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -387,16 +500,16 @@ parser = reqparse.RequestParser().add_argument("name", type=str, required=True,
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@api.doc("check_app_name")
|
||||
@api.doc(description="Check if app name is available")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(parser)
|
||||
@api.response(200, "Name availability checked")
|
||||
@console_ns.doc("check_app_name")
|
||||
@console_ns.doc(description="Check if app name is available")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(200, "Name availability checked")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args = parser.parse_args()
|
||||
|
|
@ -409,11 +522,11 @@ class AppNameApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/icon")
|
||||
class AppIconApi(Resource):
|
||||
@api.doc("update_app_icon")
|
||||
@api.doc(description="Update application icon")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_app_icon")
|
||||
@console_ns.doc(description="Update application icon")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppIconRequest",
|
||||
{
|
||||
"icon": fields.String(required=True, description="Icon data"),
|
||||
|
|
@ -422,13 +535,13 @@ class AppIconApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Icon updated successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "Icon updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
|
|
@ -446,21 +559,21 @@ class AppIconApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site-enable")
|
||||
class AppSiteStatus(Resource):
|
||||
@api.doc("update_app_site_status")
|
||||
@api.doc(description="Enable or disable app site")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_app_site_status")
|
||||
@console_ns.doc(description="Enable or disable app site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
|
||||
)
|
||||
)
|
||||
@api.response(200, "Site status updated successfully", app_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "Site status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
|
||||
|
|
@ -474,22 +587,22 @@ class AppSiteStatus(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/api-enable")
|
||||
class AppApiStatus(Resource):
|
||||
@api.doc("update_app_api_status")
|
||||
@api.doc(description="Enable or disable app API")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_app_api_status")
|
||||
@console_ns.doc(description="Enable or disable app API")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
|
||||
)
|
||||
)
|
||||
@api.response(200, "API status updated successfully", app_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "API status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
@marshal_with(app_detail_model)
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
|
@ -502,10 +615,10 @@ class AppApiStatus(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trace")
|
||||
class AppTraceApi(Resource):
|
||||
@api.doc("get_app_trace")
|
||||
@api.doc(description="Get app tracing configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Trace configuration retrieved successfully")
|
||||
@console_ns.doc("get_app_trace")
|
||||
@console_ns.doc(description="Get app tracing configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Trace configuration retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -515,11 +628,11 @@ class AppTraceApi(Resource):
|
|||
|
||||
return app_trace_config
|
||||
|
||||
@api.doc("update_app_trace")
|
||||
@api.doc(description="Update app tracing configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_app_trace")
|
||||
@console_ns.doc(description="Update app tracing configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppTraceRequest",
|
||||
{
|
||||
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
|
||||
|
|
@ -527,8 +640,8 @@ class AppTraceApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trace configuration updated successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "Trace configuration updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
|
|
@ -10,7 +9,11 @@ from controllers.console.wraps import (
|
|||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
||||
from fields.app_fields import (
|
||||
app_import_check_dependencies_fields,
|
||||
app_import_fields,
|
||||
leaked_dependency_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
|
|
@ -19,6 +22,19 @@ from services.feature_service import FeatureService
|
|||
|
||||
from .. import console_ns
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
|
||||
|
||||
app_import_model = console_ns.model("AppImport", app_import_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
|
||||
app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
|
||||
app_import_check_dependencies_model = console_ns.model(
|
||||
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
||||
)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
|
|
@ -35,11 +51,11 @@ parser = (
|
|||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@api.expect(parser)
|
||||
@console_ns.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_fields)
|
||||
@marshal_with(app_import_model)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
|
|
@ -82,7 +98,7 @@ class AppImportConfirmApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_fields)
|
||||
@marshal_with(app_import_model)
|
||||
@edit_permission_required
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
|
|
@ -108,7 +124,7 @@ class AppImportCheckDependenciesApi(Resource):
|
|||
@login_required
|
||||
@get_app_model
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_check_dependencies_fields)
|
||||
@marshal_with(app_import_check_dependencies_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
with Session(db.engine) as session:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask_restx import Resource, fields, reqparse
|
|||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
|
|
@ -36,16 +36,16 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||
class ChatMessageAudioApi(Resource):
|
||||
@api.doc("chat_message_audio_transcript")
|
||||
@api.doc(description="Transcript audio to text for chat messages")
|
||||
@api.doc(params={"app_id": "App ID"})
|
||||
@api.response(
|
||||
@console_ns.doc("chat_message_audio_transcript")
|
||||
@console_ns.doc(description="Transcript audio to text for chat messages")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Audio transcription successful",
|
||||
api.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
|
||||
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
|
||||
)
|
||||
@api.response(400, "Bad request - No audio uploaded or unsupported type")
|
||||
@api.response(413, "Audio file too large")
|
||||
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
|
||||
@console_ns.response(413, "Audio file too large")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -89,11 +89,11 @@ class ChatMessageAudioApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/text-to-audio")
|
||||
class ChatMessageTextApi(Resource):
|
||||
@api.doc("chat_message_text_to_speech")
|
||||
@api.doc(description="Convert text to speech for chat messages")
|
||||
@api.doc(params={"app_id": "App ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("chat_message_text_to_speech")
|
||||
@console_ns.doc(description="Convert text to speech for chat messages")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TextToSpeechRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
|
|
@ -103,8 +103,8 @@ class ChatMessageTextApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Text to speech conversion successful")
|
||||
@api.response(400, "Bad request - Invalid parameters")
|
||||
@console_ns.response(200, "Text to speech conversion successful")
|
||||
@console_ns.response(400, "Bad request - Invalid parameters")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -156,12 +156,16 @@ class ChatMessageTextApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/text-to-audio/voices")
|
||||
class TextModesApi(Resource):
|
||||
@api.doc("get_text_to_speech_voices")
|
||||
@api.doc(description="Get available TTS voices for a specific language")
|
||||
@api.doc(params={"app_id": "App ID"})
|
||||
@api.expect(api.parser().add_argument("language", type=str, required=True, location="args", help="Language code"))
|
||||
@api.response(200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")))
|
||||
@api.response(400, "Invalid language parameter")
|
||||
@console_ns.doc("get_text_to_speech_voices")
|
||||
@console_ns.doc(description="Get available TTS voices for a specific language")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
|
||||
)
|
||||
@console_ns.response(
|
||||
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
|
||||
)
|
||||
@console_ns.response(400, "Invalid language parameter")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask_restx import Resource, fields, reqparse
|
|||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
|
|
@ -40,11 +40,11 @@ logger = logging.getLogger(__name__)
|
|||
# define completion message api for user
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-messages")
|
||||
class CompletionMessageApi(Resource):
|
||||
@api.doc("create_completion_message")
|
||||
@api.doc(description="Generate completion message for debugging")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_completion_message")
|
||||
@console_ns.doc(description="Generate completion message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CompletionMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
|
|
@ -56,9 +56,9 @@ class CompletionMessageApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Completion generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(404, "App not found")
|
||||
@console_ns.response(200, "Completion generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -110,10 +110,10 @@ class CompletionMessageApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
|
||||
class CompletionMessageStopApi(Resource):
|
||||
@api.doc("stop_completion_message")
|
||||
@api.doc(description="Stop a running completion message generation")
|
||||
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@api.response(200, "Task stopped successfully")
|
||||
@console_ns.doc("stop_completion_message")
|
||||
@console_ns.doc(description="Stop a running completion message generation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -128,11 +128,11 @@ class CompletionMessageStopApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
|
||||
class ChatMessageApi(Resource):
|
||||
@api.doc("create_chat_message")
|
||||
@api.doc(description="Generate chat message for debugging")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_chat_message")
|
||||
@console_ns.doc(description="Generate chat message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ChatMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
|
|
@ -146,9 +146,9 @@ class ChatMessageApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Chat message generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(404, "App or conversation not found")
|
||||
@console_ns.response(200, "Chat message generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App or conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -209,10 +209,10 @@ class ChatMessageApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
|
||||
class ChatMessageStopApi(Resource):
|
||||
@api.doc("stop_chat_message")
|
||||
@api.doc(description="Stop a running chat message generation")
|
||||
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@api.response(200, "Task stopped successfully")
|
||||
@console_ns.doc("stop_chat_message")
|
||||
@console_ns.doc(description="Stop a running chat message generation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -1,38 +1,290 @@
|
|||
import sqlalchemy as sa
|
||||
from flask import abort
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
conversation_detail_fields,
|
||||
conversation_message_detail_fields,
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from fields.conversation_fields import MessageTextField
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.helper import DatetimeString, TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model(
|
||||
"SimpleAccount",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
feedback_stat_model = console_ns.model(
|
||||
"FeedbackStat",
|
||||
{
|
||||
"like": fields.Integer,
|
||||
"dislike": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
status_count_model = console_ns.model(
|
||||
"StatusCount",
|
||||
{
|
||||
"success": fields.Integer,
|
||||
"failed": fields.Integer,
|
||||
"partial_success": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
message_file_model = console_ns.model(
|
||||
"MessageFile",
|
||||
{
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
},
|
||||
)
|
||||
|
||||
agent_thought_model = console_ns.model(
|
||||
"AgentThought",
|
||||
{
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
},
|
||||
)
|
||||
|
||||
simple_model_config_model = console_ns.model(
|
||||
"SimpleModelConfig",
|
||||
{
|
||||
"model": fields.Raw(attribute="model_dict"),
|
||||
"pre_prompt": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
model_config_model = console_ns.model(
|
||||
"ModelConfig",
|
||||
{
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"model": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"pre_prompt": fields.String,
|
||||
"agent_mode": fields.Raw,
|
||||
},
|
||||
)
|
||||
|
||||
# Models that depend on simple_account_model
|
||||
feedback_model = console_ns.model(
|
||||
"Feedback",
|
||||
{
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
},
|
||||
)
|
||||
|
||||
annotation_model = console_ns.model(
|
||||
"Annotation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
annotation_hit_history_model = console_ns.model(
|
||||
"AnnotationHitHistory",
|
||||
{
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
# Simple message detail model
|
||||
simple_message_detail_model = console_ns.model(
|
||||
"SimpleMessageDetail",
|
||||
{
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": MessageTextField,
|
||||
"answer": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Message detail model that depends on multiple models
|
||||
message_detail_model = console_ns.model(
|
||||
"MessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_model)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Conversation models
|
||||
conversation_fields_model = console_ns.model(
|
||||
"Conversation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String(),
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"model_config": fields.Nested(simple_model_config_model),
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"message": fields.Nested(simple_message_detail_model, attribute="first_message"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_pagination_model = console_ns.model(
|
||||
"ConversationPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_fields_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_message_detail_model = console_ns.model(
|
||||
"ConversationMessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"model_config": fields.Nested(model_config_model),
|
||||
"message": fields.Nested(message_detail_model, attribute="first_message"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_with_summary_model = console_ns.model(
|
||||
"ConversationWithSummary",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"name": fields.String,
|
||||
"summary": fields.String(attribute="summary_or_query"),
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"model_config": fields.Nested(simple_model_config_model),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"status_count": fields.Nested(status_count_model),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_with_summary_pagination_model = console_ns.model(
|
||||
"ConversationWithSummaryPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_with_summary_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_detail_model = console_ns.model(
|
||||
"ConversationDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"introduction": fields.String,
|
||||
"model_config": fields.Nested(model_config_model),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-conversations")
|
||||
class CompletionConversationApi(Resource):
|
||||
@api.doc("list_completion_conversations")
|
||||
@api.doc(description="Get completion conversations with pagination and filtering")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
@console_ns.doc("list_completion_conversations")
|
||||
@console_ns.doc(description="Get completion conversations with pagination and filtering")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
|
|
@ -47,13 +299,13 @@ class CompletionConversationApi(Resource):
|
|||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
)
|
||||
@api.response(200, "Success", conversation_pagination_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "Success", conversation_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_pagination_fields)
|
||||
@marshal_with(conversation_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
@ -122,29 +374,29 @@ class CompletionConversationApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
|
||||
class CompletionConversationDetailApi(Resource):
|
||||
@api.doc("get_completion_conversation")
|
||||
@api.doc(description="Get completion conversation details with messages")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(200, "Success", conversation_message_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@console_ns.doc("get_completion_conversation")
|
||||
@console_ns.doc(description="Get completion conversation details with messages")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(200, "Success", conversation_message_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_message_detail_fields)
|
||||
@marshal_with(conversation_message_detail_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id):
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
|
||||
@api.doc("delete_completion_conversation")
|
||||
@api.doc(description="Delete a completion conversation")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(204, "Conversation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@console_ns.doc("delete_completion_conversation")
|
||||
@console_ns.doc(description="Delete a completion conversation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -164,11 +416,11 @@ class CompletionConversationDetailApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
|
||||
class ChatConversationApi(Resource):
|
||||
@api.doc("list_chat_conversations")
|
||||
@api.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
@console_ns.doc("list_chat_conversations")
|
||||
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
|
|
@ -192,13 +444,13 @@ class ChatConversationApi(Resource):
|
|||
help="Sort field and direction",
|
||||
)
|
||||
)
|
||||
@api.response(200, "Success", conversation_with_summary_pagination_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(conversation_with_summary_pagination_fields)
|
||||
@marshal_with(conversation_with_summary_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
@ -322,29 +574,29 @@ class ChatConversationApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
|
||||
class ChatConversationDetailApi(Resource):
|
||||
@api.doc("get_chat_conversation")
|
||||
@api.doc(description="Get chat conversation details")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(200, "Success", conversation_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@console_ns.doc("get_chat_conversation")
|
||||
@console_ns.doc(description="Get chat conversation details")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(200, "Success", conversation_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(conversation_detail_fields)
|
||||
@marshal_with(conversation_detail_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id):
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
|
||||
@api.doc("delete_chat_conversation")
|
||||
@api.doc(description="Delete a chat conversation")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(204, "Conversation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@console_ns.doc("delete_chat_conversation")
|
||||
@console_ns.doc(description="Delete a chat conversation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
|
|
|
|||
|
|
@ -1,33 +1,49 @@
|
|||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_variable_fields import paginated_conversation_variable_fields
|
||||
from fields.conversation_variable_fields import (
|
||||
conversation_variable_fields,
|
||||
paginated_conversation_variable_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
|
||||
paginated_conversation_variable_fields_copy["data"] = fields.List(
|
||||
fields.Nested(conversation_variable_model), attribute="data"
|
||||
)
|
||||
paginated_conversation_variable_model = console_ns.model(
|
||||
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/conversation-variables")
|
||||
class ConversationVariablesApi(Resource):
|
||||
@api.doc("get_conversation_variables")
|
||||
@api.doc(description="Get conversation variables for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
@console_ns.doc("get_conversation_variables")
|
||||
@console_ns.doc(description="Get conversation variables for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
|
||||
)
|
||||
)
|
||||
@api.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields)
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_fields)
|
||||
@marshal_with(paginated_conversation_variable_model)
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from collections.abc import Sequence
|
|||
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
|
|
@ -24,10 +24,10 @@ from services.workflow_service import WorkflowService
|
|||
|
||||
@console_ns.route("/rule-generate")
|
||||
class RuleGenerateApi(Resource):
|
||||
@api.doc("generate_rule_config")
|
||||
@api.doc(description="Generate rule configuration using LLM")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("generate_rule_config")
|
||||
@console_ns.doc(description="Generate rule configuration using LLM")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"RuleGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Rule generation instruction"),
|
||||
|
|
@ -36,9 +36,9 @@ class RuleGenerateApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Rule configuration generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@console_ns.response(200, "Rule configuration generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -73,10 +73,10 @@ class RuleGenerateApi(Resource):
|
|||
|
||||
@console_ns.route("/rule-code-generate")
|
||||
class RuleCodeGenerateApi(Resource):
|
||||
@api.doc("generate_rule_code")
|
||||
@api.doc(description="Generate code rules using LLM")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("generate_rule_code")
|
||||
@console_ns.doc(description="Generate code rules using LLM")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"RuleCodeGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Code generation instruction"),
|
||||
|
|
@ -88,9 +88,9 @@ class RuleCodeGenerateApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Code rules generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@console_ns.response(200, "Code rules generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -126,10 +126,10 @@ class RuleCodeGenerateApi(Resource):
|
|||
|
||||
@console_ns.route("/rule-structured-output-generate")
|
||||
class RuleStructuredOutputGenerateApi(Resource):
|
||||
@api.doc("generate_structured_output")
|
||||
@api.doc(description="Generate structured output rules using LLM")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("generate_structured_output")
|
||||
@console_ns.doc(description="Generate structured output rules using LLM")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"StructuredOutputGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Structured output generation instruction"),
|
||||
|
|
@ -137,9 +137,9 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Structured output generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@console_ns.response(200, "Structured output generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -172,10 +172,10 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||
|
||||
@console_ns.route("/instruction-generate")
|
||||
class InstructionGenerateApi(Resource):
|
||||
@api.doc("generate_instruction")
|
||||
@api.doc(description="Generate instruction for workflow nodes or general use")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("generate_instruction")
|
||||
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InstructionGenerateRequest",
|
||||
{
|
||||
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
|
||||
|
|
@ -188,9 +188,9 @@ class InstructionGenerateApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Instruction generated successfully")
|
||||
@api.response(400, "Invalid request parameters or flow/workflow not found")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@console_ns.response(200, "Instruction generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -283,10 +283,10 @@ class InstructionGenerateApi(Resource):
|
|||
|
||||
@console_ns.route("/instruction-generate/template")
|
||||
class InstructionGenerationTemplateApi(Resource):
|
||||
@api.doc("get_instruction_template")
|
||||
@api.doc(description="Get instruction generation template")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("get_instruction_template")
|
||||
@console_ns.doc(description="Get instruction generation template")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InstructionTemplateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Template instruction"),
|
||||
|
|
@ -294,8 +294,8 @@ class InstructionGenerationTemplateApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Template retrieved successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.response(200, "Template retrieved successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from enum import StrEnum
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -12,6 +12,9 @@ from fields.app_fields import app_server_fields
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||
|
||||
|
||||
class AppMCPServerStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
|
|
@ -20,24 +23,24 @@ class AppMCPServerStatus(StrEnum):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/server")
|
||||
class AppMCPServerController(Resource):
|
||||
@api.doc("get_app_mcp_server")
|
||||
@api.doc(description="Get MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
|
||||
@console_ns.doc("get_app_mcp_server")
|
||||
@console_ns.doc(description="Get MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model)
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
@marshal_with(app_server_model)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
|
||||
return server
|
||||
|
||||
@api.doc("create_app_mcp_server")
|
||||
@api.doc(description="Create MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MCPServerCreateRequest",
|
||||
{
|
||||
"description": fields.String(description="Server description"),
|
||||
|
|
@ -45,13 +48,13 @@ class AppMCPServerController(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "MCP server configuration created successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@login_required
|
||||
@setup_required
|
||||
@marshal_with(app_server_fields)
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
|
@ -79,11 +82,11 @@ class AppMCPServerController(Resource):
|
|||
db.session.commit()
|
||||
return server
|
||||
|
||||
@api.doc("update_app_mcp_server")
|
||||
@api.doc(description="Update MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_app_mcp_server")
|
||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MCPServerUpdateRequest",
|
||||
{
|
||||
"id": fields.String(required=True, description="Server ID"),
|
||||
|
|
@ -93,14 +96,14 @@ class AppMCPServerController(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Server not found")
|
||||
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@get_app_model
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_fields)
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
parser = (
|
||||
|
|
@ -134,16 +137,16 @@ class AppMCPServerController(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
|
||||
class AppMCPServerRefreshController(Resource):
|
||||
@api.doc("refresh_app_mcp_server")
|
||||
@api.doc(description="Refresh MCP server configuration and regenerate server code")
|
||||
@api.doc(params={"server_id": "Server ID"})
|
||||
@api.response(200, "MCP server refreshed successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Server not found")
|
||||
@console_ns.doc("refresh_app_mcp_server")
|
||||
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
|
||||
@console_ns.doc(params={"server_id": "Server ID"})
|
||||
@console_ns.response(200, "MCP server refreshed successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_fields)
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask_restx.inputs import int_range
|
|||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
|
|
@ -23,8 +23,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from libs.helper import uuid_value
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
|
|
@ -34,31 +34,142 @@ from services.message_service import MessageService
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model(
|
||||
"SimpleAccount",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
message_file_model = console_ns.model(
|
||||
"MessageFile",
|
||||
{
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
},
|
||||
)
|
||||
|
||||
agent_thought_model = console_ns.model(
|
||||
"AgentThought",
|
||||
{
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
},
|
||||
)
|
||||
|
||||
# Models that depend on simple_account_model
|
||||
feedback_model = console_ns.model(
|
||||
"Feedback",
|
||||
{
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
},
|
||||
)
|
||||
|
||||
annotation_model = console_ns.model(
|
||||
"Annotation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
annotation_hit_history_model = console_ns.model(
|
||||
"AnnotationHitHistory",
|
||||
{
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
# Message detail model that depends on multiple models
|
||||
message_detail_model = console_ns.model(
|
||||
"MessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_model)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Message infinite scroll pagination model
|
||||
message_infinite_scroll_pagination_model = console_ns.model(
|
||||
"MessageInfiniteScrollPagination",
|
||||
{
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_detail_model)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
|
||||
class ChatMessageListApi(Resource):
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_detail_fields)),
|
||||
}
|
||||
|
||||
@api.doc("list_chat_messages")
|
||||
@api.doc(description="Get chat messages for a conversation with pagination")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
@console_ns.doc("list_chat_messages")
|
||||
@console_ns.doc(description="Get chat messages for a conversation with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
|
||||
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
|
||||
)
|
||||
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
|
||||
@api.response(404, "Conversation not found")
|
||||
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
parser = (
|
||||
|
|
@ -132,11 +243,11 @@ class ChatMessageListApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
class MessageFeedbackApi(Resource):
|
||||
@api.doc("create_message_feedback")
|
||||
@api.doc(description="Create or update message feedback (like/dislike)")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_message_feedback")
|
||||
@console_ns.doc(description="Create or update message feedback (like/dislike)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MessageFeedbackRequest",
|
||||
{
|
||||
"message_id": fields.String(required=True, description="Message ID"),
|
||||
|
|
@ -144,9 +255,9 @@ class MessageFeedbackApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Feedback updated successfully")
|
||||
@api.response(404, "Message not found")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "Feedback updated successfully")
|
||||
@console_ns.response(404, "Message not found")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -194,13 +305,13 @@ class MessageFeedbackApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
@api.doc("get_annotation_count")
|
||||
@api.doc(description="Get count of message annotations for the app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(
|
||||
@console_ns.doc("get_annotation_count")
|
||||
@console_ns.doc(description="Get count of message annotations for the app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotation count retrieved successfully",
|
||||
api.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
|
||||
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
|
|
@ -214,15 +325,17 @@ class MessageAnnotationCountApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
|
||||
class MessageSuggestedQuestionApi(Resource):
|
||||
@api.doc("get_message_suggested_questions")
|
||||
@api.doc(description="Get suggested questions for a message")
|
||||
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@api.response(
|
||||
@console_ns.doc("get_message_suggested_questions")
|
||||
@console_ns.doc(description="Get suggested questions for a message")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Suggested questions retrieved successfully",
|
||||
api.model("SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}),
|
||||
console_ns.model(
|
||||
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
|
||||
),
|
||||
)
|
||||
@api.response(404, "Message or conversation not found")
|
||||
@console_ns.response(404, "Message or conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -258,16 +371,16 @@ class MessageSuggestedQuestionApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
|
||||
class MessageApi(Resource):
|
||||
@api.doc("get_message")
|
||||
@api.doc(description="Get message details by ID")
|
||||
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@api.response(200, "Message retrieved successfully", message_detail_fields)
|
||||
@api.response(404, "Message not found")
|
||||
@console_ns.doc("get_message")
|
||||
@console_ns.doc(description="Get message details by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@console_ns.response(200, "Message retrieved successfully", message_detail_model)
|
||||
@console_ns.response(404, "Message not found")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(message_detail_fields)
|
||||
@marshal_with(message_detail_model)
|
||||
def get(self, app_model, message_id: str):
|
||||
message_id = str(message_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import cast
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.agent.entities import AgentToolEntity
|
||||
|
|
@ -20,11 +20,11 @@ from services.app_model_config_service import AppModelConfigService
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/model-config")
|
||||
class ModelConfigResource(Resource):
|
||||
@api.doc("update_app_model_config")
|
||||
@api.doc(description="Update application model configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_app_model_config")
|
||||
@console_ns.doc(description="Update application model configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ModelConfigRequest",
|
||||
{
|
||||
"provider": fields.String(description="Model provider"),
|
||||
|
|
@ -42,9 +42,9 @@ class ModelConfigResource(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Model configuration updated successfully")
|
||||
@api.response(400, "Invalid configuration")
|
||||
@api.response(404, "App not found")
|
||||
@console_ns.response(200, "Model configuration updated successfully")
|
||||
@console_ns.response(400, "Invalid configuration")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
|
|
@ -14,18 +14,18 @@ class TraceAppConfigApi(Resource):
|
|||
Manage trace app configurations
|
||||
"""
|
||||
|
||||
@api.doc("get_trace_app_config")
|
||||
@api.doc(description="Get tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
@console_ns.doc("get_trace_app_config")
|
||||
@console_ns.doc(description="Get tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
|
||||
)
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -41,11 +41,11 @@ class TraceAppConfigApi(Resource):
|
|||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@api.doc("create_trace_app_config")
|
||||
@api.doc(description="Create a new tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_trace_app_config")
|
||||
@console_ns.doc(description="Create a new tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TraceConfigCreateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
|
|
@ -53,10 +53,10 @@ class TraceAppConfigApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
|
||||
)
|
||||
@api.response(400, "Invalid request parameters or configuration already exists")
|
||||
@console_ns.response(400, "Invalid request parameters or configuration already exists")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -81,11 +81,11 @@ class TraceAppConfigApi(Resource):
|
|||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@api.doc("update_trace_app_config")
|
||||
@api.doc(description="Update an existing tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_trace_app_config")
|
||||
@console_ns.doc(description="Update an existing tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TraceConfigUpdateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
|
|
@ -93,8 +93,8 @@ class TraceAppConfigApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
||||
@api.response(400, "Invalid request parameters or configuration not found")
|
||||
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -117,16 +117,16 @@ class TraceAppConfigApi(Resource):
|
|||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@api.doc("delete_trace_app_config")
|
||||
@api.doc(description="Delete an existing tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
@console_ns.doc("delete_trace_app_config")
|
||||
@console_ns.doc(description="Delete an existing tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@api.response(204, "Tracing configuration deleted successfully")
|
||||
@api.response(400, "Invalid request parameters or configuration not found")
|
||||
@console_ns.response(204, "Tracing configuration deleted successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from flask_restx import Resource, fields, marshal_with, reqparse
|
|||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
|
|
@ -16,6 +16,9 @@ from libs.datetime_utils import naive_utc_now
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Site
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_site_model = console_ns.model("AppSite", app_site_fields)
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
parser = (
|
||||
|
|
@ -48,11 +51,11 @@ def parse_app_site_args():
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site")
|
||||
class AppSite(Resource):
|
||||
@api.doc("update_app_site")
|
||||
@api.doc(description="Update application site configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_app_site")
|
||||
@console_ns.doc(description="Update application site configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppSiteRequest",
|
||||
{
|
||||
"title": fields.String(description="Site title"),
|
||||
|
|
@ -76,15 +79,15 @@ class AppSite(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Site configuration updated successfully", app_site_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "App not found")
|
||||
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_fields)
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
args = parse_app_site_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
|
@ -123,18 +126,18 @@ class AppSite(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
|
||||
class AppSiteAccessTokenReset(Resource):
|
||||
@api.doc("reset_app_site_access_token")
|
||||
@api.doc(description="Reset access token for application site")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Access token reset successfully", app_site_fields)
|
||||
@api.response(403, "Insufficient permissions (admin/owner required)")
|
||||
@api.response(404, "App or site not found")
|
||||
@console_ns.doc("reset_app_site_access_token")
|
||||
@console_ns.doc(description="Reset access token for application site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Access token reset successfully", app_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions (admin/owner required)")
|
||||
@console_ns.response(404, "App or site not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_fields)
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sqlalchemy as sa
|
|||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
@ -17,15 +17,15 @@ from models import AppMode
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
class DailyMessageStatistic(Resource):
|
||||
@api.doc("get_daily_message_statistics")
|
||||
@api.doc(description="Get daily message statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
@console_ns.doc("get_daily_message_statistics")
|
||||
@console_ns.doc(description="Get daily message statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily message statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily message count data")),
|
||||
|
|
@ -90,11 +90,11 @@ parser = (
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
class DailyConversationStatistic(Resource):
|
||||
@api.doc("get_daily_conversation_statistics")
|
||||
@api.doc(description="Get daily conversation statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
@console_ns.doc("get_daily_conversation_statistics")
|
||||
@console_ns.doc(description="Get daily conversation statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily conversation statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily conversation count data")),
|
||||
|
|
@ -146,11 +146,11 @@ WHERE
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-end-users")
|
||||
class DailyTerminalsStatistic(Resource):
|
||||
@api.doc("get_daily_terminals_statistics")
|
||||
@api.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
@console_ns.doc("get_daily_terminals_statistics")
|
||||
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily terminal statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily terminal count data")),
|
||||
|
|
@ -203,11 +203,11 @@ WHERE
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/token-costs")
|
||||
class DailyTokenCostStatistic(Resource):
|
||||
@api.doc("get_daily_token_cost_statistics")
|
||||
@api.doc(description="Get daily token cost statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
@console_ns.doc("get_daily_token_cost_statistics")
|
||||
@console_ns.doc(description="Get daily token cost statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily token cost statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily token cost data")),
|
||||
|
|
@ -263,11 +263,11 @@ WHERE
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/average-session-interactions")
|
||||
class AverageSessionInteractionStatistic(Resource):
|
||||
@api.doc("get_average_session_interaction_statistics")
|
||||
@api.doc(description="Get average session interaction statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
@console_ns.doc("get_average_session_interaction_statistics")
|
||||
@console_ns.doc(description="Get average session interaction statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Average session interaction statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Average session interaction data")),
|
||||
|
|
@ -339,11 +339,11 @@ ORDER BY
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
|
||||
class UserSatisfactionRateStatistic(Resource):
|
||||
@api.doc("get_user_satisfaction_rate_statistics")
|
||||
@api.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
@console_ns.doc("get_user_satisfaction_rate_statistics")
|
||||
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"User satisfaction rate statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="User satisfaction rate data")),
|
||||
|
|
@ -405,11 +405,11 @@ WHERE
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/average-response-time")
|
||||
class AverageResponseTimeStatistic(Resource):
|
||||
@api.doc("get_average_response_time_statistics")
|
||||
@api.doc(description="Get average response time statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
@console_ns.doc("get_average_response_time_statistics")
|
||||
@console_ns.doc(description="Get average response time statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Average response time statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Average response time data")),
|
||||
|
|
@ -462,11 +462,11 @@ WHERE
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/tokens-per-second")
|
||||
class TokensPerSecondStatistic(Resource):
|
||||
@api.doc("get_tokens_per_second_statistics")
|
||||
@api.doc(description="Get tokens per second statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
@console_ns.doc("get_tokens_per_second_statistics")
|
||||
@console_ns.doc(description="Get tokens per second statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Tokens per second statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Tokens per second data")),
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
|
@ -32,6 +32,7 @@ from core.workflow.enums import NodeType
|
|||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
|
|
@ -49,6 +50,56 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
|
|||
logger = logging.getLogger(__name__)
|
||||
LISTENING_RETRY_IN = 2000
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model("SimpleAccount", simple_account_fields)
|
||||
|
||||
from fields.workflow_fields import pipeline_variable_fields, serialize_value_type
|
||||
|
||||
conversation_variable_model = console_ns.model(
|
||||
"ConversationVariable",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"value_type": fields.String(attribute=serialize_value_type),
|
||||
"value": fields.Raw,
|
||||
"description": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
pipeline_variable_model = console_ns.model("PipelineVariable", pipeline_variable_fields)
|
||||
|
||||
# Workflow model with nested dependencies
|
||||
workflow_fields_copy = workflow_fields.copy()
|
||||
workflow_fields_copy["created_by"] = fields.Nested(simple_account_model, attribute="created_by_account")
|
||||
workflow_fields_copy["updated_by"] = fields.Nested(
|
||||
simple_account_model, attribute="updated_by_account", allow_null=True
|
||||
)
|
||||
workflow_fields_copy["conversation_variables"] = fields.List(fields.Nested(conversation_variable_model))
|
||||
workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipeline_variable_model))
|
||||
workflow_model = console_ns.model("Workflow", workflow_fields_copy)
|
||||
|
||||
# Workflow pagination model
|
||||
workflow_pagination_fields_copy = workflow_pagination_fields.copy()
|
||||
workflow_pagination_fields_copy["items"] = fields.List(fields.Nested(workflow_model), attribute="items")
|
||||
workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagination_fields_copy)
|
||||
|
||||
# Reuse workflow_run_node_execution_model from workflow_run.py if already registered
|
||||
# Otherwise register it here
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
|
||||
try:
|
||||
simple_end_user_model = console_ns.models.get("SimpleEndUser")
|
||||
except (KeyError, AttributeError):
|
||||
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
|
||||
try:
|
||||
workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution")
|
||||
except (KeyError, AttributeError):
|
||||
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
# at the controller level rather than in the workflow logic. This would improve separation
|
||||
|
|
@ -70,16 +121,16 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft")
|
||||
class DraftWorkflowApi(Resource):
|
||||
@api.doc("get_draft_workflow")
|
||||
@api.doc(description="Get draft workflow for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Draft workflow retrieved successfully", workflow_fields)
|
||||
@api.response(404, "Draft workflow not found")
|
||||
@console_ns.doc("get_draft_workflow")
|
||||
@console_ns.doc(description="Get draft workflow for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_fields)
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
|
|
@ -99,10 +150,10 @@ class DraftWorkflowApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@api.doc("sync_draft_workflow")
|
||||
@api.doc(description="Sync draft workflow configuration")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("sync_draft_workflow")
|
||||
@console_ns.doc(description="Sync draft workflow configuration")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SyncDraftWorkflowRequest",
|
||||
{
|
||||
"graph": fields.Raw(required=True, description="Workflow graph configuration"),
|
||||
|
|
@ -113,10 +164,10 @@ class DraftWorkflowApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Draft workflow synced successfully",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"SyncDraftWorkflowResponse",
|
||||
{
|
||||
"result": fields.String,
|
||||
|
|
@ -125,8 +176,8 @@ class DraftWorkflowApi(Resource):
|
|||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid workflow configuration")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.response(400, "Invalid workflow configuration")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
|
|
@ -198,11 +249,11 @@ class DraftWorkflowApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||
class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
@api.doc("run_advanced_chat_draft_workflow")
|
||||
@api.doc(description="Run draft workflow for advanced chat application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("run_advanced_chat_draft_workflow")
|
||||
@console_ns.doc(description="Run draft workflow for advanced chat application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AdvancedChatWorkflowRunRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="User query"),
|
||||
|
|
@ -212,9 +263,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Workflow run started successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.response(200, "Workflow run started successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -262,11 +313,11 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
@api.doc("run_advanced_chat_draft_iteration_node")
|
||||
@api.doc(description="Run draft workflow iteration node for advanced chat")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("run_advanced_chat_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node for advanced chat")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"IterationNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
|
|
@ -274,9 +325,9 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Iteration node run started successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(404, "Node not found")
|
||||
@console_ns.response(200, "Iteration node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -309,11 +360,11 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
@api.doc("run_workflow_draft_iteration_node")
|
||||
@api.doc(description="Run draft workflow iteration node")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("run_workflow_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"WorkflowIterationNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
|
|
@ -321,9 +372,9 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Workflow iteration node run started successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(404, "Node not found")
|
||||
@console_ns.response(200, "Workflow iteration node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -356,11 +407,11 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
@api.doc("run_advanced_chat_draft_loop_node")
|
||||
@api.doc(description="Run draft workflow loop node for advanced chat")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("run_advanced_chat_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node for advanced chat")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"LoopNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
|
|
@ -368,9 +419,9 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Loop node run started successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(404, "Node not found")
|
||||
@console_ns.response(200, "Loop node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -403,11 +454,11 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
@api.doc("run_workflow_draft_loop_node")
|
||||
@api.doc(description="Run draft workflow loop node")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("run_workflow_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"WorkflowLoopNodeRunRequest",
|
||||
{
|
||||
"task_id": fields.String(required=True, description="Task ID"),
|
||||
|
|
@ -415,9 +466,9 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Workflow loop node run started successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(404, "Node not found")
|
||||
@console_ns.response(200, "Workflow loop node run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -450,11 +501,11 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/run")
|
||||
class DraftWorkflowRunApi(Resource):
|
||||
@api.doc("run_draft_workflow")
|
||||
@api.doc(description="Run draft workflow")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("run_draft_workflow")
|
||||
@console_ns.doc(description="Run draft workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowRunRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
|
|
@ -462,8 +513,8 @@ class DraftWorkflowRunApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Draft workflow run started successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.response(200, "Draft workflow run started successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -501,12 +552,12 @@ class DraftWorkflowRunApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class WorkflowTaskStopApi(Resource):
|
||||
@api.doc("stop_workflow_task")
|
||||
@api.doc(description="Stop running workflow task")
|
||||
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID"})
|
||||
@api.response(200, "Task stopped successfully")
|
||||
@api.response(404, "Task not found")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.doc("stop_workflow_task")
|
||||
@console_ns.doc(description="Stop running workflow task")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(404, "Task not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -528,25 +579,25 @@ class WorkflowTaskStopApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class DraftWorkflowNodeRunApi(Resource):
|
||||
@api.doc("run_draft_workflow_node")
|
||||
@api.doc(description="Run draft workflow node")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("run_draft_workflow_node")
|
||||
@console_ns.doc(description="Run draft workflow node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowNodeRunRequest",
|
||||
{
|
||||
"inputs": fields.Raw(description="Input variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Node run started successfully", workflow_run_node_execution_fields)
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(404, "Node not found")
|
||||
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
|
|
@ -595,16 +646,16 @@ parser_publish = (
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
|
||||
class PublishedWorkflowApi(Resource):
|
||||
@api.doc("get_published_workflow")
|
||||
@api.doc(description="Get published workflow for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Published workflow retrieved successfully", workflow_fields)
|
||||
@api.response(404, "Published workflow not found")
|
||||
@console_ns.doc("get_published_workflow")
|
||||
@console_ns.doc(description="Get published workflow for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
|
||||
@console_ns.response(404, "Published workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_fields)
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
|
|
@ -617,7 +668,7 @@ class PublishedWorkflowApi(Resource):
|
|||
# return workflow, if not found, return None
|
||||
return workflow
|
||||
|
||||
@api.expect(parser_publish)
|
||||
@console_ns.expect(parser_publish)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -666,10 +717,10 @@ class PublishedWorkflowApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
class DefaultBlockConfigsApi(Resource):
|
||||
@api.doc("get_default_block_configs")
|
||||
@api.doc(description="Get default block configurations for workflow")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Default block configurations retrieved successfully")
|
||||
@console_ns.doc("get_default_block_configs")
|
||||
@console_ns.doc(description="Get default block configurations for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Default block configurations retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -689,12 +740,12 @@ parser_block = reqparse.RequestParser().add_argument("q", type=str, location="ar
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultBlockConfigApi(Resource):
|
||||
@api.doc("get_default_block_config")
|
||||
@api.doc(description="Get default block configuration by type")
|
||||
@api.doc(params={"app_id": "Application ID", "block_type": "Block type"})
|
||||
@api.response(200, "Default block configuration retrieved successfully")
|
||||
@api.response(404, "Block type not found")
|
||||
@api.expect(parser_block)
|
||||
@console_ns.doc("get_default_block_config")
|
||||
@console_ns.doc(description="Get default block configuration by type")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
|
||||
@console_ns.response(200, "Default block configuration retrieved successfully")
|
||||
@console_ns.response(404, "Block type not found")
|
||||
@console_ns.expect(parser_block)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -731,13 +782,13 @@ parser_convert = (
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
|
||||
class ConvertToWorkflowApi(Resource):
|
||||
@api.expect(parser_convert)
|
||||
@api.doc("convert_to_workflow")
|
||||
@api.doc(description="Convert application to workflow mode")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Application converted to workflow successfully")
|
||||
@api.response(400, "Application cannot be converted")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.expect(parser_convert)
|
||||
@console_ns.doc("convert_to_workflow")
|
||||
@console_ns.doc(description="Convert application to workflow mode")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Application converted to workflow successfully")
|
||||
@console_ns.response(400, "Application cannot be converted")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -777,16 +828,16 @@ parser_workflows = (
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@api.expect(parser_workflows)
|
||||
@api.doc("get_all_published_workflows")
|
||||
@api.doc(description="Get all published workflows for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Published workflows retrieved successfully", workflow_pagination_fields)
|
||||
@console_ns.expect(parser_workflows)
|
||||
@console_ns.doc("get_all_published_workflows")
|
||||
@console_ns.doc(description="Get all published workflows for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Published workflows retrieved successfully", workflow_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_pagination_fields)
|
||||
@marshal_with(workflow_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
|
|
@ -826,11 +877,11 @@ class PublishedAllWorkflowApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@api.doc("update_workflow_by_id")
|
||||
@api.doc(description="Update workflow by ID")
|
||||
@api.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_workflow_by_id")
|
||||
@console_ns.doc(description="Update workflow by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateWorkflowRequest",
|
||||
{
|
||||
"environment_variables": fields.List(fields.Raw, description="Environment variables"),
|
||||
|
|
@ -838,14 +889,14 @@ class WorkflowByIdApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Workflow updated successfully", workflow_fields)
|
||||
@api.response(404, "Workflow not found")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.response(200, "Workflow updated successfully", workflow_model)
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_fields)
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
|
|
@ -926,17 +977,17 @@ class WorkflowByIdApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
class DraftWorkflowNodeLastRunApi(Resource):
|
||||
@api.doc("get_draft_workflow_node_last_run")
|
||||
@api.doc(description="Get last run result for draft workflow node")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.response(200, "Node last run retrieved successfully", workflow_run_node_execution_fields)
|
||||
@api.response(404, "Node last run not found")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.doc("get_draft_workflow_node_last_run")
|
||||
@console_ns.doc(description="Get last run result for draft workflow node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(404, "Node last run not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
srv = WorkflowService()
|
||||
workflow = srv.get_draft_workflow(app_model)
|
||||
|
|
@ -959,20 +1010,20 @@ class DraftWorkflowTriggerRunApi(Resource):
|
|||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_run")
|
||||
@api.doc(description="Poll for trigger events and execute full workflow when event arrives")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("poll_draft_workflow_trigger_run")
|
||||
@console_ns.doc(description="Poll for trigger events and execute full workflow when event arrives")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowTriggerRunRequest",
|
||||
{
|
||||
"node_id": fields.String(required=True, description="Node ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trigger event received and workflow executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@console_ns.response(200, "Trigger event received and workflow executed successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -1033,12 +1084,12 @@ class DraftWorkflowTriggerNodeApi(Resource):
|
|||
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_node")
|
||||
@api.doc(description="Poll for trigger events and execute single node when event arrives")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.response(200, "Trigger event received and node executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@console_ns.doc("poll_draft_workflow_trigger_node")
|
||||
@console_ns.doc(description="Poll for trigger events and execute single node when event arrives")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.response(200, "Trigger event received and node executed successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -1112,20 +1163,20 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
|||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run-all
|
||||
"""
|
||||
|
||||
@api.doc("draft_workflow_trigger_run_all")
|
||||
@api.doc(description="Full workflow debug when the start node is a trigger")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("draft_workflow_trigger_run_all")
|
||||
@console_ns.doc(description="Full workflow debug when the start node is a trigger")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DraftWorkflowTriggerRunAllRequest",
|
||||
{
|
||||
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Workflow executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@console_ns.response(200, "Workflow executed successfully")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -3,24 +3,27 @@ from flask_restx import Resource, marshal_with, reqparse
|
|||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@api.doc("get_workflow_app_logs")
|
||||
@api.doc(description="Get workflow application execution logs")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(
|
||||
@console_ns.doc("get_workflow_app_logs")
|
||||
@console_ns.doc(description="Get workflow application execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"keyword": "Search keyword for filtering logs",
|
||||
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
|
||||
|
|
@ -33,12 +36,12 @@ class WorkflowAppLogApi(Resource):
|
|||
"limit": "Number of items per page (1-100)",
|
||||
}
|
||||
)
|
||||
@api.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_fields)
|
||||
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_app_log_pagination_fields)
|
||||
@marshal_with(workflow_app_log_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow app logs
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from flask import Response
|
|||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
|
|
@ -141,6 +141,37 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
|||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_draft_variable_without_value_model = console_ns.model(
|
||||
"WorkflowDraftVariableWithoutValue", _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS
|
||||
)
|
||||
|
||||
workflow_draft_variable_model = console_ns.model("WorkflowDraftVariable", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
workflow_draft_env_variable_model = console_ns.model("WorkflowDraftEnvVariable", _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)
|
||||
|
||||
workflow_draft_env_variable_list_fields_copy = _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS.copy()
|
||||
workflow_draft_env_variable_list_fields_copy["items"] = fields.List(fields.Nested(workflow_draft_env_variable_model))
|
||||
workflow_draft_env_variable_list_model = console_ns.model(
|
||||
"WorkflowDraftEnvVariableList", workflow_draft_env_variable_list_fields_copy
|
||||
)
|
||||
|
||||
workflow_draft_variable_list_without_value_fields_copy = _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS.copy()
|
||||
workflow_draft_variable_list_without_value_fields_copy["items"] = fields.List(
|
||||
fields.Nested(workflow_draft_variable_without_value_model), attribute=_get_items
|
||||
)
|
||||
workflow_draft_variable_list_without_value_model = console_ns.model(
|
||||
"WorkflowDraftVariableListWithoutValue", workflow_draft_variable_list_without_value_fields_copy
|
||||
)
|
||||
|
||||
workflow_draft_variable_list_fields_copy = _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS.copy()
|
||||
workflow_draft_variable_list_fields_copy["items"] = fields.List(
|
||||
fields.Nested(workflow_draft_variable_model), attribute=_get_items
|
||||
)
|
||||
workflow_draft_variable_list_model = console_ns.model(
|
||||
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
|
@ -170,14 +201,16 @@ def _api_prerequisite(f: Callable[P, R]):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@api.expect(_create_pagination_parser())
|
||||
@api.doc("get_workflow_variables")
|
||||
@api.doc(description="Get draft workflow variables")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
|
||||
@api.response(200, "Workflow variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
@console_ns.expect(_create_pagination_parser())
|
||||
@console_ns.doc("get_workflow_variables")
|
||||
@console_ns.doc(description="Get draft workflow variables")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
|
||||
@console_ns.response(
|
||||
200, "Workflow variables retrieved successfully", workflow_draft_variable_list_without_value_model
|
||||
)
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
|
|
@ -204,9 +237,9 @@ class WorkflowVariableCollectionApi(Resource):
|
|||
|
||||
return workflow_vars
|
||||
|
||||
@api.doc("delete_workflow_variables")
|
||||
@api.doc(description="Delete all draft workflow variables")
|
||||
@api.response(204, "Workflow variables deleted successfully")
|
||||
@console_ns.doc("delete_workflow_variables")
|
||||
@console_ns.doc(description="Delete all draft workflow variables")
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
|
|
@ -237,12 +270,12 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
class NodeVariableCollectionApi(Resource):
|
||||
@api.doc("get_node_variables")
|
||||
@api.doc(description="Get variables for a specific node")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.response(200, "Node variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@console_ns.doc("get_node_variables")
|
||||
@console_ns.doc(description="Get variables for a specific node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
|
|
@ -253,9 +286,9 @@ class NodeVariableCollectionApi(Resource):
|
|||
|
||||
return node_vars
|
||||
|
||||
@api.doc("delete_node_variables")
|
||||
@api.doc(description="Delete all variables for a specific node")
|
||||
@api.response(204, "Node variables deleted successfully")
|
||||
@console_ns.doc("delete_node_variables")
|
||||
@console_ns.doc(description="Delete all variables for a specific node")
|
||||
@console_ns.response(204, "Node variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
|
|
@ -270,13 +303,13 @@ class VariableApi(Resource):
|
|||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
||||
@api.doc("get_variable")
|
||||
@api.doc(description="Get a specific workflow variable")
|
||||
@api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
|
||||
@api.response(200, "Variable retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@api.response(404, "Variable not found")
|
||||
@console_ns.doc("get_variable")
|
||||
@console_ns.doc(description="Get a specific workflow variable")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
|
||||
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
|
|
@ -288,10 +321,10 @@ class VariableApi(Resource):
|
|||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@api.doc("update_variable")
|
||||
@api.doc(description="Update a workflow variable")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_variable")
|
||||
@console_ns.doc(description="Update a workflow variable")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateVariableRequest",
|
||||
{
|
||||
"name": fields.String(description="Variable name"),
|
||||
|
|
@ -299,10 +332,10 @@ class VariableApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Variable updated successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@api.response(404, "Variable not found")
|
||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, app_model: App, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
|
|
@ -364,10 +397,10 @@ class VariableApi(Resource):
|
|||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@api.doc("delete_variable")
|
||||
@api.doc(description="Delete a workflow variable")
|
||||
@api.response(204, "Variable deleted successfully")
|
||||
@api.response(404, "Variable not found")
|
||||
@console_ns.doc("delete_variable")
|
||||
@console_ns.doc(description="Delete a workflow variable")
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
|
|
@ -385,12 +418,12 @@ class VariableApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class VariableResetApi(Resource):
|
||||
@api.doc("reset_variable")
|
||||
@api.doc(description="Reset a workflow variable to its default value")
|
||||
@api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
|
||||
@api.response(200, "Variable reset successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@api.response(204, "Variable reset (no content)")
|
||||
@api.response(404, "Variable not found")
|
||||
@console_ns.doc("reset_variable")
|
||||
@console_ns.doc(description="Reset a workflow variable to its default value")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
|
||||
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
|
|
@ -414,7 +447,7 @@ class VariableResetApi(Resource):
|
|||
if resetted is None:
|
||||
return Response("", 204)
|
||||
else:
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
return marshal(resetted, workflow_draft_variable_model)
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
|
|
@ -433,13 +466,13 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||
class ConversationVariableCollectionApi(Resource):
|
||||
@api.doc("get_conversation_variables")
|
||||
@api.doc(description="Get conversation variables for workflow")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Conversation variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@api.response(404, "Draft workflow not found")
|
||||
@console_ns.doc("get_conversation_variables")
|
||||
@console_ns.doc(description="Get conversation variables for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App):
|
||||
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
|
||||
# so their IDs can be returned to the caller.
|
||||
|
|
@ -455,23 +488,23 @@ class ConversationVariableCollectionApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@api.doc("get_system_variables")
|
||||
@api.doc(description="Get system variables for workflow")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "System variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@console_ns.doc("get_system_variables")
|
||||
@console_ns.doc(description="Get system variables for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
class EnvironmentVariableCollectionApi(Resource):
|
||||
@api.doc("get_environment_variables")
|
||||
@api.doc(description="Get environment variables for workflow")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Environment variables retrieved successfully")
|
||||
@api.response(404, "Draft workflow not found")
|
||||
@console_ns.doc("get_environment_variables")
|
||||
@console_ns.doc(description="Get environment variables for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Environment variables retrieved successfully")
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_api_prerequisite
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,15 +1,20 @@
|
|||
from typing import cast
|
||||
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_run_fields import (
|
||||
advanced_chat_workflow_run_for_list_fields,
|
||||
advanced_chat_workflow_run_pagination_fields,
|
||||
workflow_run_count_fields,
|
||||
workflow_run_detail_fields,
|
||||
workflow_run_for_list_fields,
|
||||
workflow_run_node_execution_fields,
|
||||
workflow_run_node_execution_list_fields,
|
||||
workflow_run_pagination_fields,
|
||||
)
|
||||
|
|
@ -22,6 +27,71 @@ from services.workflow_run_service import WorkflowRunService
|
|||
# Workflow run status choices for filtering
|
||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model("SimpleAccount", simple_account_fields)
|
||||
|
||||
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
|
||||
# Models that depend on simple_account_fields
|
||||
workflow_run_for_list_fields_copy = workflow_run_for_list_fields.copy()
|
||||
workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
workflow_run_for_list_model = console_ns.model("WorkflowRunForList", workflow_run_for_list_fields_copy)
|
||||
|
||||
advanced_chat_workflow_run_for_list_fields_copy = advanced_chat_workflow_run_for_list_fields.copy()
|
||||
advanced_chat_workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
advanced_chat_workflow_run_for_list_model = console_ns.model(
|
||||
"AdvancedChatWorkflowRunForList", advanced_chat_workflow_run_for_list_fields_copy
|
||||
)
|
||||
|
||||
workflow_run_detail_fields_copy = workflow_run_detail_fields.copy()
|
||||
workflow_run_detail_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
workflow_run_detail_fields_copy["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
workflow_run_detail_model = console_ns.model("WorkflowRunDetail", workflow_run_detail_fields_copy)
|
||||
|
||||
workflow_run_node_execution_fields_copy = workflow_run_node_execution_fields.copy()
|
||||
workflow_run_node_execution_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
workflow_run_node_execution_fields_copy["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
workflow_run_node_execution_model = console_ns.model(
|
||||
"WorkflowRunNodeExecution", workflow_run_node_execution_fields_copy
|
||||
)
|
||||
|
||||
# Simple models without nested dependencies
|
||||
workflow_run_count_model = console_ns.model("WorkflowRunCount", workflow_run_count_fields)
|
||||
|
||||
# Pagination models that depend on list models
|
||||
advanced_chat_workflow_run_pagination_fields_copy = advanced_chat_workflow_run_pagination_fields.copy()
|
||||
advanced_chat_workflow_run_pagination_fields_copy["data"] = fields.List(
|
||||
fields.Nested(advanced_chat_workflow_run_for_list_model), attribute="data"
|
||||
)
|
||||
advanced_chat_workflow_run_pagination_model = console_ns.model(
|
||||
"AdvancedChatWorkflowRunPagination", advanced_chat_workflow_run_pagination_fields_copy
|
||||
)
|
||||
|
||||
workflow_run_pagination_fields_copy = workflow_run_pagination_fields.copy()
|
||||
workflow_run_pagination_fields_copy["data"] = fields.List(fields.Nested(workflow_run_for_list_model), attribute="data")
|
||||
workflow_run_pagination_model = console_ns.model("WorkflowRunPagination", workflow_run_pagination_fields_copy)
|
||||
|
||||
workflow_run_node_execution_list_fields_copy = workflow_run_node_execution_list_fields.copy()
|
||||
workflow_run_node_execution_list_fields_copy["data"] = fields.List(fields.Nested(workflow_run_node_execution_model))
|
||||
workflow_run_node_execution_list_model = console_ns.model(
|
||||
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
||||
)
|
||||
|
||||
|
||||
def _parse_workflow_run_list_args():
|
||||
"""
|
||||
|
|
@ -90,18 +160,22 @@ def _parse_workflow_run_count_args():
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@api.doc("get_advanced_chat_workflow_runs")
|
||||
@api.doc(description="Get advanced chat workflow run list")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
|
||||
@console_ns.doc("get_advanced_chat_workflow_runs")
|
||||
@console_ns.doc(description="Get advanced chat workflow run list")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(advanced_chat_workflow_run_pagination_fields)
|
||||
@marshal_with(advanced_chat_workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
|
|
@ -125,11 +199,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
|
||||
class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@api.doc("get_advanced_chat_workflow_runs_count")
|
||||
@api.doc(description="Get advanced chat workflow runs count statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||
@api.doc(
|
||||
@console_ns.doc("get_advanced_chat_workflow_runs_count")
|
||||
@console_ns.doc(description="Get advanced chat workflow runs count statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"time_range": (
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
|
|
@ -137,13 +213,15 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
|||
)
|
||||
}
|
||||
)
|
||||
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(workflow_run_count_fields)
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get advanced chat workflow runs count statistics
|
||||
|
|
@ -170,18 +248,22 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs")
|
||||
class WorkflowRunListApi(Resource):
|
||||
@api.doc("get_workflow_runs")
|
||||
@api.doc(description="Get workflow run list")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
|
||||
@console_ns.doc("get_workflow_runs")
|
||||
@console_ns.doc(description="Get workflow run list")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_pagination_fields)
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow run list
|
||||
|
|
@ -205,11 +287,13 @@ class WorkflowRunListApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
|
||||
class WorkflowRunCountApi(Resource):
|
||||
@api.doc("get_workflow_runs_count")
|
||||
@api.doc(description="Get workflow runs count statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||
@api.doc(
|
||||
@console_ns.doc("get_workflow_runs_count")
|
||||
@console_ns.doc(description="Get workflow runs count statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"time_range": (
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
|
|
@ -217,13 +301,15 @@ class WorkflowRunCountApi(Resource):
|
|||
)
|
||||
}
|
||||
)
|
||||
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_count_fields)
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow runs count statistics
|
||||
|
|
@ -250,16 +336,16 @@ class WorkflowRunCountApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
|
||||
class WorkflowRunDetailApi(Resource):
|
||||
@api.doc("get_workflow_run_detail")
|
||||
@api.doc(description="Get workflow run detail")
|
||||
@api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@api.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_fields)
|
||||
@api.response(404, "Workflow run not found")
|
||||
@console_ns.doc("get_workflow_run_detail")
|
||||
@console_ns.doc(description="Get workflow run detail")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_detail_fields)
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
Get workflow run detail
|
||||
|
|
@ -274,16 +360,16 @@ class WorkflowRunDetailApi(Resource):
|
|||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@api.doc("get_workflow_run_node_executions")
|
||||
@api.doc(description="Get workflow run node execution list")
|
||||
@api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@api.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_fields)
|
||||
@api.response(404, "Workflow run not found")
|
||||
@console_ns.doc("get_workflow_run_node_executions")
|
||||
@console_ns.doc(description="Get workflow run node execution list")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_list_fields)
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from flask import abort, jsonify
|
|||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -21,11 +21,13 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_daily_runs_statistic")
|
||||
@api.doc(description="Get workflow daily runs statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
|
||||
@api.response(200, "Daily runs statistics retrieved successfully")
|
||||
@console_ns.doc("get_workflow_daily_runs_statistic")
|
||||
@console_ns.doc(description="Get workflow daily runs statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.response(200, "Daily runs statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -66,11 +68,13 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_daily_terminals_statistic")
|
||||
@api.doc(description="Get workflow daily terminals statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
|
||||
@api.response(200, "Daily terminals statistics retrieved successfully")
|
||||
@console_ns.doc("get_workflow_daily_terminals_statistic")
|
||||
@console_ns.doc(description="Get workflow daily terminals statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -111,11 +115,13 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_daily_token_cost_statistic")
|
||||
@api.doc(description="Get workflow daily token cost statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
|
||||
@api.response(200, "Daily token cost statistics retrieved successfully")
|
||||
@console_ns.doc("get_workflow_daily_token_cost_statistic")
|
||||
@console_ns.doc(description="Get workflow daily token cost statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -156,11 +162,13 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_average_app_interaction_statistic")
|
||||
@api.doc(description="Get workflow average app interaction statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
|
||||
@api.response(200, "Average app interaction statistics retrieved successfully")
|
||||
@console_ns.doc("get_workflow_average_app_interaction_statistic")
|
||||
@console_ns.doc(description="Get workflow average app interaction statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
|
||||
)
|
||||
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -139,6 +139,6 @@ class AppTriggerEnableApi(Resource):
|
|||
return trigger
|
||||
|
||||
|
||||
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||
console_ns.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
console_ns.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||
console_ns.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from flask import request
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
|
@ -20,13 +20,13 @@ active_check_parser = (
|
|||
|
||||
@console_ns.route("/activate/check")
|
||||
class ActivateCheckApi(Resource):
|
||||
@api.doc("check_activation_token")
|
||||
@api.doc(description="Check if activation token is valid")
|
||||
@api.expect(active_check_parser)
|
||||
@api.response(
|
||||
@console_ns.doc("check_activation_token")
|
||||
@console_ns.doc(description="Check if activation token is valid")
|
||||
@console_ns.expect(active_check_parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"ActivationCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether token is valid"),
|
||||
|
|
@ -69,13 +69,13 @@ active_parser = (
|
|||
|
||||
@console_ns.route("/activate")
|
||||
class ActivateApi(Resource):
|
||||
@api.doc("activate_account")
|
||||
@api.doc(description="Activate account with invitation token")
|
||||
@api.expect(active_parser)
|
||||
@api.response(
|
||||
@console_ns.doc("activate_account")
|
||||
@console_ns.doc(description="Activate account with invitation token")
|
||||
@console_ns.expect(active_parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Account activated successfully",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"ActivationResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
|
|
@ -83,7 +83,7 @@ class ActivateApi(Resource):
|
|||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Already activated or invalid token")
|
||||
@console_ns.response(400, "Already activated or invalid token")
|
||||
def post(self):
|
||||
args = active_parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask import current_app, redirect, request
|
|||
from flask_restx import Resource, fields
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
|
@ -29,19 +29,19 @@ def get_oauth_providers():
|
|||
|
||||
@console_ns.route("/oauth/data-source/<string:provider>")
|
||||
class OAuthDataSource(Resource):
|
||||
@api.doc("oauth_data_source")
|
||||
@api.doc(description="Get OAuth authorization URL for data source provider")
|
||||
@api.doc(params={"provider": "Data source provider name (notion)"})
|
||||
@api.response(
|
||||
@console_ns.doc("oauth_data_source")
|
||||
@console_ns.doc(description="Get OAuth authorization URL for data source provider")
|
||||
@console_ns.doc(params={"provider": "Data source provider name (notion)"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Authorization URL or internal setup success",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"OAuthDataSourceResponse",
|
||||
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid provider")
|
||||
@api.response(403, "Admin privileges required")
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@is_admin_or_owner_required
|
||||
def get(self, provider: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
|
|
@ -63,17 +63,17 @@ class OAuthDataSource(Resource):
|
|||
|
||||
@console_ns.route("/oauth/data-source/callback/<string:provider>")
|
||||
class OAuthDataSourceCallback(Resource):
|
||||
@api.doc("oauth_data_source_callback")
|
||||
@api.doc(description="Handle OAuth callback from data source provider")
|
||||
@api.doc(
|
||||
@console_ns.doc("oauth_data_source_callback")
|
||||
@console_ns.doc(description="Handle OAuth callback from data source provider")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"provider": "Data source provider name (notion)",
|
||||
"code": "Authorization code from OAuth provider",
|
||||
"error": "Error message from OAuth provider",
|
||||
}
|
||||
)
|
||||
@api.response(302, "Redirect to console with result")
|
||||
@api.response(400, "Invalid provider")
|
||||
@console_ns.response(302, "Redirect to console with result")
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
|
|
@ -94,17 +94,17 @@ class OAuthDataSourceCallback(Resource):
|
|||
|
||||
@console_ns.route("/oauth/data-source/binding/<string:provider>")
|
||||
class OAuthDataSourceBinding(Resource):
|
||||
@api.doc("oauth_data_source_binding")
|
||||
@api.doc(description="Bind OAuth data source with authorization code")
|
||||
@api.doc(
|
||||
@console_ns.doc("oauth_data_source_binding")
|
||||
@console_ns.doc(description="Bind OAuth data source with authorization code")
|
||||
@console_ns.doc(
|
||||
params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"}
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Data source binding success",
|
||||
api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@api.response(400, "Invalid provider or code")
|
||||
@console_ns.response(400, "Invalid provider or code")
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
|
|
@ -128,15 +128,15 @@ class OAuthDataSourceBinding(Resource):
|
|||
|
||||
@console_ns.route("/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
|
||||
class OAuthDataSourceSync(Resource):
|
||||
@api.doc("oauth_data_source_sync")
|
||||
@api.doc(description="Sync data from OAuth data source")
|
||||
@api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
|
||||
@api.response(
|
||||
@console_ns.doc("oauth_data_source_sync")
|
||||
@console_ns.doc(description="Sync data from OAuth data source")
|
||||
@console_ns.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Data source sync success",
|
||||
api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@api.response(400, "Invalid provider or sync failed")
|
||||
@console_ns.response(400, "Invalid provider or sync failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from flask_restx import Resource, fields, reqparse
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
|
|
@ -27,10 +27,10 @@ from services.feature_service import FeatureService
|
|||
|
||||
@console_ns.route("/forgot-password")
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@api.doc("send_forgot_password_email")
|
||||
@api.doc(description="Send password reset email")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("send_forgot_password_email")
|
||||
@console_ns.doc(description="Send password reset email")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordEmailRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
|
|
@ -38,10 +38,10 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Email sent successfully",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"ForgotPasswordEmailResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
|
|
@ -50,7 +50,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid email or rate limit exceeded")
|
||||
@console_ns.response(400, "Invalid email or rate limit exceeded")
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
|
|
@ -85,10 +85,10 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
|
||||
@console_ns.route("/forgot-password/validity")
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@api.doc("check_forgot_password_code")
|
||||
@api.doc(description="Verify password reset code")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("check_forgot_password_code")
|
||||
@console_ns.doc(description="Verify password reset code")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordCheckRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
|
|
@ -97,10 +97,10 @@ class ForgotPasswordCheckApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Code verified successfully",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"ForgotPasswordCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether code is valid"),
|
||||
|
|
@ -109,7 +109,7 @@ class ForgotPasswordCheckApi(Resource):
|
|||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid code or token")
|
||||
@console_ns.response(400, "Invalid code or token")
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
|
|
@ -152,10 +152,10 @@ class ForgotPasswordCheckApi(Resource):
|
|||
|
||||
@console_ns.route("/forgot-password/resets")
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@api.doc("reset_password")
|
||||
@api.doc(description="Reset password with verification token")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("reset_password")
|
||||
@console_ns.doc(description="Reset password with verification token")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordResetRequest",
|
||||
{
|
||||
"token": fields.String(required=True, description="Verification token"),
|
||||
|
|
@ -164,12 +164,12 @@ class ForgotPasswordResetApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Password reset successfully",
|
||||
api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@api.response(400, "Invalid token or password mismatch")
|
||||
@console_ns.response(400, "Invalid token or password mismatch")
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError
|
|||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import api, console_ns
|
||||
from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -56,11 +56,13 @@ def get_oauth_providers():
|
|||
|
||||
@console_ns.route("/oauth/login/<provider>")
|
||||
class OAuthLogin(Resource):
|
||||
@api.doc("oauth_login")
|
||||
@api.doc(description="Initiate OAuth login process")
|
||||
@api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"})
|
||||
@api.response(302, "Redirect to OAuth authorization URL")
|
||||
@api.response(400, "Invalid provider")
|
||||
@console_ns.doc("oauth_login")
|
||||
@console_ns.doc(description="Initiate OAuth login process")
|
||||
@console_ns.doc(
|
||||
params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"}
|
||||
)
|
||||
@console_ns.response(302, "Redirect to OAuth authorization URL")
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
def get(self, provider: str):
|
||||
invite_token = request.args.get("invite_token") or None
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
|
|
@ -75,17 +77,17 @@ class OAuthLogin(Resource):
|
|||
|
||||
@console_ns.route("/oauth/authorize/<provider>")
|
||||
class OAuthCallback(Resource):
|
||||
@api.doc("oauth_callback")
|
||||
@api.doc(description="Handle OAuth callback and complete login process")
|
||||
@api.doc(
|
||||
@console_ns.doc("oauth_callback")
|
||||
@console_ns.doc(description="Handle OAuth callback and complete login process")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"provider": "OAuth provider name (github/google)",
|
||||
"code": "Authorization code from OAuth provider",
|
||||
"state": "Optional state parameter (used for invite token)",
|
||||
}
|
||||
)
|
||||
@api.response(302, "Redirect to console with access token")
|
||||
@api.response(400, "OAuth process failed")
|
||||
@console_ns.response(302, "Redirect to console with access token")
|
||||
@console_ns.response(400, "OAuth process failed")
|
||||
def get(self, provider: str):
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import base64
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -48,17 +48,17 @@ class Invoices(Resource):
|
|||
|
||||
@console_ns.route("/billing/partners/<string:partner_key>/tenants")
|
||||
class PartnerTenants(Resource):
|
||||
@api.doc("sync_partner_tenants_bindings")
|
||||
@api.doc(description="Sync partner tenants bindings")
|
||||
@api.doc(params={"partner_key": "Partner key"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("sync_partner_tenants_bindings")
|
||||
@console_ns.doc(description="Sync partner tenants bindings")
|
||||
@console_ns.doc(params={"partner_key": "Partner key"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SyncPartnerTenantsBindingsRequest",
|
||||
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Tenants synced to partner successfully")
|
||||
@api.response(400, "Invalid partner information")
|
||||
@console_ns.response(200, "Tenants synced to partner successfully")
|
||||
@console_ns.response(400, "Invalid partner information")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -7,8 +7,11 @@ from werkzeug.exceptions import Forbidden, NotFound
|
|||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.apikey import api_key_fields, api_key_list
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import (
|
||||
api_key_item_model,
|
||||
api_key_list_model,
|
||||
)
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from controllers.console.wraps import (
|
||||
|
|
@ -27,8 +30,22 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
|
|||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.app_fields import app_detail_kernel_fields, related_app_list
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
dataset_fields,
|
||||
dataset_query_detail_fields,
|
||||
dataset_retrieval_model_fields,
|
||||
doc_metadata_fields,
|
||||
external_knowledge_info_fields,
|
||||
external_retrieval_model_fields,
|
||||
icon_info_fields,
|
||||
keyword_setting_fields,
|
||||
reranking_model_fields,
|
||||
tag_fields,
|
||||
vector_setting_fields,
|
||||
weighted_score_fields,
|
||||
)
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.validators import validate_description_length
|
||||
|
|
@ -38,6 +55,58 @@ from models.provider_ids import ModelProviderID
|
|||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = _get_or_create_model("DatasetBase", dataset_fields)
|
||||
|
||||
tag_model = _get_or_create_model("Tag", tag_fields)
|
||||
|
||||
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
|
||||
|
||||
weighted_score_fields_copy = weighted_score_fields.copy()
|
||||
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
|
||||
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
|
||||
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
|
||||
|
||||
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
|
||||
|
||||
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
|
||||
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
|
||||
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
|
||||
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
|
||||
|
||||
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
|
||||
|
||||
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
|
||||
|
||||
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
|
||||
|
||||
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
|
||||
|
||||
dataset_detail_fields_copy = dataset_detail_fields.copy()
|
||||
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
|
||||
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
|
||||
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
|
||||
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
|
||||
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
|
||||
dataset_detail_model = _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
|
||||
|
||||
dataset_query_detail_model = _get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields)
|
||||
|
||||
app_detail_kernel_model = _get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
|
||||
related_app_list_copy = related_app_list.copy()
|
||||
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
|
||||
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
|
||||
|
||||
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
|
|
@ -119,9 +188,9 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
|||
|
||||
@console_ns.route("/datasets")
|
||||
class DatasetListApi(Resource):
|
||||
@api.doc("get_datasets")
|
||||
@api.doc(description="Get list of datasets")
|
||||
@api.doc(
|
||||
@console_ns.doc("get_datasets")
|
||||
@console_ns.doc(description="Get list of datasets")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"page": "Page number (default: 1)",
|
||||
"limit": "Number of items per page (default: 20)",
|
||||
|
|
@ -131,7 +200,7 @@ class DatasetListApi(Resource):
|
|||
"include_all": "Include all datasets (default: false)",
|
||||
}
|
||||
)
|
||||
@api.response(200, "Datasets retrieved successfully")
|
||||
@console_ns.response(200, "Datasets retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -184,10 +253,10 @@ class DatasetListApi(Resource):
|
|||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
return response, 200
|
||||
|
||||
@api.doc("create_dataset")
|
||||
@api.doc(description="Create a new dataset")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_dataset")
|
||||
@console_ns.doc(description="Create a new dataset")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateDatasetRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="Dataset name (1-40 characters)"),
|
||||
|
|
@ -200,8 +269,8 @@ class DatasetListApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "Dataset created successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.response(201, "Dataset created successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -279,12 +348,12 @@ class DatasetListApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>")
|
||||
class DatasetApi(Resource):
|
||||
@api.doc("get_dataset")
|
||||
@api.doc(description="Get dataset details")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.response(200, "Dataset retrieved successfully", dataset_detail_fields)
|
||||
@api.response(404, "Dataset not found")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.doc("get_dataset")
|
||||
@console_ns.doc(description="Get dataset details")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Dataset retrieved successfully", dataset_detail_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -328,10 +397,10 @@ class DatasetApi(Resource):
|
|||
|
||||
return data, 200
|
||||
|
||||
@api.doc("update_dataset")
|
||||
@api.doc(description="Update dataset details")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_dataset")
|
||||
@console_ns.doc(description="Update dataset details")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateDatasetRequest",
|
||||
{
|
||||
"name": fields.String(description="Dataset name"),
|
||||
|
|
@ -342,9 +411,9 @@ class DatasetApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Dataset updated successfully", dataset_detail_fields)
|
||||
@api.response(404, "Dataset not found")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -488,10 +557,10 @@ class DatasetApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/use-check")
|
||||
class DatasetUseCheckApi(Resource):
|
||||
@api.doc("check_dataset_use")
|
||||
@api.doc(description="Check if dataset is in use")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.response(200, "Dataset use status retrieved successfully")
|
||||
@console_ns.doc("check_dataset_use")
|
||||
@console_ns.doc(description="Check if dataset is in use")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Dataset use status retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -504,10 +573,10 @@ class DatasetUseCheckApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/queries")
|
||||
class DatasetQueryApi(Resource):
|
||||
@api.doc("get_dataset_queries")
|
||||
@api.doc(description="Get dataset query history")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.response(200, "Query history retrieved successfully", dataset_query_detail_fields)
|
||||
@console_ns.doc("get_dataset_queries")
|
||||
@console_ns.doc(description="Get dataset query history")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Query history retrieved successfully", dataset_query_detail_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -529,7 +598,7 @@ class DatasetQueryApi(Resource):
|
|||
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
|
||||
|
||||
response = {
|
||||
"data": marshal(dataset_queries, dataset_query_detail_fields),
|
||||
"data": marshal(dataset_queries, dataset_query_detail_model),
|
||||
"has_more": len(dataset_queries) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
|
|
@ -540,9 +609,9 @@ class DatasetQueryApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/indexing-estimate")
|
||||
class DatasetIndexingEstimateApi(Resource):
|
||||
@api.doc("estimate_dataset_indexing")
|
||||
@api.doc(description="Estimate dataset indexing cost")
|
||||
@api.response(200, "Indexing estimate calculated successfully")
|
||||
@console_ns.doc("estimate_dataset_indexing")
|
||||
@console_ns.doc(description="Estimate dataset indexing cost")
|
||||
@console_ns.response(200, "Indexing estimate calculated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -650,14 +719,14 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/related-apps")
|
||||
class DatasetRelatedAppListApi(Resource):
|
||||
@api.doc("get_dataset_related_apps")
|
||||
@api.doc(description="Get applications related to dataset")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.response(200, "Related apps retrieved successfully", related_app_list)
|
||||
@console_ns.doc("get_dataset_related_apps")
|
||||
@console_ns.doc(description="Get applications related to dataset")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Related apps retrieved successfully", related_app_list_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(related_app_list)
|
||||
@marshal_with(related_app_list_model)
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
|
@ -683,10 +752,10 @@ class DatasetRelatedAppListApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/indexing-status")
|
||||
class DatasetIndexingStatusApi(Resource):
|
||||
@api.doc("get_dataset_indexing_status")
|
||||
@api.doc(description="Get dataset indexing status")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.response(200, "Indexing status retrieved successfully")
|
||||
@console_ns.doc("get_dataset_indexing_status")
|
||||
@console_ns.doc(description="Get dataset indexing status")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Indexing status retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -738,13 +807,13 @@ class DatasetApiKeyApi(Resource):
|
|||
token_prefix = "dataset-"
|
||||
resource_type = "dataset"
|
||||
|
||||
@api.doc("get_dataset_api_keys")
|
||||
@api.doc(description="Get dataset API keys")
|
||||
@api.response(200, "API keys retrieved successfully", api_key_list)
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get dataset API keys")
|
||||
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_list)
|
||||
@marshal_with(api_key_list_model)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
keys = db.session.scalars(
|
||||
|
|
@ -756,7 +825,7 @@ class DatasetApiKeyApi(Resource):
|
|||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_fields)
|
||||
@marshal_with(api_key_item_model)
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
|
|
@ -767,7 +836,7 @@ class DatasetApiKeyApi(Resource):
|
|||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
api.abort(
|
||||
console_ns.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
|
|
@ -787,10 +856,10 @@ class DatasetApiKeyApi(Resource):
|
|||
class DatasetApiDeleteApi(Resource):
|
||||
resource_type = "dataset"
|
||||
|
||||
@api.doc("delete_dataset_api_key")
|
||||
@api.doc(description="Delete dataset API key")
|
||||
@api.doc(params={"api_key_id": "API key ID"})
|
||||
@api.response(204, "API key deleted successfully")
|
||||
@console_ns.doc("delete_dataset_api_key")
|
||||
@console_ns.doc(description="Delete dataset API key")
|
||||
@console_ns.doc(params={"api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -809,7 +878,7 @@ class DatasetApiDeleteApi(Resource):
|
|||
)
|
||||
|
||||
if key is None:
|
||||
api.abort(404, message="API key not found")
|
||||
console_ns.abort(404, message="API key not found")
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
|
@ -832,9 +901,9 @@ class DatasetEnableApiApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/api-base-info")
|
||||
class DatasetApiBaseUrlApi(Resource):
|
||||
@api.doc("get_dataset_api_base_info")
|
||||
@api.doc(description="Get dataset API base information")
|
||||
@api.response(200, "API base info retrieved successfully")
|
||||
@console_ns.doc("get_dataset_api_base_info")
|
||||
@console_ns.doc(description="Get dataset API base information")
|
||||
@console_ns.response(200, "API base info retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -844,9 +913,9 @@ class DatasetApiBaseUrlApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/retrieval-setting")
|
||||
class DatasetRetrievalSettingApi(Resource):
|
||||
@api.doc("get_dataset_retrieval_setting")
|
||||
@api.doc(description="Get dataset retrieval settings")
|
||||
@api.response(200, "Retrieval settings retrieved successfully")
|
||||
@console_ns.doc("get_dataset_retrieval_setting")
|
||||
@console_ns.doc(description="Get dataset retrieval settings")
|
||||
@console_ns.response(200, "Retrieval settings retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -857,10 +926,10 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
|
||||
class DatasetRetrievalSettingMockApi(Resource):
|
||||
@api.doc("get_dataset_retrieval_setting_mock")
|
||||
@api.doc(description="Get mock dataset retrieval settings by vector type")
|
||||
@api.doc(params={"vector_type": "Vector store type"})
|
||||
@api.response(200, "Mock retrieval settings retrieved successfully")
|
||||
@console_ns.doc("get_dataset_retrieval_setting_mock")
|
||||
@console_ns.doc(description="Get mock dataset retrieval settings by vector type")
|
||||
@console_ns.doc(params={"vector_type": "Vector store type"})
|
||||
@console_ns.response(200, "Mock retrieval settings retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -870,11 +939,11 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
|
||||
class DatasetErrorDocs(Resource):
|
||||
@api.doc("get_dataset_error_docs")
|
||||
@api.doc(description="Get dataset error documents")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.response(200, "Error documents retrieved successfully")
|
||||
@api.response(404, "Dataset not found")
|
||||
@console_ns.doc("get_dataset_error_docs")
|
||||
@console_ns.doc(description="Get dataset error documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Error documents retrieved successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -890,12 +959,12 @@ class DatasetErrorDocs(Resource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users")
|
||||
class DatasetPermissionUserListApi(Resource):
|
||||
@api.doc("get_dataset_permission_users")
|
||||
@api.doc(description="Get dataset permission user list")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.response(200, "Permission users retrieved successfully")
|
||||
@api.response(404, "Dataset not found")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.doc("get_dataset_permission_users")
|
||||
@console_ns.doc(description="Get dataset permission user list")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Permission users retrieved successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -919,11 +988,11 @@ class DatasetPermissionUserListApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs")
|
||||
class DatasetAutoDisableLogApi(Resource):
|
||||
@api.doc("get_dataset_auto_disable_logs")
|
||||
@api.doc(description="Get dataset auto disable logs")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.response(200, "Auto disable logs retrieved successfully")
|
||||
@api.response(404, "Dataset not found")
|
||||
@console_ns.doc("get_dataset_auto_disable_logs")
|
||||
@console_ns.doc(description="Get dataset auto disable logs")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Auto disable logs retrieved successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from sqlalchemy import asc, desc, select
|
|||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
|
|
@ -45,9 +45,11 @@ from core.plugin.impl.exc import PluginDaemonClientSideError
|
|||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.document_fields import (
|
||||
dataset_and_document_fields,
|
||||
document_fields,
|
||||
document_metadata_fields,
|
||||
document_status_fields,
|
||||
document_with_segments_fields,
|
||||
)
|
||||
|
|
@ -61,6 +63,36 @@ from services.entities.knowledge_entities.knowledge_entities import KnowledgeCon
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_model = _get_or_create_model("Dataset", dataset_fields)
|
||||
|
||||
document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields)
|
||||
|
||||
document_fields_copy = document_fields.copy()
|
||||
document_fields_copy["doc_metadata"] = fields.List(
|
||||
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
|
||||
)
|
||||
document_model = _get_or_create_model("Document", document_fields_copy)
|
||||
|
||||
document_with_segments_fields_copy = document_with_segments_fields.copy()
|
||||
document_with_segments_fields_copy["doc_metadata"] = fields.List(
|
||||
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
|
||||
)
|
||||
document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
|
||||
|
||||
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
|
||||
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
|
||||
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
|
||||
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
|
@ -104,10 +136,10 @@ class DocumentResource(Resource):
|
|||
|
||||
@console_ns.route("/datasets/process-rule")
|
||||
class GetProcessRuleApi(Resource):
|
||||
@api.doc("get_process_rule")
|
||||
@api.doc(description="Get dataset document processing rules")
|
||||
@api.doc(params={"document_id": "Document ID (optional)"})
|
||||
@api.response(200, "Process rules retrieved successfully")
|
||||
@console_ns.doc("get_process_rule")
|
||||
@console_ns.doc(description="Get dataset document processing rules")
|
||||
@console_ns.doc(params={"document_id": "Document ID (optional)"})
|
||||
@console_ns.response(200, "Process rules retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -152,9 +184,9 @@ class GetProcessRuleApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents")
|
||||
class DatasetDocumentListApi(Resource):
|
||||
@api.doc("get_dataset_documents")
|
||||
@api.doc(description="Get documents in a dataset")
|
||||
@api.doc(
|
||||
@console_ns.doc("get_dataset_documents")
|
||||
@console_ns.doc(description="Get documents in a dataset")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"dataset_id": "Dataset ID",
|
||||
"page": "Page number (default: 1)",
|
||||
|
|
@ -165,13 +197,12 @@ class DatasetDocumentListApi(Resource):
|
|||
"status": "Filter documents by display status",
|
||||
}
|
||||
)
|
||||
@api.response(200, "Documents retrieved successfully")
|
||||
@console_ns.response(200, "Documents retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
def get(self, dataset_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
|
|
@ -276,7 +307,7 @@ class DatasetDocumentListApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_fields)
|
||||
@marshal_with(dataset_and_document_model)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
|
|
@ -357,10 +388,10 @@ class DatasetDocumentListApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/init")
|
||||
class DatasetInitApi(Resource):
|
||||
@api.doc("init_dataset")
|
||||
@api.doc(description="Initialize dataset with documents")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("init_dataset")
|
||||
@console_ns.doc(description="Initialize dataset with documents")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DatasetInitRequest",
|
||||
{
|
||||
"upload_file_id": fields.String(required=True, description="Upload file ID"),
|
||||
|
|
@ -370,12 +401,12 @@ class DatasetInitApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "Dataset initialized successfully", dataset_and_document_fields)
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_fields)
|
||||
@marshal_with(dataset_and_document_model)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
|
|
@ -446,12 +477,12 @@ class DatasetInitApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
|
||||
class DocumentIndexingEstimateApi(DocumentResource):
|
||||
@api.doc("estimate_document_indexing")
|
||||
@api.doc(description="Estimate document indexing cost")
|
||||
@api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@api.response(200, "Indexing estimate calculated successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@api.response(400, "Document already finished")
|
||||
@console_ns.doc("estimate_document_indexing")
|
||||
@console_ns.doc(description="Estimate document indexing cost")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Indexing estimate calculated successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(400, "Document already finished")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -661,11 +692,11 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||
class DocumentIndexingStatusApi(DocumentResource):
|
||||
@api.doc("get_document_indexing_status")
|
||||
@api.doc(description="Get document indexing status")
|
||||
@api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@api.response(200, "Indexing status retrieved successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@console_ns.doc("get_document_indexing_status")
|
||||
@console_ns.doc(description="Get document indexing status")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Indexing status retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -711,17 +742,17 @@ class DocumentIndexingStatusApi(DocumentResource):
|
|||
class DocumentApi(DocumentResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
@api.doc("get_document")
|
||||
@api.doc(description="Get document details")
|
||||
@api.doc(
|
||||
@console_ns.doc("get_document")
|
||||
@console_ns.doc(description="Get document details")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"dataset_id": "Dataset ID",
|
||||
"document_id": "Document ID",
|
||||
"metadata": "Metadata inclusion (all/only/without)",
|
||||
}
|
||||
)
|
||||
@api.response(200, "Document retrieved successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@console_ns.response(200, "Document retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -832,14 +863,14 @@ class DocumentApi(DocumentResource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>")
|
||||
class DocumentProcessingApi(DocumentResource):
|
||||
@api.doc("update_document_processing")
|
||||
@api.doc(description="Update document processing status (pause/resume)")
|
||||
@api.doc(
|
||||
@console_ns.doc("update_document_processing")
|
||||
@console_ns.doc(description="Update document processing status (pause/resume)")
|
||||
@console_ns.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "action": "Action to perform (pause/resume)"}
|
||||
)
|
||||
@api.response(200, "Processing status updated successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@api.response(400, "Invalid action")
|
||||
@console_ns.response(200, "Processing status updated successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(400, "Invalid action")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -877,11 +908,11 @@ class DocumentProcessingApi(DocumentResource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
|
||||
class DocumentMetadataApi(DocumentResource):
|
||||
@api.doc("update_document_metadata")
|
||||
@api.doc(description="Update document metadata")
|
||||
@api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_document_metadata")
|
||||
@console_ns.doc(description="Update document metadata")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateDocumentMetadataRequest",
|
||||
{
|
||||
"doc_type": fields.String(description="Document type"),
|
||||
|
|
@ -889,9 +920,9 @@ class DocumentMetadataApi(DocumentResource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Document metadata updated successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.response(200, "Document metadata updated successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -3,10 +3,22 @@ from flask_restx import Resource, fields, marshal, reqparse
|
|||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
dataset_retrieval_model_fields,
|
||||
doc_metadata_fields,
|
||||
external_knowledge_info_fields,
|
||||
external_retrieval_model_fields,
|
||||
icon_info_fields,
|
||||
keyword_setting_fields,
|
||||
reranking_model_fields,
|
||||
tag_fields,
|
||||
vector_setting_fields,
|
||||
weighted_score_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
|
@ -14,6 +26,51 @@ from services.hit_testing_service import HitTestingService
|
|||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
def _build_dataset_detail_model():
|
||||
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
|
||||
|
||||
weighted_score_fields_copy = weighted_score_fields.copy()
|
||||
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
|
||||
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
|
||||
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
|
||||
|
||||
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
|
||||
|
||||
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
|
||||
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
|
||||
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
|
||||
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
|
||||
|
||||
tag_model = _get_or_create_model("Tag", tag_fields)
|
||||
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
|
||||
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
|
||||
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
|
||||
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
|
||||
|
||||
dataset_detail_fields_copy = dataset_detail_fields.copy()
|
||||
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
|
||||
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
|
||||
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
|
||||
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
|
||||
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
|
||||
return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
|
||||
|
||||
|
||||
try:
|
||||
dataset_detail_model = console_ns.models["DatasetDetail"]
|
||||
except KeyError:
|
||||
dataset_detail_model = _build_dataset_detail_model()
|
||||
|
||||
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 100:
|
||||
raise ValueError("Name must be between 1 to 100 characters.")
|
||||
|
|
@ -22,16 +79,16 @@ def _validate_name(name: str) -> str:
|
|||
|
||||
@console_ns.route("/datasets/external-knowledge-api")
|
||||
class ExternalApiTemplateListApi(Resource):
|
||||
@api.doc("get_external_api_templates")
|
||||
@api.doc(description="Get external knowledge API templates")
|
||||
@api.doc(
|
||||
@console_ns.doc("get_external_api_templates")
|
||||
@console_ns.doc(description="Get external knowledge API templates")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"page": "Page number (default: 1)",
|
||||
"limit": "Number of items per page (default: 20)",
|
||||
"keyword": "Search keyword",
|
||||
}
|
||||
)
|
||||
@api.response(200, "External API templates retrieved successfully")
|
||||
@console_ns.response(200, "External API templates retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -95,11 +152,11 @@ class ExternalApiTemplateListApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
|
||||
class ExternalApiTemplateApi(Resource):
|
||||
@api.doc("get_external_api_template")
|
||||
@api.doc(description="Get external knowledge API template details")
|
||||
@api.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@api.response(200, "External API template retrieved successfully")
|
||||
@api.response(404, "Template not found")
|
||||
@console_ns.doc("get_external_api_template")
|
||||
@console_ns.doc(description="Get external knowledge API template details")
|
||||
@console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@console_ns.response(200, "External API template retrieved successfully")
|
||||
@console_ns.response(404, "Template not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -163,10 +220,10 @@ class ExternalApiTemplateApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
|
||||
class ExternalApiUseCheckApi(Resource):
|
||||
@api.doc("check_external_api_usage")
|
||||
@api.doc(description="Check if external knowledge API is being used")
|
||||
@api.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@api.response(200, "Usage check completed successfully")
|
||||
@console_ns.doc("check_external_api_usage")
|
||||
@console_ns.doc(description="Check if external knowledge API is being used")
|
||||
@console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@console_ns.response(200, "Usage check completed successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -181,10 +238,10 @@ class ExternalApiUseCheckApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/external")
|
||||
class ExternalDatasetCreateApi(Resource):
|
||||
@api.doc("create_external_dataset")
|
||||
@api.doc(description="Create external knowledge dataset")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_external_dataset")
|
||||
@console_ns.doc(description="Create external knowledge dataset")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateExternalDatasetRequest",
|
||||
{
|
||||
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
|
||||
|
|
@ -194,9 +251,9 @@ class ExternalDatasetCreateApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "External dataset created successfully", dataset_detail_fields)
|
||||
@api.response(400, "Invalid parameters")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.response(201, "External dataset created successfully", dataset_detail_model)
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -239,11 +296,11 @@ class ExternalDatasetCreateApi(Resource):
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/external-hit-testing")
|
||||
class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@api.doc("test_external_knowledge_retrieval")
|
||||
@api.doc(description="Test external knowledge retrieval for dataset")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("test_external_knowledge_retrieval")
|
||||
@console_ns.doc(description="Test external knowledge retrieval for dataset")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ExternalHitTestingRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="Query text for testing"),
|
||||
|
|
@ -252,9 +309,9 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "External hit testing completed successfully")
|
||||
@api.response(404, "Dataset not found")
|
||||
@api.response(400, "Invalid parameters")
|
||||
@console_ns.response(200, "External hit testing completed successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -297,10 +354,10 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||
@console_ns.route("/test/retrieval")
|
||||
class BedrockRetrievalApi(Resource):
|
||||
# this api is only for internal testing
|
||||
@api.doc("bedrock_retrieval_test")
|
||||
@api.doc(description="Bedrock retrieval test (internal use only)")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("bedrock_retrieval_test")
|
||||
@console_ns.doc(description="Bedrock retrieval test (internal use only)")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"BedrockRetrievalTestRequest",
|
||||
{
|
||||
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
|
||||
|
|
@ -309,7 +366,7 @@ class BedrockRetrievalApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Bedrock retrieval test completed")
|
||||
@console_ns.response(200, "Bedrock retrieval test completed")
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
|
|
@ -12,11 +12,11 @@ from libs.login import login_required
|
|||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@api.doc("test_dataset_retrieval")
|
||||
@api.doc(description="Test dataset knowledge retrieval")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("test_dataset_retrieval")
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"HitTestingRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="Query text for testing"),
|
||||
|
|
@ -26,9 +26,9 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Hit testing completed successfully")
|
||||
@api.response(404, "Dataset not found")
|
||||
@api.response(400, "Invalid parameters")
|
||||
@console_ns.response(200, "Hit testing completed successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from flask_restx import Resource, reqparse
|
|||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -130,7 +130,7 @@ parser_datasource = (
|
|||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||
class DatasourceAuth(Resource):
|
||||
@api.expect(parser_datasource)
|
||||
@console_ns.expect(parser_datasource)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -176,7 +176,7 @@ parser_datasource_delete = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@api.expect(parser_datasource_delete)
|
||||
@console_ns.expect(parser_datasource_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -209,7 +209,7 @@ parser_datasource_update = (
|
|||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@api.expect(parser_datasource_update)
|
||||
@console_ns.expect(parser_datasource_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -267,7 +267,7 @@ parser_datasource_custom = (
|
|||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@api.expect(parser_datasource_custom)
|
||||
@console_ns.expect(parser_datasource_custom)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -306,7 +306,7 @@ parser_default = reqparse.RequestParser().add_argument("id", type=str, required=
|
|||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@api.expect(parser_default)
|
||||
@console_ns.expect(parser_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -334,7 +334,7 @@ parser_update_name = (
|
|||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@api.expect(parser_update_name)
|
||||
@console_ns.expect(parser_update_name)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
|
|||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
|
|
@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@api.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -38,7 +38,7 @@ class DataSourceContentPreviewApi(Resource):
|
|||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
args = Parser.model_validate(api.payload)
|
||||
args = Parser.model_validate(console_ns.payload)
|
||||
|
||||
inputs = args.inputs
|
||||
datasource_type = args.datasource_type
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
DraftWorkflowNotExist,
|
||||
|
|
@ -153,7 +153,7 @@ parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
@api.expect(parser_run)
|
||||
@console_ns.expect(parser_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -187,7 +187,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
@api.expect(parser_run)
|
||||
@console_ns.expect(parser_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -230,7 +230,7 @@ parser_draft_run = (
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
|
||||
class DraftRagPipelineRunApi(Resource):
|
||||
@api.expect(parser_draft_run)
|
||||
@console_ns.expect(parser_draft_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -273,7 +273,7 @@ parser_published_run = (
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
|
||||
class PublishedRagPipelineRunApi(Resource):
|
||||
@api.expect(parser_published_run)
|
||||
@console_ns.expect(parser_published_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -397,7 +397,7 @@ parser_rag_run = (
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@api.expect(parser_rag_run)
|
||||
@console_ns.expect(parser_rag_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -437,7 +437,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
@api.expect(parser_rag_run)
|
||||
@console_ns.expect(parser_rag_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
|
|
@ -482,7 +482,7 @@ parser_run_api = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftNodeRunApi(Resource):
|
||||
@api.expect(parser_run_api)
|
||||
@console_ns.expect(parser_run_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
|
|
@ -607,7 +607,7 @@ parser_default = reqparse.RequestParser().add_argument("q", type=str, location="
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
@api.expect(parser_default)
|
||||
@console_ns.expect(parser_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -644,7 +644,7 @@ parser_wf = (
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
|
||||
class PublishedAllRagPipelineApi(Resource):
|
||||
@api.expect(parser_wf)
|
||||
@console_ns.expect(parser_wf)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -696,7 +696,7 @@ parser_wf_id = (
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||
class RagPipelineByIdApi(Resource):
|
||||
@api.expect(parser_wf_id)
|
||||
@console_ns.expect(parser_wf_id)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -754,7 +754,7 @@ parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, r
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
|
||||
class PublishedRagPipelineSecondStepApi(Resource):
|
||||
@api.expect(parser_parameters)
|
||||
@console_ns.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -777,7 +777,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
|
||||
class PublishedRagPipelineFirstStepApi(Resource):
|
||||
@api.expect(parser_parameters)
|
||||
@console_ns.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -800,7 +800,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
|
||||
class DraftRagPipelineFirstStepApi(Resource):
|
||||
@api.expect(parser_parameters)
|
||||
@console_ns.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -823,7 +823,7 @@ class DraftRagPipelineFirstStepApi(Resource):
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
|
||||
class DraftRagPipelineSecondStepApi(Resource):
|
||||
@api.expect(parser_parameters)
|
||||
@console_ns.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -854,7 +854,7 @@ parser_wf_run = (
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
|
||||
class RagPipelineWorkflowRunListApi(Resource):
|
||||
@api.expect(parser_wf_run)
|
||||
@console_ns.expect(parser_wf_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -975,7 +975,7 @@ parser_var = (
|
|||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
|
||||
class RagPipelineDatasourceVariableApi(Resource):
|
||||
@api.expect(parser_var)
|
||||
@console_ns.expect(parser_var)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import WebsiteCrawlError
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
|
|
@ -9,10 +9,10 @@ from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusA
|
|||
|
||||
@console_ns.route("/website/crawl")
|
||||
class WebsiteCrawlApi(Resource):
|
||||
@api.doc("crawl_website")
|
||||
@api.doc(description="Crawl website content")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("crawl_website")
|
||||
@console_ns.doc(description="Crawl website content")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"WebsiteCrawlRequest",
|
||||
{
|
||||
"provider": fields.String(
|
||||
|
|
@ -25,8 +25,8 @@ class WebsiteCrawlApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Website crawl initiated successfully")
|
||||
@api.response(400, "Invalid crawl parameters")
|
||||
@console_ns.response(200, "Website crawl initiated successfully")
|
||||
@console_ns.response(400, "Invalid crawl parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -62,12 +62,12 @@ class WebsiteCrawlApi(Resource):
|
|||
|
||||
@console_ns.route("/website/crawl/status/<string:job_id>")
|
||||
class WebsiteCrawlStatusApi(Resource):
|
||||
@api.doc("get_crawl_status")
|
||||
@api.doc(description="Get website crawl status")
|
||||
@api.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
|
||||
@api.response(200, "Crawl status retrieved successfully")
|
||||
@api.response(404, "Crawl job not found")
|
||||
@api.response(400, "Invalid provider")
|
||||
@console_ns.doc("get_crawl_status")
|
||||
@console_ns.doc(description="Get website crawl status")
|
||||
@console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
|
||||
@console_ns.response(200, "Crawl status retrieved successfully")
|
||||
@console_ns.response(404, "Crawl job not found")
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -1,44 +1,40 @@
|
|||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.dataset import Pipeline
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
def get_rag_pipeline(
|
||||
view: Callable | None = None,
|
||||
):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
def get_rag_pipeline(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
pipeline_id = kwargs.get("pipeline_id")
|
||||
pipeline_id = str(pipeline_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
del kwargs["pipeline_id"]
|
||||
pipeline_id = kwargs.get("pipeline_id")
|
||||
pipeline_id = str(pipeline_id)
|
||||
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
del kwargs["pipeline_id"]
|
||||
|
||||
if not pipeline:
|
||||
raise PipelineNotFoundError()
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
kwargs["pipeline"] = pipeline
|
||||
if not pipeline:
|
||||
raise PipelineNotFoundError()
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
kwargs["pipeline"] = pipeline
|
||||
|
||||
return decorated_view
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
return decorated_view
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from libs.login import current_user, login_required
|
||||
|
|
@ -40,7 +40,7 @@ parser_apps = reqparse.RequestParser().add_argument("language", type=str, locati
|
|||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@api.expect(parser_apps)
|
||||
@console_ns.expect(parser_apps)
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(recommended_app_list_fields)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -9,18 +9,24 @@ from models.api_based_extension import APIBasedExtension
|
|||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
|
||||
|
||||
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
|
||||
|
||||
|
||||
@console_ns.route("/code-based-extension")
|
||||
class CodeBasedExtensionAPI(Resource):
|
||||
@api.doc("get_code_based_extension")
|
||||
@api.doc(description="Get code-based extension data by module name")
|
||||
@api.expect(
|
||||
api.parser().add_argument("module", type=str, required=True, location="args", help="Extension module name")
|
||||
@console_ns.doc("get_code_based_extension")
|
||||
@console_ns.doc(description="Get code-based extension data by module name")
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"module", type=str, required=True, location="args", help="Extension module name"
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"CodeBasedExtensionResponse",
|
||||
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
|
||||
),
|
||||
|
|
@ -37,21 +43,21 @@ class CodeBasedExtensionAPI(Resource):
|
|||
|
||||
@console_ns.route("/api-based-extension")
|
||||
class APIBasedExtensionAPI(Resource):
|
||||
@api.doc("get_api_based_extensions")
|
||||
@api.doc(description="Get all API-based extensions for current tenant")
|
||||
@api.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields)))
|
||||
@console_ns.doc("get_api_based_extensions")
|
||||
@console_ns.doc(description="Get all API-based extensions for current tenant")
|
||||
@console_ns.response(200, "Success", api_based_extension_list_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
@marshal_with(api_based_extension_model)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
|
||||
@api.doc("create_api_based_extension")
|
||||
@api.doc(description="Create a new API-based extension")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_api_based_extension")
|
||||
@console_ns.doc(description="Create a new API-based extension")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateAPIBasedExtensionRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="Extension name"),
|
||||
|
|
@ -60,13 +66,13 @@ class APIBasedExtensionAPI(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "Extension created successfully", api_based_extension_fields)
|
||||
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
@marshal_with(api_based_extension_model)
|
||||
def post(self):
|
||||
args = api.payload
|
||||
args = console_ns.payload
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
|
|
@ -81,25 +87,25 @@ class APIBasedExtensionAPI(Resource):
|
|||
|
||||
@console_ns.route("/api-based-extension/<uuid:id>")
|
||||
class APIBasedExtensionDetailAPI(Resource):
|
||||
@api.doc("get_api_based_extension")
|
||||
@api.doc(description="Get API-based extension by ID")
|
||||
@api.doc(params={"id": "Extension ID"})
|
||||
@api.response(200, "Success", api_based_extension_fields)
|
||||
@console_ns.doc("get_api_based_extension")
|
||||
@console_ns.doc(description="Get API-based extension by ID")
|
||||
@console_ns.doc(params={"id": "Extension ID"})
|
||||
@console_ns.response(200, "Success", api_based_extension_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
@marshal_with(api_based_extension_model)
|
||||
def get(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
|
||||
@api.doc("update_api_based_extension")
|
||||
@api.doc(description="Update API-based extension")
|
||||
@api.doc(params={"id": "Extension ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_api_based_extension")
|
||||
@console_ns.doc(description="Update API-based extension")
|
||||
@console_ns.doc(params={"id": "Extension ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateAPIBasedExtensionRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="Extension name"),
|
||||
|
|
@ -108,18 +114,18 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Extension updated successfully", api_based_extension_fields)
|
||||
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
@marshal_with(api_based_extension_model)
|
||||
def post(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
args = api.payload
|
||||
args = console_ns.payload
|
||||
|
||||
extension_data_from_db.name = args["name"]
|
||||
extension_data_from_db.api_endpoint = args["api_endpoint"]
|
||||
|
|
@ -129,10 +135,10 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
|
||||
return APIBasedExtensionService.save(extension_data_from_db)
|
||||
|
||||
@api.doc("delete_api_based_extension")
|
||||
@api.doc(description="Delete API-based extension")
|
||||
@api.doc(params={"id": "Extension ID"})
|
||||
@api.response(204, "Extension deleted successfully")
|
||||
@console_ns.doc("delete_api_based_extension")
|
||||
@console_ns.doc(description="Delete API-based extension")
|
||||
@console_ns.doc(params={"id": "Extension ID"})
|
||||
@console_ns.response(204, "Extension deleted successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -3,18 +3,18 @@ from flask_restx import Resource, fields
|
|||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api, console_ns
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
|
||||
@console_ns.route("/features")
|
||||
class FeatureApi(Resource):
|
||||
@api.doc("get_tenant_features")
|
||||
@api.doc(description="Get feature configuration for current tenant")
|
||||
@api.response(
|
||||
@console_ns.doc("get_tenant_features")
|
||||
@console_ns.doc(description="Get feature configuration for current tenant")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
|
||||
console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -29,12 +29,14 @@ class FeatureApi(Resource):
|
|||
|
||||
@console_ns.route("/system-features")
|
||||
class SystemFeatureApi(Resource):
|
||||
@api.doc("get_system_features")
|
||||
@api.doc(description="Get system-wide feature configuration")
|
||||
@api.response(
|
||||
@console_ns.doc("get_system_features")
|
||||
@console_ns.doc(description="Get system-wide feature configuration")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model("SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}),
|
||||
console_ns.model(
|
||||
"SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Get system-wide feature configuration"""
|
||||
|
|
|
|||
|
|
@ -11,19 +11,19 @@ from libs.helper import StrLen
|
|||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
from . import api, console_ns
|
||||
from . import console_ns
|
||||
from .error import AlreadySetupError, InitValidateFailedError
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
|
||||
@console_ns.route("/init")
|
||||
class InitValidateAPI(Resource):
|
||||
@api.doc("get_init_status")
|
||||
@api.doc(description="Get initialization validation status")
|
||||
@api.response(
|
||||
@console_ns.doc("get_init_status")
|
||||
@console_ns.doc(description="Get initialization validation status")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
model=api.model(
|
||||
model=console_ns.model(
|
||||
"InitStatusResponse",
|
||||
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
|
||||
),
|
||||
|
|
@ -35,20 +35,20 @@ class InitValidateAPI(Resource):
|
|||
return {"status": "finished"}
|
||||
return {"status": "not_started"}
|
||||
|
||||
@api.doc("validate_init_password")
|
||||
@api.doc(description="Validate initialization password for self-hosted edition")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("validate_init_password")
|
||||
@console_ns.doc(description="Validate initialization password for self-hosted edition")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InitValidateRequest",
|
||||
{"password": fields.String(required=True, description="Initialization password", max_length=30)},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
201,
|
||||
"Success",
|
||||
model=api.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
|
||||
model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@api.response(400, "Already setup or validation failed")
|
||||
@console_ns.response(400, "Already setup or validation failed")
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
"""Validate initialization password"""
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
from flask_restx import Resource, fields
|
||||
|
||||
from . import api, console_ns
|
||||
from . import console_ns
|
||||
|
||||
|
||||
@console_ns.route("/ping")
|
||||
class PingApi(Resource):
|
||||
@api.doc("health_check")
|
||||
@api.doc(description="Health check endpoint for connection testing")
|
||||
@api.response(
|
||||
@console_ns.doc("health_check")
|
||||
@console_ns.doc(description="Health check endpoint for connection testing")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
|
||||
console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
|
||||
)
|
||||
def get(self):
|
||||
"""Health check endpoint for connection testing"""
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from controllers.common.errors import (
|
|||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console import api
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -42,7 +41,7 @@ parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=
|
|||
|
||||
@console_ns.route("/remote-files/upload")
|
||||
class RemoteFileUploadApi(Resource):
|
||||
@api.expect(parser_upload)
|
||||
@console_ns.expect(parser_upload)
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
args = parser_upload.parse_args()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from libs.password import valid_password
|
|||
from models.model import DifySetup, db
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
from . import api, console_ns
|
||||
from . import console_ns
|
||||
from .error import AlreadySetupError, NotInitValidateError
|
||||
from .init_validate import get_init_validate_status
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
|
@ -15,12 +15,12 @@ from .wraps import only_edition_self_hosted
|
|||
|
||||
@console_ns.route("/setup")
|
||||
class SetupApi(Resource):
|
||||
@api.doc("get_setup_status")
|
||||
@api.doc(description="Get system setup status")
|
||||
@api.response(
|
||||
@console_ns.doc("get_setup_status")
|
||||
@console_ns.doc(description="Get system setup status")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"SetupStatusResponse",
|
||||
{
|
||||
"step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
|
||||
|
|
@ -40,10 +40,10 @@ class SetupApi(Resource):
|
|||
return {"step": "not_started"}
|
||||
return {"step": "finished"}
|
||||
|
||||
@api.doc("setup_system")
|
||||
@api.doc(description="Initialize system setup with admin account")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("setup_system")
|
||||
@console_ns.doc(description="Initialize system setup with admin account")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SetupRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Admin email address"),
|
||||
|
|
@ -53,8 +53,10 @@ class SetupApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "Success", api.model("SetupResponse", {"result": fields.String(description="Setup result")}))
|
||||
@api.response(400, "Already setup or validation failed")
|
||||
@console_ns.response(
|
||||
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
|
||||
)
|
||||
@console_ns.response(400, "Already setup or validation failed")
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
"""Initialize system setup with admin account"""
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from flask import request
|
|||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -43,7 +43,7 @@ class TagListApi(Resource):
|
|||
|
||||
return tags, 200
|
||||
|
||||
@api.expect(parser_tags)
|
||||
@console_ns.expect(parser_tags)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -68,7 +68,7 @@ parser_tag_id = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@api.expect(parser_tag_id)
|
||||
@console_ns.expect(parser_tag_id)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -110,7 +110,7 @@ parser_create = (
|
|||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@api.expect(parser_create)
|
||||
@console_ns.expect(parser_create)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -136,7 +136,7 @@ parser_remove = (
|
|||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@api.expect(parser_remove)
|
||||
@console_ns.expect(parser_remove)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from packaging import version
|
|||
|
||||
from configs import dify_config
|
||||
|
||||
from . import api, console_ns
|
||||
from . import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -18,13 +18,13 @@ parser = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/version")
|
||||
class VersionApi(Resource):
|
||||
@api.doc("check_version_update")
|
||||
@api.doc(description="Check for application version updates")
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
@console_ns.doc("check_version_update")
|
||||
@console_ns.doc(description="Check for application version updates")
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"VersionResponse",
|
||||
{
|
||||
"version": fields.String(description="Latest version number"),
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
EmailChangeLimitError,
|
||||
|
|
@ -55,7 +55,7 @@ def _init_parser():
|
|||
|
||||
@console_ns.route("/account/init")
|
||||
class AccountInitApi(Resource):
|
||||
@api.expect(_init_parser())
|
||||
@console_ns.expect(_init_parser())
|
||||
@setup_required
|
||||
@login_required
|
||||
def post(self):
|
||||
|
|
@ -115,7 +115,7 @@ parser_name = reqparse.RequestParser().add_argument("name", type=str, required=T
|
|||
|
||||
@console_ns.route("/account/name")
|
||||
class AccountNameApi(Resource):
|
||||
@api.expect(parser_name)
|
||||
@console_ns.expect(parser_name)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -138,7 +138,7 @@ parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, requir
|
|||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@api.expect(parser_avatar)
|
||||
@console_ns.expect(parser_avatar)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -159,7 +159,7 @@ parser_interface = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/account/interface-language")
|
||||
class AccountInterfaceLanguageApi(Resource):
|
||||
@api.expect(parser_interface)
|
||||
@console_ns.expect(parser_interface)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -180,7 +180,7 @@ parser_theme = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/account/interface-theme")
|
||||
class AccountInterfaceThemeApi(Resource):
|
||||
@api.expect(parser_theme)
|
||||
@console_ns.expect(parser_theme)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -199,7 +199,7 @@ parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, re
|
|||
|
||||
@console_ns.route("/account/timezone")
|
||||
class AccountTimezoneApi(Resource):
|
||||
@api.expect(parser_timezone)
|
||||
@console_ns.expect(parser_timezone)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -227,7 +227,7 @@ parser_pw = (
|
|||
|
||||
@console_ns.route("/account/password")
|
||||
class AccountPasswordApi(Resource):
|
||||
@api.expect(parser_pw)
|
||||
@console_ns.expect(parser_pw)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -325,7 +325,7 @@ parser_delete = (
|
|||
|
||||
@console_ns.route("/account/delete")
|
||||
class AccountDeleteApi(Resource):
|
||||
@api.expect(parser_delete)
|
||||
@console_ns.expect(parser_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -351,7 +351,7 @@ parser_feedback = (
|
|||
|
||||
@console_ns.route("/account/delete/feedback")
|
||||
class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
@api.expect(parser_feedback)
|
||||
@console_ns.expect(parser_feedback)
|
||||
@setup_required
|
||||
def post(self):
|
||||
args = parser_feedback.parse_args()
|
||||
|
|
@ -396,7 +396,7 @@ class EducationApi(Resource):
|
|||
"allow_refresh": fields.Boolean,
|
||||
}
|
||||
|
||||
@api.expect(parser_edu)
|
||||
@console_ns.expect(parser_edu)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -441,7 +441,7 @@ class EducationAutoCompleteApi(Resource):
|
|||
"has_next": fields.Boolean,
|
||||
}
|
||||
|
||||
@api.expect(parser_autocomplete)
|
||||
@console_ns.expect(parser_autocomplete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -465,7 +465,7 @@ parser_change_email = (
|
|||
|
||||
@console_ns.route("/account/change-email")
|
||||
class ChangeEmailSendEmailApi(Resource):
|
||||
@api.expect(parser_change_email)
|
||||
@console_ns.expect(parser_change_email)
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -517,7 +517,7 @@ parser_validity = (
|
|||
|
||||
@console_ns.route("/account/change-email/validity")
|
||||
class ChangeEmailCheckApi(Resource):
|
||||
@api.expect(parser_validity)
|
||||
@console_ns.expect(parser_validity)
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -563,7 +563,7 @@ parser_reset = (
|
|||
|
||||
@console_ns.route("/account/change-email/reset")
|
||||
class ChangeEmailResetApi(Resource):
|
||||
@api.expect(parser_reset)
|
||||
@console_ns.expect(parser_reset)
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -603,7 +603,7 @@ parser_check = reqparse.RequestParser().add_argument("email", type=email, requir
|
|||
|
||||
@console_ns.route("/account/change-email/check-email-unique")
|
||||
class CheckEmailUnique(Resource):
|
||||
@api.expect(parser_check)
|
||||
@console_ns.expect(parser_check)
|
||||
@setup_required
|
||||
def post(self):
|
||||
args = parser_check.parse_args()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
|
|
@ -9,9 +9,9 @@ from services.agent_service import AgentService
|
|||
|
||||
@console_ns.route("/workspaces/current/agent-providers")
|
||||
class AgentProviderListApi(Resource):
|
||||
@api.doc("list_agent_providers")
|
||||
@api.doc(description="Get list of available agent providers")
|
||||
@api.response(
|
||||
@console_ns.doc("list_agent_providers")
|
||||
@console_ns.doc(description="Get list of available agent providers")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
fields.List(fields.Raw(description="Agent provider information")),
|
||||
|
|
@ -31,10 +31,10 @@ class AgentProviderListApi(Resource):
|
|||
|
||||
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
|
||||
class AgentProviderApi(Resource):
|
||||
@api.doc("get_agent_provider")
|
||||
@api.doc(description="Get specific agent provider details")
|
||||
@api.doc(params={"provider_name": "Agent provider name"})
|
||||
@api.response(
|
||||
@console_ns.doc("get_agent_provider")
|
||||
@console_ns.doc(description="Get specific agent provider details")
|
||||
@console_ns.doc(params={"provider_name": "Agent provider name"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
fields.Raw(description="Agent provider details"),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
|
|
@ -10,10 +10,10 @@ from services.plugin.endpoint_service import EndpointService
|
|||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@api.doc("create_endpoint")
|
||||
@api.doc(description="Create a new plugin endpoint")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_endpoint")
|
||||
@console_ns.doc(description="Create a new plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointCreateRequest",
|
||||
{
|
||||
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
|
||||
|
|
@ -22,12 +22,12 @@ class EndpointCreateApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
api.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
)
|
||||
@api.response(403, "Admin privileges required")
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -63,17 +63,19 @@ class EndpointCreateApi(Resource):
|
|||
|
||||
@console_ns.route("/workspaces/current/endpoints/list")
|
||||
class EndpointListApi(Resource):
|
||||
@api.doc("list_endpoints")
|
||||
@api.doc(description="List plugin endpoints with pagination")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
@console_ns.doc("list_endpoints")
|
||||
@console_ns.doc(description="List plugin endpoints with pagination")
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model("EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}),
|
||||
console_ns.model(
|
||||
"EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
|
||||
),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -105,18 +107,18 @@ class EndpointListApi(Resource):
|
|||
|
||||
@console_ns.route("/workspaces/current/endpoints/list/plugin")
|
||||
class EndpointListForSinglePluginApi(Resource):
|
||||
@api.doc("list_plugin_endpoints")
|
||||
@api.doc(description="List endpoints for a specific plugin")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
@console_ns.doc("list_plugin_endpoints")
|
||||
@console_ns.doc(description="List endpoints for a specific plugin")
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, required=True, location="args", help="Page number")
|
||||
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
|
||||
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
|
||||
),
|
||||
)
|
||||
|
|
@ -153,17 +155,19 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
class EndpointDeleteApi(Resource):
|
||||
@api.doc("delete_endpoint")
|
||||
@api.doc(description="Delete a plugin endpoint")
|
||||
@api.expect(
|
||||
api.model("EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
|
||||
@console_ns.doc("delete_endpoint")
|
||||
@console_ns.doc(description="Delete a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
api.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
)
|
||||
@api.response(403, "Admin privileges required")
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -183,10 +187,10 @@ class EndpointDeleteApi(Resource):
|
|||
|
||||
@console_ns.route("/workspaces/current/endpoints/update")
|
||||
class EndpointUpdateApi(Resource):
|
||||
@api.doc("update_endpoint")
|
||||
@api.doc(description="Update a plugin endpoint")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointUpdateRequest",
|
||||
{
|
||||
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
|
||||
|
|
@ -195,12 +199,12 @@ class EndpointUpdateApi(Resource):
|
|||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
api.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
)
|
||||
@api.response(403, "Admin privileges required")
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -233,17 +237,19 @@ class EndpointUpdateApi(Resource):
|
|||
|
||||
@console_ns.route("/workspaces/current/endpoints/enable")
|
||||
class EndpointEnableApi(Resource):
|
||||
@api.doc("enable_endpoint")
|
||||
@api.doc(description="Enable a plugin endpoint")
|
||||
@api.expect(
|
||||
api.model("EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
|
||||
@console_ns.doc("enable_endpoint")
|
||||
@console_ns.doc(description="Enable a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint enabled successfully",
|
||||
api.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
)
|
||||
@api.response(403, "Admin privileges required")
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -263,17 +269,19 @@ class EndpointEnableApi(Resource):
|
|||
|
||||
@console_ns.route("/workspaces/current/endpoints/disable")
|
||||
class EndpointDisableApi(Resource):
|
||||
@api.doc("disable_endpoint")
|
||||
@api.doc(description="Disable a plugin endpoint")
|
||||
@api.expect(
|
||||
api.model("EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
|
||||
@console_ns.doc("disable_endpoint")
|
||||
@console_ns.doc(description="Disable a plugin endpoint")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint disabled successfully",
|
||||
api.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
)
|
||||
@api.response(403, "Admin privileges required")
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse
|
|||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
EmailCodeError,
|
||||
|
|
@ -60,7 +60,7 @@ parser_invite = (
|
|||
class MemberInviteEmailApi(Resource):
|
||||
"""Invite a new member by email."""
|
||||
|
||||
@api.expect(parser_invite)
|
||||
@console_ns.expect(parser_invite)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -153,7 +153,7 @@ parser_update = reqparse.RequestParser().add_argument("role", type=str, required
|
|||
class MemberUpdateRoleApi(Resource):
|
||||
"""Update member role."""
|
||||
|
||||
@api.expect(parser_update)
|
||||
@console_ns.expect(parser_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -204,7 +204,7 @@ parser_send = reqparse.RequestParser().add_argument("language", type=str, requir
|
|||
class SendOwnerTransferEmailApi(Resource):
|
||||
"""Send owner transfer email."""
|
||||
|
||||
@api.expect(parser_send)
|
||||
@console_ns.expect(parser_send)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -247,7 +247,7 @@ parser_owner = (
|
|||
|
||||
@console_ns.route("/workspaces/current/members/owner-transfer-check")
|
||||
class OwnerTransferCheckApi(Resource):
|
||||
@api.expect(parser_owner)
|
||||
@console_ns.expect(parser_owner)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -295,7 +295,7 @@ parser_owner_transfer = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
|
||||
class OwnerTransfer(Resource):
|
||||
@api.expect(parser_owner_transfer)
|
||||
@console_ns.expect(parser_owner_transfer)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import io
|
|||
from flask import send_file
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
|
@ -25,7 +25,7 @@ parser_model = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/model-providers")
|
||||
class ModelProviderListApi(Resource):
|
||||
@api.expect(parser_model)
|
||||
@console_ns.expect(parser_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -64,7 +64,7 @@ parser_delete_cred = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
class ModelProviderCredentialApi(Resource):
|
||||
@api.expect(parser_cred)
|
||||
@console_ns.expect(parser_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -81,7 +81,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
return {"credentials": credentials}
|
||||
|
||||
@api.expect(parser_post_cred)
|
||||
@console_ns.expect(parser_post_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -104,7 +104,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@api.expect(parser_put_cred)
|
||||
@console_ns.expect(parser_put_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -129,7 +129,7 @@ class ModelProviderCredentialApi(Resource):
|
|||
|
||||
return {"result": "success"}
|
||||
|
||||
@api.expect(parser_delete_cred)
|
||||
@console_ns.expect(parser_delete_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -153,7 +153,7 @@ parser_switch = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
|
||||
class ModelProviderCredentialSwitchApi(Resource):
|
||||
@api.expect(parser_switch)
|
||||
@console_ns.expect(parser_switch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -178,7 +178,7 @@ parser_validate = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@api.expect(parser_validate)
|
||||
@console_ns.expect(parser_validate)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -240,7 +240,7 @@ parser_preferred = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
|
||||
class PreferredProviderTypeUpdateApi(Resource):
|
||||
@api.expect(parser_preferred)
|
||||
@console_ns.expect(parser_preferred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
|
@ -30,7 +30,7 @@ parser_post_default = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/default-model")
|
||||
class DefaultModelApi(Resource):
|
||||
@api.expect(parser_get_default)
|
||||
@console_ns.expect(parser_get_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -46,7 +46,7 @@ class DefaultModelApi(Resource):
|
|||
|
||||
return jsonable_encoder({"data": default_model_entity})
|
||||
|
||||
@api.expect(parser_post_default)
|
||||
@console_ns.expect(parser_post_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -127,7 +127,7 @@ class ModelProviderModelApi(Resource):
|
|||
|
||||
return jsonable_encoder({"data": models})
|
||||
|
||||
@api.expect(parser_post_models)
|
||||
@console_ns.expect(parser_post_models)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -173,7 +173,7 @@ class ModelProviderModelApi(Resource):
|
|||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@api.expect(parser_delete_models)
|
||||
@console_ns.expect(parser_delete_models)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -253,7 +253,7 @@ parser_delete_cred = (
|
|||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
|
||||
class ModelProviderModelCredentialApi(Resource):
|
||||
@api.expect(parser_get_credentials)
|
||||
@console_ns.expect(parser_get_credentials)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -304,7 +304,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
}
|
||||
)
|
||||
|
||||
@api.expect(parser_post_cred)
|
||||
@console_ns.expect(parser_post_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -336,7 +336,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@api.expect(parser_put_cred)
|
||||
@console_ns.expect(parser_put_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -362,7 +362,7 @@ class ModelProviderModelCredentialApi(Resource):
|
|||
|
||||
return {"result": "success"}
|
||||
|
||||
@api.expect(parser_delete_cred)
|
||||
@console_ns.expect(parser_delete_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -400,7 +400,7 @@ parser_switch = (
|
|||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
|
||||
class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@api.expect(parser_switch)
|
||||
@console_ns.expect(parser_switch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -439,7 +439,7 @@ parser_model_enable_disable = (
|
|||
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
|
||||
)
|
||||
class ModelProviderModelEnableApi(Resource):
|
||||
@api.expect(parser_model_enable_disable)
|
||||
@console_ns.expect(parser_model_enable_disable)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -460,7 +460,7 @@ class ModelProviderModelEnableApi(Resource):
|
|||
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
|
||||
)
|
||||
class ModelProviderModelDisableApi(Resource):
|
||||
@api.expect(parser_model_enable_disable)
|
||||
@console_ns.expect(parser_model_enable_disable)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -494,7 +494,7 @@ parser_validate = (
|
|||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
@api.expect(parser_validate)
|
||||
@console_ns.expect(parser_validate)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -535,7 +535,7 @@ parser_parameter = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
|
||||
class ModelProviderModelParameterRuleApi(Resource):
|
||||
@api.expect(parser_parameter)
|
||||
@console_ns.expect(parser_parameter)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask_restx import Resource, reqparse
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -46,7 +46,7 @@ parser_list = (
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/list")
|
||||
class PluginListApi(Resource):
|
||||
@api.expect(parser_list)
|
||||
@console_ns.expect(parser_list)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -66,7 +66,7 @@ parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, r
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
|
||||
class PluginListLatestVersionsApi(Resource):
|
||||
@api.expect(parser_latest)
|
||||
@console_ns.expect(parser_latest)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -86,7 +86,7 @@ parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, requ
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/list/installations/ids")
|
||||
class PluginListInstallationsFromIdsApi(Resource):
|
||||
@api.expect(parser_ids)
|
||||
@console_ns.expect(parser_ids)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -112,7 +112,7 @@ parser_icon = (
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/icon")
|
||||
class PluginIconApi(Resource):
|
||||
@api.expect(parser_icon)
|
||||
@console_ns.expect(parser_icon)
|
||||
@setup_required
|
||||
def get(self):
|
||||
args = parser_icon.parse_args()
|
||||
|
|
@ -181,7 +181,7 @@ parser_github = (
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/github")
|
||||
class PluginUploadFromGithubApi(Resource):
|
||||
@api.expect(parser_github)
|
||||
@console_ns.expect(parser_github)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -230,7 +230,7 @@ parser_pkg = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/install/pkg")
|
||||
class PluginInstallFromPkgApi(Resource):
|
||||
@api.expect(parser_pkg)
|
||||
@console_ns.expect(parser_pkg)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -263,7 +263,7 @@ parser_githubapi = (
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/install/github")
|
||||
class PluginInstallFromGithubApi(Resource):
|
||||
@api.expect(parser_githubapi)
|
||||
@console_ns.expect(parser_githubapi)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -294,7 +294,7 @@ parser_marketplace = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/install/marketplace")
|
||||
class PluginInstallFromMarketplaceApi(Resource):
|
||||
@api.expect(parser_marketplace)
|
||||
@console_ns.expect(parser_marketplace)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -324,7 +324,7 @@ parser_pkgapi = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/marketplace/pkg")
|
||||
class PluginFetchMarketplacePkgApi(Resource):
|
||||
@api.expect(parser_pkgapi)
|
||||
@console_ns.expect(parser_pkgapi)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -353,7 +353,7 @@ parser_fetch = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/fetch-manifest")
|
||||
class PluginFetchManifestApi(Resource):
|
||||
@api.expect(parser_fetch)
|
||||
@console_ns.expect(parser_fetch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -384,7 +384,7 @@ parser_tasks = (
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks")
|
||||
class PluginFetchInstallTasksApi(Resource):
|
||||
@api.expect(parser_tasks)
|
||||
@console_ns.expect(parser_tasks)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -471,7 +471,7 @@ parser_marketplace_api = (
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
|
||||
class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@api.expect(parser_marketplace_api)
|
||||
@console_ns.expect(parser_marketplace_api)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -503,7 +503,7 @@ parser_github_post = (
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/upgrade/github")
|
||||
class PluginUpgradeFromGithubApi(Resource):
|
||||
@api.expect(parser_github_post)
|
||||
@console_ns.expect(parser_github_post)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -535,7 +535,7 @@ parser_uninstall = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/uninstall")
|
||||
class PluginUninstallApi(Resource):
|
||||
@api.expect(parser_uninstall)
|
||||
@console_ns.expect(parser_uninstall)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -560,7 +560,7 @@ parser_change_post = (
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/change")
|
||||
class PluginChangePermissionApi(Resource):
|
||||
@api.expect(parser_change_post)
|
||||
@console_ns.expect(parser_change_post)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -618,7 +618,7 @@ parser_dynamic = (
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
|
||||
class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
@api.expect(parser_dynamic)
|
||||
@console_ns.expect(parser_dynamic)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -655,7 +655,7 @@ parser_change = (
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
||||
class PluginChangePreferencesApi(Resource):
|
||||
@api.expect(parser_change)
|
||||
@console_ns.expect(parser_change)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -749,7 +749,7 @@ parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, re
|
|||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
|
||||
class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@api.expect(parser_exclude)
|
||||
@console_ns.expect(parser_exclude)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
|
|
@ -65,7 +65,7 @@ parser_tool = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-providers")
|
||||
class ToolProviderListApi(Resource):
|
||||
@api.expect(parser_tool)
|
||||
@console_ns.expect(parser_tool)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -113,7 +113,7 @@ parser_delete = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@api.expect(parser_delete)
|
||||
@console_ns.expect(parser_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -140,7 +140,7 @@ parser_add = (
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
|
||||
class ToolBuiltinProviderAddApi(Resource):
|
||||
@api.expect(parser_add)
|
||||
@console_ns.expect(parser_add)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -174,7 +174,7 @@ parser_update = (
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@api.expect(parser_update)
|
||||
@console_ns.expect(parser_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -236,7 +236,7 @@ parser_api_add = (
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/add")
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
@api.expect(parser_api_add)
|
||||
@console_ns.expect(parser_api_add)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -267,7 +267,7 @@ parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/remote")
|
||||
class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@api.expect(parser_remote)
|
||||
@console_ns.expect(parser_remote)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -292,7 +292,7 @@ parser_tools = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/tools")
|
||||
class ToolApiProviderListToolsApi(Resource):
|
||||
@api.expect(parser_tools)
|
||||
@console_ns.expect(parser_tools)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -328,7 +328,7 @@ parser_api_update = (
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/update")
|
||||
class ToolApiProviderUpdateApi(Resource):
|
||||
@api.expect(parser_api_update)
|
||||
@console_ns.expect(parser_api_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -362,7 +362,7 @@ parser_api_delete = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/delete")
|
||||
class ToolApiProviderDeleteApi(Resource):
|
||||
@api.expect(parser_api_delete)
|
||||
@console_ns.expect(parser_api_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -386,7 +386,7 @@ parser_get = reqparse.RequestParser().add_argument("provider", type=str, require
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/get")
|
||||
class ToolApiProviderGetApi(Resource):
|
||||
@api.expect(parser_get)
|
||||
@console_ns.expect(parser_get)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -426,7 +426,7 @@ parser_schema = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/schema")
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@api.expect(parser_schema)
|
||||
@console_ns.expect(parser_schema)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -451,7 +451,7 @@ parser_pre = (
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/api/test/pre")
|
||||
class ToolApiProviderPreviousTestApi(Resource):
|
||||
@api.expect(parser_pre)
|
||||
@console_ns.expect(parser_pre)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -484,7 +484,7 @@ parser_create = (
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/create")
|
||||
class ToolWorkflowProviderCreateApi(Resource):
|
||||
@api.expect(parser_create)
|
||||
@console_ns.expect(parser_create)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -525,7 +525,7 @@ parser_workflow_update = (
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/update")
|
||||
class ToolWorkflowProviderUpdateApi(Resource):
|
||||
@api.expect(parser_workflow_update)
|
||||
@console_ns.expect(parser_workflow_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -560,7 +560,7 @@ parser_workflow_delete = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/delete")
|
||||
class ToolWorkflowProviderDeleteApi(Resource):
|
||||
@api.expect(parser_workflow_delete)
|
||||
@console_ns.expect(parser_workflow_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -588,7 +588,7 @@ parser_wf_get = (
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/get")
|
||||
class ToolWorkflowProviderGetApi(Resource):
|
||||
@api.expect(parser_wf_get)
|
||||
@console_ns.expect(parser_wf_get)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -624,7 +624,7 @@ parser_wf_tools = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/workflow/tools")
|
||||
class ToolWorkflowProviderListToolApi(Resource):
|
||||
@api.expect(parser_wf_tools)
|
||||
@console_ns.expect(parser_wf_tools)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -813,7 +813,7 @@ parser_default_cred = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
|
||||
class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@api.expect(parser_default_cred)
|
||||
@console_ns.expect(parser_default_cred)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -834,7 +834,7 @@ parser_custom = (
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
|
||||
class ToolOAuthCustomClient(Resource):
|
||||
@api.expect(parser_custom)
|
||||
@console_ns.expect(parser_custom)
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
|
@ -932,7 +932,7 @@ parser_mcp_delete = reqparse.RequestParser().add_argument(
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp")
|
||||
class ToolProviderMCPApi(Resource):
|
||||
@api.expect(parser_mcp)
|
||||
@console_ns.expect(parser_mcp)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -962,7 +962,7 @@ class ToolProviderMCPApi(Resource):
|
|||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@api.expect(parser_mcp_put)
|
||||
@console_ns.expect(parser_mcp_put)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -1001,7 +1001,7 @@ class ToolProviderMCPApi(Resource):
|
|||
)
|
||||
return {"result": "success"}
|
||||
|
||||
@api.expect(parser_mcp_delete)
|
||||
@console_ns.expect(parser_mcp_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -1024,7 +1024,7 @@ parser_auth = (
|
|||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
|
||||
class ToolMCPAuthApi(Resource):
|
||||
@api.expect(parser_auth)
|
||||
@console_ns.expect(parser_auth)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -1142,7 +1142,7 @@ parser_cb = (
|
|||
|
||||
@console_ns.route("/mcp/oauth/callback")
|
||||
class ToolMCPCallbackApi(Resource):
|
||||
@api.expect(parser_cb)
|
||||
@console_ns.expect(parser_cb)
|
||||
def get(self):
|
||||
args = parser_cb.parse_args()
|
||||
state_key = args["state"]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -539,45 +539,49 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
|
||||
|
||||
# Trigger Subscription
|
||||
api.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider/<path:provider>/icon")
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
api.add_resource(
|
||||
console_ns.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider/<path:provider>/icon")
|
||||
console_ns.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||
console_ns.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list"
|
||||
)
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionDeleteApi,
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
|
||||
# Trigger Subscription Builder
|
||||
api.add_resource(
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
api.add_resource(
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
console_ns.add_resource(
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
|
||||
|
||||
# OAuth
|
||||
api.add_resource(
|
||||
console_ns.add_resource(
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
|
||||
)
|
||||
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
console_ns.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
console_ns.add_resource(
|
||||
TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from controllers.common.errors import (
|
|||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.admin import admin_required
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import (
|
||||
|
|
@ -155,7 +155,7 @@ parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, req
|
|||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
class SwitchWorkspaceApi(Resource):
|
||||
@api.expect(parser_switch)
|
||||
@console_ns.expect(parser_switch)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -250,7 +250,7 @@ parser_info = reqparse.RequestParser().add_argument("name", type=str, required=T
|
|||
|
||||
@console_ns.route("/workspaces/info")
|
||||
class WorkspaceInfoApi(Resource):
|
||||
@api.expect(parser_info)
|
||||
@console_ns.expect(parser_info)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ class VariableEntity(BaseModel):
|
|||
type: VariableEntityType
|
||||
required: bool = False
|
||||
hide: bool = False
|
||||
default: Any = None
|
||||
max_length: int | None = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
|
|
|
|||
|
|
@ -93,7 +93,11 @@ class BaseAppGenerator:
|
|||
if value is None:
|
||||
if variable_entity.required:
|
||||
raise ValueError(f"{variable_entity.variable} is required in input form")
|
||||
return value
|
||||
# Use default value and continue validation to ensure type conversion
|
||||
value = variable_entity.default
|
||||
# If default is also None, return None directly
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if variable_entity.type in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from enum import StrEnum
|
|||
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
|
||||
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
|
||||
from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path
|
||||
|
||||
|
||||
class TracingProviderEnum(StrEnum):
|
||||
|
|
@ -13,6 +13,8 @@ class TracingProviderEnum(StrEnum):
|
|||
OPIK = "opik"
|
||||
WEAVE = "weave"
|
||||
ALIYUN = "aliyun"
|
||||
MLFLOW = "mlflow"
|
||||
DATABRICKS = "databricks"
|
||||
TENCENT = "tencent"
|
||||
|
||||
|
||||
|
|
@ -223,5 +225,47 @@ class TencentConfig(BaseTracingConfig):
|
|||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
|
||||
class MLflowConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for MLflow tracing config.
|
||||
"""
|
||||
|
||||
tracking_uri: str = "http://localhost:5000"
|
||||
experiment_id: str = "0" # Default experiment id in MLflow is 0
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
|
||||
@field_validator("tracking_uri")
|
||||
@classmethod
|
||||
def tracking_uri_validator(cls, v, info: ValidationInfo):
|
||||
if isinstance(v, str) and v.startswith("databricks"):
|
||||
raise ValueError(
|
||||
"Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
|
||||
)
|
||||
return validate_url_with_path(v, "http://localhost:5000")
|
||||
|
||||
@field_validator("experiment_id")
|
||||
@classmethod
|
||||
def experiment_id_validator(cls, v, info: ValidationInfo):
|
||||
return validate_integer_id(v)
|
||||
|
||||
|
||||
class DatabricksConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Databricks (Databricks-managed MLflow) tracing config.
|
||||
"""
|
||||
|
||||
experiment_id: str
|
||||
host: str
|
||||
client_id: str | None = None
|
||||
client_secret: str | None = None
|
||||
personal_access_token: str | None = None
|
||||
|
||||
@field_validator("experiment_id")
|
||||
@classmethod
|
||||
def experiment_id_validator(cls, v, info: ValidationInfo):
|
||||
return validate_integer_id(v)
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,549 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import mlflow
|
||||
from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType
|
||||
from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey
|
||||
from mlflow.tracing.fluent import start_span_no_context, update_current_trace
|
||||
from mlflow.tracing.provider import detach_span_from_context, set_span_in_context
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def datetime_to_nanoseconds(dt: datetime | None) -> int | None:
|
||||
"""Convert datetime to nanosecond timestamp for MLflow API"""
|
||||
if dt is None:
|
||||
return None
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
class MLflowDataTrace(BaseTraceInstance):
|
||||
def __init__(self, config: MLflowConfig | DatabricksConfig):
|
||||
super().__init__(config)
|
||||
if isinstance(config, DatabricksConfig):
|
||||
self._setup_databricks(config)
|
||||
else:
|
||||
self._setup_mlflow(config)
|
||||
|
||||
# Enable async logging to minimize performance overhead
|
||||
os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] = "true"
|
||||
|
||||
def _setup_databricks(self, config: DatabricksConfig):
|
||||
"""Setup connection to Databricks-managed MLflow instances"""
|
||||
os.environ["DATABRICKS_HOST"] = config.host
|
||||
|
||||
if config.client_id and config.client_secret:
|
||||
# OAuth: https://docs.databricks.com/aws/en/dev-tools/auth/oauth-m2m?language=Environment
|
||||
os.environ["DATABRICKS_CLIENT_ID"] = config.client_id
|
||||
os.environ["DATABRICKS_CLIENT_SECRET"] = config.client_secret
|
||||
elif config.personal_access_token:
|
||||
# PAT: https://docs.databricks.com/aws/en/dev-tools/auth/pat
|
||||
os.environ["DATABRICKS_TOKEN"] = config.personal_access_token
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either Databricks token (PAT) or client id and secret (OAuth) must be provided"
|
||||
"See https://docs.databricks.com/aws/en/dev-tools/auth/#what-authorization-option-should-i-choose "
|
||||
"for more information about the authorization options."
|
||||
)
|
||||
mlflow.set_tracking_uri("databricks")
|
||||
mlflow.set_experiment(experiment_id=config.experiment_id)
|
||||
|
||||
# Remove trailing slash from host
|
||||
config.host = config.host.rstrip("/")
|
||||
self._project_url = f"{config.host}/ml/experiments/{config.experiment_id}/traces"
|
||||
|
||||
def _setup_mlflow(self, config: MLflowConfig):
|
||||
"""Setup connection to MLflow instances"""
|
||||
mlflow.set_tracking_uri(config.tracking_uri)
|
||||
mlflow.set_experiment(experiment_id=config.experiment_id)
|
||||
|
||||
# Simple auth if provided
|
||||
if config.username and config.password:
|
||||
os.environ["MLFLOW_TRACKING_USERNAME"] = config.username
|
||||
os.environ["MLFLOW_TRACKING_PASSWORD"] = config.password
|
||||
|
||||
self._project_url = f"{config.tracking_uri}/#/experiments/{config.experiment_id}/traces"
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
"""Simple dispatch to trace methods"""
|
||||
try:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
elif isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
elif isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
elif isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
elif isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
except Exception:
|
||||
logger.exception("[MLflow] Trace error")
|
||||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
"""Create workflow span as root, with node spans as children"""
|
||||
# fields with sys.xyz is added by Dify, they are duplicate to trace_info.metadata
|
||||
raw_inputs = trace_info.workflow_run_inputs or {}
|
||||
workflow_inputs = {k: v for k, v in raw_inputs.items() if not k.startswith("sys.")}
|
||||
|
||||
# Special inputs propagated by system
|
||||
if trace_info.query:
|
||||
workflow_inputs["query"] = trace_info.query
|
||||
|
||||
workflow_span = start_span_no_context(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
span_type=SpanType.CHAIN,
|
||||
inputs=workflow_inputs,
|
||||
attributes=trace_info.metadata,
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
|
||||
# Set reserved fields in trace-level metadata
|
||||
trace_metadata = {}
|
||||
if user_id := trace_info.metadata.get("user_id"):
|
||||
trace_metadata[TraceMetadataKey.TRACE_USER] = user_id
|
||||
if session_id := trace_info.conversation_id:
|
||||
trace_metadata[TraceMetadataKey.TRACE_SESSION] = session_id
|
||||
self._set_trace_metadata(workflow_span, trace_metadata)
|
||||
|
||||
try:
|
||||
# Create child spans for workflow nodes
|
||||
for node in self._get_workflow_nodes(trace_info.workflow_run_id):
|
||||
inputs = None
|
||||
attributes = {
|
||||
"node_id": node.id,
|
||||
"node_type": node.node_type,
|
||||
"status": node.status,
|
||||
"tenant_id": node.tenant_id,
|
||||
"app_id": node.app_id,
|
||||
"app_name": node.title,
|
||||
}
|
||||
|
||||
if node.node_type in (NodeType.LLM, NodeType.QUESTION_CLASSIFIER):
|
||||
inputs, llm_attributes = self._parse_llm_inputs_and_attributes(node)
|
||||
attributes.update(llm_attributes)
|
||||
elif node.node_type == NodeType.HTTP_REQUEST:
|
||||
inputs = node.process_data # contains request URL
|
||||
|
||||
if not inputs:
|
||||
inputs = json.loads(node.inputs) if node.inputs else {}
|
||||
|
||||
node_span = start_span_no_context(
|
||||
name=node.title,
|
||||
span_type=self._get_node_span_type(node.node_type),
|
||||
parent_span=workflow_span,
|
||||
inputs=inputs,
|
||||
attributes=attributes,
|
||||
start_time_ns=datetime_to_nanoseconds(node.created_at),
|
||||
)
|
||||
|
||||
# Handle node errors
|
||||
if node.status != "succeeded":
|
||||
node_span.set_status(SpanStatusCode.ERROR)
|
||||
node_span.add_event(
|
||||
SpanEvent( # type: ignore[abstract]
|
||||
name="exception",
|
||||
attributes={
|
||||
"exception.message": f"Node failed with status: {node.status}",
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": f"Node failed with status: {node.status}",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# End node span
|
||||
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
|
||||
outputs = json.loads(node.outputs) if node.outputs else {}
|
||||
if node.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
outputs = self._parse_knowledge_retrieval_outputs(outputs)
|
||||
elif node.node_type == NodeType.LLM:
|
||||
outputs = outputs.get("text", outputs)
|
||||
node_span.end(
|
||||
outputs=outputs,
|
||||
end_time_ns=datetime_to_nanoseconds(finished_at),
|
||||
)
|
||||
|
||||
# Handle workflow-level errors
|
||||
if trace_info.error:
|
||||
workflow_span.set_status(SpanStatusCode.ERROR)
|
||||
workflow_span.add_event(
|
||||
SpanEvent( # type: ignore[abstract]
|
||||
name="exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
finally:
|
||||
workflow_span.end(
|
||||
outputs=trace_info.workflow_run_outputs,
|
||||
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
|
||||
)
|
||||
|
||||
def _parse_llm_inputs_and_attributes(self, node: WorkflowNodeExecutionModel) -> tuple[Any, dict]:
|
||||
"""Parse LLM inputs and attributes from LLM workflow node"""
|
||||
if node.process_data is None:
|
||||
return {}, {}
|
||||
|
||||
try:
|
||||
data = json.loads(node.process_data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}, {}
|
||||
|
||||
inputs = self._parse_prompts(data.get("prompts"))
|
||||
attributes = {
|
||||
"model_name": data.get("model_name"),
|
||||
"model_provider": data.get("model_provider"),
|
||||
"finish_reason": data.get("finish_reason"),
|
||||
}
|
||||
|
||||
if hasattr(SpanAttributeKey, "MESSAGE_FORMAT"):
|
||||
attributes[SpanAttributeKey.MESSAGE_FORMAT] = "dify"
|
||||
|
||||
if usage := data.get("usage"):
|
||||
# Set reserved token usage attributes
|
||||
attributes[SpanAttributeKey.CHAT_USAGE] = {
|
||||
TokenUsageKey.INPUT_TOKENS: usage.get("prompt_tokens", 0),
|
||||
TokenUsageKey.OUTPUT_TOKENS: usage.get("completion_tokens", 0),
|
||||
TokenUsageKey.TOTAL_TOKENS: usage.get("total_tokens", 0),
|
||||
}
|
||||
# Store raw usage data as well as it includes more data like price
|
||||
attributes["usage"] = usage
|
||||
|
||||
return inputs, attributes
|
||||
|
||||
def _parse_knowledge_retrieval_outputs(self, outputs: dict):
|
||||
"""Parse KR outputs and attributes from KR workflow node"""
|
||||
retrieved = outputs.get("result", [])
|
||||
|
||||
if not retrieved or not isinstance(retrieved, list):
|
||||
return outputs
|
||||
|
||||
documents = []
|
||||
for item in retrieved:
|
||||
documents.append(Document(page_content=item.get("content", ""), metadata=item.get("metadata", {})))
|
||||
return documents
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
"""Create span for CHATBOT message processing"""
|
||||
if not trace_info.message_data:
|
||||
return
|
||||
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
if message_file_data := trace_info.message_file_data:
|
||||
base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
file_list.append(f"{base_url}/{message_file_data.url}")
|
||||
|
||||
span = start_span_no_context(
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
span_type=SpanType.LLM,
|
||||
inputs=self._parse_prompts(trace_info.inputs), # type: ignore[arg-type]
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"model_provider": trace_info.message_data.model_provider,
|
||||
"model_id": trace_info.message_data.model_id,
|
||||
"conversation_mode": trace_info.conversation_mode,
|
||||
"file_list": file_list, # type: ignore[dict-item]
|
||||
"total_price": trace_info.message_data.total_price,
|
||||
**trace_info.metadata,
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
|
||||
if hasattr(SpanAttributeKey, "MESSAGE_FORMAT"):
|
||||
span.set_attribute(SpanAttributeKey.MESSAGE_FORMAT, "dify")
|
||||
|
||||
# Set token usage
|
||||
span.set_attribute(
|
||||
SpanAttributeKey.CHAT_USAGE,
|
||||
{
|
||||
TokenUsageKey.INPUT_TOKENS: trace_info.message_tokens or 0,
|
||||
TokenUsageKey.OUTPUT_TOKENS: trace_info.answer_tokens or 0,
|
||||
TokenUsageKey.TOTAL_TOKENS: trace_info.total_tokens or 0,
|
||||
},
|
||||
)
|
||||
|
||||
# Set reserved fields in trace-level metadata
|
||||
trace_metadata = {}
|
||||
if user_id := self._get_message_user_id(trace_info.metadata):
|
||||
trace_metadata[TraceMetadataKey.TRACE_USER] = user_id
|
||||
if session_id := trace_info.metadata.get("conversation_id"):
|
||||
trace_metadata[TraceMetadataKey.TRACE_SESSION] = session_id
|
||||
self._set_trace_metadata(span, trace_metadata)
|
||||
|
||||
if trace_info.error:
|
||||
span.set_status(SpanStatusCode.ERROR)
|
||||
span.add_event(
|
||||
SpanEvent( # type: ignore[abstract]
|
||||
name="error",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
span.end(
|
||||
outputs=trace_info.message_data.answer,
|
||||
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
|
||||
)
|
||||
|
||||
def _get_message_user_id(self, metadata: dict) -> str | None:
|
||||
if (end_user_id := metadata.get("from_end_user_id")) and (
|
||||
end_user_data := db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||
):
|
||||
return end_user_data.session_id
|
||||
|
||||
return metadata.get("from_account_id") # type: ignore[return-value]
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
span = start_span_no_context(
|
||||
name=trace_info.tool_name,
|
||||
span_type=SpanType.TOOL,
|
||||
inputs=trace_info.tool_inputs, # type: ignore[arg-type]
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"metadata": trace_info.metadata, # type: ignore[dict-item]
|
||||
"tool_config": trace_info.tool_config, # type: ignore[dict-item]
|
||||
"tool_parameters": trace_info.tool_parameters, # type: ignore[dict-item]
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
|
||||
# Handle tool errors
|
||||
if trace_info.error:
|
||||
span.set_status(SpanStatusCode.ERROR)
|
||||
span.add_event(
|
||||
SpanEvent( # type: ignore[abstract]
|
||||
name="error",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
span.end(
|
||||
outputs=trace_info.tool_outputs,
|
||||
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
|
||||
)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
span = start_span_no_context(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
span_type=SpanType.TOOL,
|
||||
inputs=trace_info.inputs or {},
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"metadata": trace_info.metadata, # type: ignore[dict-item]
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(start_time),
|
||||
)
|
||||
|
||||
span.end(
|
||||
outputs={
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
},
|
||||
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
|
||||
)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
span = start_span_no_context(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
span_type=SpanType.RETRIEVER,
|
||||
inputs=trace_info.inputs,
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"metadata": trace_info.metadata, # type: ignore[dict-item]
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
span.end(outputs={"documents": trace_info.documents}, end_time_ns=datetime_to_nanoseconds(trace_info.end_time))
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
span = start_span_no_context(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
span_type=SpanType.TOOL,
|
||||
inputs=trace_info.inputs,
|
||||
attributes={
|
||||
"message_id": trace_info.message_id, # type: ignore[dict-item]
|
||||
"model_provider": trace_info.model_provider, # type: ignore[dict-item]
|
||||
"model_id": trace_info.model_id, # type: ignore[dict-item]
|
||||
"total_tokens": trace_info.total_tokens or 0, # type: ignore[dict-item]
|
||||
},
|
||||
start_time_ns=datetime_to_nanoseconds(start_time),
|
||||
)
|
||||
|
||||
if trace_info.error:
|
||||
span.set_status(SpanStatusCode.ERROR)
|
||||
span.add_event(
|
||||
SpanEvent( # type: ignore[abstract]
|
||||
name="error",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
span.end(outputs=trace_info.suggested_question, end_time_ns=datetime_to_nanoseconds(end_time))
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
span = start_span_no_context(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
span_type=SpanType.CHAIN,
|
||||
inputs=trace_info.inputs,
|
||||
attributes={"message_id": trace_info.message_id}, # type: ignore[dict-item]
|
||||
start_time_ns=datetime_to_nanoseconds(trace_info.start_time),
|
||||
)
|
||||
span.end(outputs=trace_info.outputs, end_time_ns=datetime_to_nanoseconds(trace_info.end_time))
|
||||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.inputs,
|
||||
WorkflowNodeExecutionModel.outputs,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.order_by(WorkflowNodeExecutionModel.created_at)
|
||||
.all()
|
||||
)
|
||||
return workflow_nodes
|
||||
|
||||
def _get_node_span_type(self, node_type: str) -> str:
|
||||
"""Map Dify node types to MLflow span types"""
|
||||
node_type_mapping = {
|
||||
NodeType.LLM: SpanType.LLM,
|
||||
NodeType.QUESTION_CLASSIFIER: SpanType.LLM,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
|
||||
NodeType.TOOL: SpanType.TOOL,
|
||||
NodeType.CODE: SpanType.TOOL,
|
||||
NodeType.HTTP_REQUEST: SpanType.TOOL,
|
||||
NodeType.AGENT: SpanType.AGENT,
|
||||
}
|
||||
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
|
||||
|
||||
def _set_trace_metadata(self, span: Span, metadata: dict):
|
||||
token = None
|
||||
try:
|
||||
# NB: Set span in context such that we can use update_current_trace() API
|
||||
token = set_span_in_context(span)
|
||||
update_current_trace(metadata=metadata)
|
||||
finally:
|
||||
if token:
|
||||
detach_span_from_context(token)
|
||||
|
||||
def _parse_prompts(self, prompts):
|
||||
"""Postprocess prompts format to be standard chat messages"""
|
||||
if isinstance(prompts, str):
|
||||
return prompts
|
||||
elif isinstance(prompts, dict):
|
||||
return self._parse_single_message(prompts)
|
||||
elif isinstance(prompts, list):
|
||||
messages = [self._parse_single_message(item) for item in prompts]
|
||||
messages = self._resolve_tool_call_ids(messages)
|
||||
return messages
|
||||
return prompts # Fallback to original format
|
||||
|
||||
def _parse_single_message(self, item: dict):
|
||||
"""Postprocess single message format to be standard chat message"""
|
||||
role = item.get("role", "user")
|
||||
msg = {"role": role, "content": item.get("text", "")}
|
||||
|
||||
if (
|
||||
(tool_calls := item.get("tool_calls"))
|
||||
# Tool message does not contain tool calls normally
|
||||
and role != "tool"
|
||||
):
|
||||
msg["tool_calls"] = tool_calls
|
||||
|
||||
if files := item.get("files"):
|
||||
msg["files"] = files
|
||||
|
||||
return msg
|
||||
|
||||
def _resolve_tool_call_ids(self, messages: list[dict]):
|
||||
"""
|
||||
The tool call message from Dify does not contain tool call ids, which is not
|
||||
ideal for debugging. This method resolves the tool call ids by matching the
|
||||
tool call name and parameters with the tool instruction messages.
|
||||
"""
|
||||
tool_call_ids = []
|
||||
for msg in messages:
|
||||
if tool_calls := msg.get("tool_calls"):
|
||||
tool_call_ids = [t["id"] for t in tool_calls]
|
||||
if msg["role"] == "tool":
|
||||
# Get the tool call id in the order of the tool call messages
|
||||
# assuming Dify runs tools sequentially
|
||||
if tool_call_ids:
|
||||
msg["tool_call_id"] = tool_call_ids.pop(0)
|
||||
return messages
|
||||
|
||||
def api_check(self):
|
||||
"""Simple connection test"""
|
||||
try:
|
||||
mlflow.search_experiments(max_results=1)
|
||||
return True
|
||||
except Exception as e:
|
||||
raise ValueError(f"MLflow connection failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
return self._project_url
|
||||
|
|
@ -120,6 +120,26 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
|||
"other_keys": ["endpoint", "app_name"],
|
||||
"trace_instance": AliyunDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.MLFLOW:
|
||||
from core.ops.entities.config_entity import MLflowConfig
|
||||
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
|
||||
|
||||
return {
|
||||
"config_class": MLflowConfig,
|
||||
"secret_keys": ["password"],
|
||||
"other_keys": ["tracking_uri", "experiment_id", "username"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.DATABRICKS:
|
||||
from core.ops.entities.config_entity import DatabricksConfig
|
||||
from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
|
||||
|
||||
return {
|
||||
"config_class": DatabricksConfig,
|
||||
"secret_keys": ["personal_access_token", "client_secret"],
|
||||
"other_keys": ["host", "client_id", "experiment_id"],
|
||||
"trace_instance": MLflowDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.TENCENT:
|
||||
from core.ops.entities.config_entity import TencentConfig
|
||||
|
|
|
|||
|
|
@ -147,3 +147,14 @@ def validate_project_name(project: str, default_name: str) -> str:
|
|||
return default_name
|
||||
|
||||
return project.strip()
|
||||
|
||||
|
||||
def validate_integer_id(id_str: str) -> str:
|
||||
"""
|
||||
Validate and normalize integer ID
|
||||
"""
|
||||
id_str = id_str.strip()
|
||||
if not id_str.isdigit():
|
||||
raise ValueError("ID must be a valid integer")
|
||||
|
||||
return id_str
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import base64
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -89,8 +90,8 @@ class CacheEmbedding(Embeddings):
|
|||
model_name=self._model_instance.model,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||
)
|
||||
embedding_cache.set_embedding(n_embedding)
|
||||
db.session.add(embedding_cache)
|
||||
cache_embeddings.append(hash)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -63,7 +63,19 @@ class RedisSubscriptionBase(Subscription):
|
|||
pubsub = self._pubsub
|
||||
assert pubsub is not None, "PubSub should not be None while starting listening."
|
||||
while not self._closed.is_set():
|
||||
raw_message = self._get_message()
|
||||
try:
|
||||
raw_message = self._get_message()
|
||||
except Exception as e:
|
||||
# Log the exception and exit the listener thread gracefully
|
||||
# This handles Redis connection errors and other exceptions
|
||||
_logger.error(
|
||||
"Error getting message from Redis %s subscription, topic=%s: %s",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
break
|
||||
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
|
@ -98,10 +110,20 @@ class RedisSubscriptionBase(Subscription):
|
|||
self._enqueue_message(payload_bytes)
|
||||
|
||||
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
|
||||
self._unsubscribe()
|
||||
pubsub.close()
|
||||
_logger.debug("%s PubSub closed for topic %s", self._get_subscription_type().title(), self._topic)
|
||||
self._pubsub = None
|
||||
try:
|
||||
self._unsubscribe()
|
||||
pubsub.close()
|
||||
_logger.debug("%s PubSub closed for topic %s", self._get_subscription_type().title(), self._topic)
|
||||
except Exception as e:
|
||||
_logger.error(
|
||||
"Error during cleanup of Redis %s subscription, topic=%s: %s",
|
||||
self._get_subscription_type(),
|
||||
self._topic,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
self._pubsub = None
|
||||
|
||||
def _enqueue_message(self, payload: bytes) -> None:
|
||||
"""Enqueue a message to the internal queue with dropping behavior."""
|
||||
|
|
|
|||
|
|
@ -307,7 +307,7 @@ class Dataset(Base):
|
|||
return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node"
|
||||
|
||||
|
||||
class DatasetProcessRule(Base):
|
||||
class DatasetProcessRule(Base): # bug
|
||||
__tablename__ = "dataset_process_rules"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
|
||||
|
|
@ -1004,7 +1004,7 @@ class DatasetKeywordTable(TypeBase):
|
|||
return None
|
||||
|
||||
|
||||
class Embedding(Base):
|
||||
class Embedding(TypeBase):
|
||||
__tablename__ = "embeddings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="embedding_pkey"),
|
||||
|
|
@ -1012,12 +1012,16 @@ class Embedding(Base):
|
|||
sa.Index("created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
|
||||
model_name = mapped_column(String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'"))
|
||||
hash = mapped_column(String(64), nullable=False)
|
||||
embedding = mapped_column(BinaryData, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
provider_name = mapped_column(String(255), nullable=False, server_default=sa.text("''"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
|
||||
model_name: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'")
|
||||
)
|
||||
hash: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
embedding: Mapped[bytes] = mapped_column(BinaryData, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("''"))
|
||||
|
||||
def set_embedding(self, embedding_data: list[float]):
|
||||
self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
|
@ -1026,19 +1030,21 @@ class Embedding(Base):
|
|||
return cast(list[float], pickle.loads(self.embedding)) # noqa: S301
|
||||
|
||||
|
||||
class DatasetCollectionBinding(Base):
|
||||
class DatasetCollectionBinding(TypeBase):
|
||||
__tablename__ = "dataset_collection_bindings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
|
||||
sa.Index("provider_model_name_idx", "provider_name", "model_name"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
type = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
|
||||
collection_name = mapped_column(String(64), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False)
|
||||
collection_name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class TidbAuthBinding(Base):
|
||||
|
|
@ -1176,7 +1182,7 @@ class ExternalKnowledgeBindings(TypeBase):
|
|||
)
|
||||
|
||||
|
||||
class DatasetAutoDisableLog(Base):
|
||||
class DatasetAutoDisableLog(TypeBase):
|
||||
__tablename__ = "dataset_auto_disable_logs"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
|
||||
|
|
@ -1185,12 +1191,14 @@ class DatasetAutoDisableLog(Base):
|
|||
sa.Index("dataset_auto_disable_log_created_atx", "created_at"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class RateLimitLog(TypeBase):
|
||||
|
|
@ -1210,7 +1218,7 @@ class RateLimitLog(TypeBase):
|
|||
)
|
||||
|
||||
|
||||
class DatasetMetadata(Base):
|
||||
class DatasetMetadata(TypeBase):
|
||||
__tablename__ = "dataset_metadatas"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
|
||||
|
|
@ -1218,20 +1226,26 @@ class DatasetMetadata(Base):
|
|||
sa.Index("dataset_metadata_dataset_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
updated_by: Mapped[str] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
|
||||
|
||||
class DatasetMetadataBinding(Base):
|
||||
class DatasetMetadataBinding(TypeBase):
|
||||
__tablename__ = "dataset_metadata_bindings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
|
||||
|
|
@ -1241,13 +1255,15 @@ class DatasetMetadataBinding(Base):
|
|||
sa.Index("dataset_metadata_binding_document_idx", "document_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
metadata_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
metadata_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
|
||||
class PipelineBuiltInTemplate(TypeBase):
|
||||
|
|
@ -1315,22 +1331,30 @@ class PipelineCustomizedTemplate(TypeBase):
|
|||
return ""
|
||||
|
||||
|
||||
class Pipeline(Base): # type: ignore[name-defined]
|
||||
class Pipeline(TypeBase):
|
||||
__tablename__ = "pipelines"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name = mapped_column(sa.String(255), nullable=False)
|
||||
description = mapped_column(LongText, nullable=False, default=sa.text("''"))
|
||||
workflow_id = mapped_column(StringUUID, nullable=True)
|
||||
is_public = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
is_published = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(LongText, nullable=False, default=sa.text("''"))
|
||||
workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
is_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
|
||||
is_published: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false"), default=False
|
||||
)
|
||||
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
def retrieve_dataset(self, session: Session):
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
|||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.file import helpers as file_helpers
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
|
|
@ -533,7 +533,7 @@ class AppModelConfig(Base):
|
|||
return self
|
||||
|
||||
|
||||
class RecommendedApp(Base):
|
||||
class RecommendedApp(Base): # bug
|
||||
__tablename__ = "recommended_apps"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
|
||||
|
|
@ -594,7 +594,7 @@ class InstalledApp(TypeBase):
|
|||
return tenant
|
||||
|
||||
|
||||
class OAuthProviderApp(Base):
|
||||
class OAuthProviderApp(TypeBase):
|
||||
"""
|
||||
Globally shared OAuth provider app information.
|
||||
Only for Dify Cloud.
|
||||
|
|
@ -606,18 +606,21 @@ class OAuthProviderApp(Base):
|
|||
sa.Index("oauth_provider_app_client_id_idx", "client_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
app_icon = mapped_column(String(255), nullable=False)
|
||||
app_label = mapped_column(sa.JSON, nullable=False, default="{}")
|
||||
client_id = mapped_column(String(255), nullable=False)
|
||||
client_secret = mapped_column(String(255), nullable=False)
|
||||
redirect_uris = mapped_column(sa.JSON, nullable=False, default="[]")
|
||||
scope = mapped_column(
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
app_icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
client_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
client_secret: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
app_label: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default_factory=dict)
|
||||
redirect_uris: Mapped[list] = mapped_column(sa.JSON, nullable=False, default_factory=list)
|
||||
scope: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"),
|
||||
default="read:name read:email read:avatar read:interface_language read:timezone",
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
|
|
@ -1291,7 +1294,7 @@ class Message(Base):
|
|||
)
|
||||
|
||||
|
||||
class MessageFeedback(Base):
|
||||
class MessageFeedback(TypeBase):
|
||||
__tablename__ = "message_feedbacks"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
|
||||
|
|
@ -1300,18 +1303,24 @@ class MessageFeedback(Base):
|
|||
sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
rating: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
content: Mapped[str | None] = mapped_column(LongText)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -1335,7 +1344,7 @@ class MessageFeedback(Base):
|
|||
}
|
||||
|
||||
|
||||
class MessageFile(Base):
|
||||
class MessageFile(TypeBase):
|
||||
__tablename__ = "message_files"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="message_file_pkey"),
|
||||
|
|
@ -1343,37 +1352,18 @@ class MessageFile(Base):
|
|||
sa.Index("message_file_created_by_idx", "created_by"),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
message_id: str,
|
||||
type: FileType,
|
||||
transfer_method: FileTransferMethod,
|
||||
url: str | None = None,
|
||||
belongs_to: Literal["user", "assistant"] | None = None,
|
||||
upload_file_id: str | None = None,
|
||||
created_by_role: CreatorUserRole,
|
||||
created_by: str,
|
||||
):
|
||||
self.message_id = message_id
|
||||
self.type = type
|
||||
self.transfer_method = transfer_method
|
||||
self.url = url
|
||||
self.belongs_to = belongs_to
|
||||
self.upload_file_id = upload_file_id
|
||||
self.created_by_role = created_by_role.value
|
||||
self.created_by = created_by
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
transfer_method: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
url: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
belongs_to: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class MessageAnnotation(Base):
|
||||
|
|
@ -1447,22 +1437,28 @@ class AppAnnotationHitHistory(Base):
|
|||
return account
|
||||
|
||||
|
||||
class AppAnnotationSetting(Base):
|
||||
class AppAnnotationSetting(TypeBase):
|
||||
__tablename__ = "app_annotation_settings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
|
||||
sa.Index("app_annotation_settings_app_idx", "app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
score_threshold = mapped_column(Float, nullable=False, server_default=sa.text("0"))
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=False)
|
||||
created_user_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_user_id = mapped_column(StringUUID, nullable=False)
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
score_threshold: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"))
|
||||
collection_binding_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -1477,22 +1473,28 @@ class AppAnnotationSetting(Base):
|
|||
return collection_binding_detail
|
||||
|
||||
|
||||
class OperationLog(Base):
|
||||
class OperationLog(TypeBase):
|
||||
__tablename__ = "operation_logs"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="operation_log_pkey"),
|
||||
sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
action: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
content = mapped_column(sa.JSON)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
content: Mapped[Any] = mapped_column(sa.JSON)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
created_ip: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1637,7 +1639,7 @@ class Site(Base):
|
|||
return dify_config.APP_WEB_URL or request.url_root.rstrip("/")
|
||||
|
||||
|
||||
class ApiToken(Base):
|
||||
class ApiToken(Base): # bug: this uses setattr so idk the field.
|
||||
__tablename__ = "api_tokens"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="api_token_pkey"),
|
||||
|
|
@ -1897,34 +1899,36 @@ class MessageAgentThought(Base):
|
|||
return {}
|
||||
|
||||
|
||||
class DatasetRetrieverResource(Base):
|
||||
class DatasetRetrieverResource(TypeBase):
|
||||
__tablename__ = "dataset_retriever_resources"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
|
||||
sa.Index("dataset_retriever_resource_message_id_idx", "message_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_name = mapped_column(LongText, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=True)
|
||||
document_name = mapped_column(LongText, nullable=False)
|
||||
data_source_type = mapped_column(LongText, nullable=True)
|
||||
segment_id = mapped_column(StringUUID, nullable=True)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_name: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
document_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
document_name: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
data_source_type: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
segment_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
score: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
|
||||
content = mapped_column(LongText, nullable=False)
|
||||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
hit_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
word_count: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
segment_position: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
index_node_hash = mapped_column(LongText, nullable=True)
|
||||
retriever_from = mapped_column(LongText, nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())
|
||||
index_node_hash: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
retriever_from: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
class Tag(TypeBase):
|
||||
__tablename__ = "tags"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="tag_pkey"),
|
||||
|
|
@ -1934,12 +1938,14 @@ class Tag(Base):
|
|||
|
||||
TAG_TYPE_LIST = ["knowledge", "app"]
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id = mapped_column(StringUUID, nullable=True)
|
||||
type = mapped_column(String(16), nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class TagBinding(TypeBase):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from sqlalchemy.orm import Mapped, mapped_column
|
|||
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import Base, TypeBase
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .types import LongText, StringUUID
|
||||
|
||||
|
|
@ -262,7 +262,7 @@ class ProviderModelSetting(TypeBase):
|
|||
)
|
||||
|
||||
|
||||
class LoadBalancingModelConfig(Base):
|
||||
class LoadBalancingModelConfig(TypeBase):
|
||||
"""
|
||||
Configurations for load balancing models.
|
||||
"""
|
||||
|
|
@ -273,23 +273,25 @@ class LoadBalancingModelConfig(Base):
|
|||
sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class ProviderCredential(Base):
|
||||
class ProviderCredential(TypeBase):
|
||||
"""
|
||||
Provider credential - stores multiple named credentials for each provider
|
||||
"""
|
||||
|
|
@ -300,18 +302,20 @@ class ProviderCredential(Base):
|
|||
sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class ProviderModelCredential(Base):
|
||||
class ProviderModelCredential(TypeBase):
|
||||
"""
|
||||
Provider model credential - stores multiple named credentials for each provider model
|
||||
"""
|
||||
|
|
@ -328,14 +332,16 @@ class ProviderModelCredential(Base):
|
|||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from core.trigger.utils.endpoint import generate_plugin_trigger_endpoint_url, ge
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import Base, TypeBase
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus
|
||||
from .model import Account
|
||||
|
|
@ -129,27 +129,30 @@ class TriggerOAuthSystemClient(TypeBase):
|
|||
|
||||
|
||||
# tenant level trigger oauth client params (client_id, client_secret, etc.)
|
||||
class TriggerOAuthTenantClient(Base):
|
||||
class TriggerOAuthTenantClient(TypeBase):
|
||||
__tablename__ = "trigger_oauth_tenant_clients"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="trigger_oauth_tenant_client_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_trigger_oauth_tenant_client"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
|
||||
# oauth params of the trigger provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(LongText, nullable=False, default="{}")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
server_onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -157,7 +160,7 @@ class TriggerOAuthTenantClient(Base):
|
|||
return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
|
||||
|
||||
|
||||
class WorkflowTriggerLog(Base):
|
||||
class WorkflowTriggerLog(TypeBase):
|
||||
"""
|
||||
Workflow Trigger Log
|
||||
|
||||
|
|
@ -199,7 +202,7 @@ class WorkflowTriggerLog(Base):
|
|||
sa.Index("workflow_trigger_log_workflow_id_idx", "workflow_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
@ -211,24 +214,21 @@ class WorkflowTriggerLog(Base):
|
|||
inputs: Mapped[str] = mapped_column(LongText, nullable=False) # Just inputs for easy viewing
|
||||
outputs: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
|
||||
status: Mapped[str] = mapped_column(
|
||||
EnumText(WorkflowTriggerStatus, length=50), nullable=False, default=WorkflowTriggerStatus.PENDING
|
||||
)
|
||||
status: Mapped[str] = mapped_column(EnumText(WorkflowTriggerStatus, length=50), nullable=False)
|
||||
error: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
|
||||
queue_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
|
||||
elapsed_time: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
|
||||
total_tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
triggered_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
elapsed_time: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None)
|
||||
total_tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
triggered_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class _InvalidGraphDefinitionError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class Workflow(Base):
|
||||
class Workflow(Base): # bug
|
||||
"""
|
||||
Workflow, for `Workflow App` and `Chat App workflow mode`.
|
||||
|
||||
|
|
@ -869,16 +869,20 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
|||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||
if created_by_role == CreatorUserRole.ACCOUNT:
|
||||
stmt = select(Account).where(Account.id == self.created_by)
|
||||
return db.session.scalar(stmt)
|
||||
return None
|
||||
|
||||
@property
|
||||
def created_by_end_user(self):
|
||||
from .model import EndUser
|
||||
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
# TODO(-LAN-): Avoid using db.session.get() here.
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
if created_by_role == CreatorUserRole.END_USER:
|
||||
stmt = select(EndUser).where(EndUser.id == self.created_by)
|
||||
return db.session.scalar(stmt)
|
||||
return None
|
||||
|
||||
@property
|
||||
def inputs_dict(self):
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ dependencies = [
|
|||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
"markdown~=3.5.1",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.8.72",
|
||||
|
|
@ -202,7 +203,7 @@ vdb = [
|
|||
"alibabacloud_gpdb20160503~=3.8.0",
|
||||
"alibabacloud_tea_openapi~=0.3.9",
|
||||
"chromadb==0.5.20",
|
||||
"clickhouse-connect~=0.7.16",
|
||||
"clickhouse-connect~=0.10.0",
|
||||
"clickzetta-connector-python>=0.8.102",
|
||||
"couchbase~=4.3.0",
|
||||
"elasticsearch==8.14.0",
|
||||
|
|
|
|||
|
|
@ -113,6 +113,8 @@ class AsyncWorkflowService:
|
|||
trigger_data.trigger_metadata.model_dump_json() if trigger_data.trigger_metadata else "{}"
|
||||
),
|
||||
trigger_type=trigger_data.trigger_type,
|
||||
workflow_run_id=None,
|
||||
outputs=None,
|
||||
trigger_data=trigger_data.model_dump_json(),
|
||||
inputs=json.dumps(dict(trigger_data.inputs)),
|
||||
status=WorkflowTriggerStatus.PENDING,
|
||||
|
|
@ -120,6 +122,10 @@ class AsyncWorkflowService:
|
|||
retry_count=0,
|
||||
created_by_role=created_by_role,
|
||||
created_by=created_by,
|
||||
celery_task_id=None,
|
||||
error=None,
|
||||
elapsed_time=None,
|
||||
total_tokens=None,
|
||||
)
|
||||
|
||||
trigger_log = trigger_log_repo.create(trigger_log)
|
||||
|
|
|
|||
|
|
@ -164,6 +164,7 @@ class MessageService:
|
|||
elif not rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
assert rating is not None
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
|
|
|
|||
|
|
@ -113,6 +113,24 @@ class OpsService:
|
|||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://console.cloud.tencent.com/apm"})
|
||||
|
||||
if tracing_provider == "mlflow" and (
|
||||
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
|
||||
):
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
|
||||
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "http://localhost:5000/"})
|
||||
|
||||
if tracing_provider == "databricks" and (
|
||||
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
|
||||
):
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
|
||||
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://www.databricks.com/"})
|
||||
|
||||
trace_config_data.tracing_config = new_decrypt_tracing_config
|
||||
return trace_config_data.to_dict()
|
||||
|
||||
|
|
@ -155,7 +173,7 @@ class OpsService:
|
|||
project_url = f"{tracing_config.get('host')}/project/{project_key}"
|
||||
except Exception:
|
||||
project_url = None
|
||||
elif tracing_provider in ("langsmith", "opik", "tencent"):
|
||||
elif tracing_provider in ("langsmith", "opik", "mlflow", "databricks", "tencent"):
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -580,13 +580,14 @@ class RagPipelineDslService:
|
|||
raise ValueError("Current tenant is not set")
|
||||
|
||||
# Create new app
|
||||
pipeline = Pipeline()
|
||||
pipeline = Pipeline(
|
||||
tenant_id=account.current_tenant_id,
|
||||
name=pipeline_data.get("name", ""),
|
||||
description=pipeline_data.get("description", ""),
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
pipeline.id = str(uuid4())
|
||||
pipeline.tenant_id = account.current_tenant_id
|
||||
pipeline.name = pipeline_data.get("name", "")
|
||||
pipeline.description = pipeline_data.get("description", "")
|
||||
pipeline.created_by = account.id
|
||||
pipeline.updated_by = account.id
|
||||
|
||||
self._session.add(pipeline)
|
||||
self._session.commit()
|
||||
|
|
|
|||
|
|
@ -198,15 +198,16 @@ class RagPipelineTransformService:
|
|||
graph = workflow_data.get("graph", {})
|
||||
|
||||
# Create new app
|
||||
pipeline = Pipeline()
|
||||
pipeline = Pipeline(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name=pipeline_data.get("name", ""),
|
||||
description=pipeline_data.get("description", ""),
|
||||
created_by=current_user.id,
|
||||
updated_by=current_user.id,
|
||||
is_published=True,
|
||||
is_public=True,
|
||||
)
|
||||
pipeline.id = str(uuid4())
|
||||
pipeline.tenant_id = current_user.current_tenant_id
|
||||
pipeline.name = pipeline_data.get("name", "")
|
||||
pipeline.description = pipeline_data.get("description", "")
|
||||
pipeline.created_by = current_user.id
|
||||
pipeline.updated_by = current_user.id
|
||||
pipeline.is_published = True
|
||||
pipeline.is_public = True
|
||||
|
||||
db.session.add(pipeline)
|
||||
db.session.flush()
|
||||
|
|
|
|||
|
|
@ -79,12 +79,12 @@ class TagService:
|
|||
if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]):
|
||||
raise ValueError("Tag name already exists")
|
||||
tag = Tag(
|
||||
id=str(uuid.uuid4()),
|
||||
name=args["name"],
|
||||
type=args["type"],
|
||||
created_by=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
tag.id = str(uuid.uuid4())
|
||||
db.session.add(tag)
|
||||
db.session.commit()
|
||||
return tag
|
||||
|
|
|
|||
|
|
@ -475,7 +475,7 @@ class TriggerProviderService:
|
|||
oauth_params = encrypter.decrypt(dict(tenant_client.oauth_params))
|
||||
return oauth_params
|
||||
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
|
||||
if not is_verified:
|
||||
return None
|
||||
|
||||
|
|
@ -499,7 +499,8 @@ class TriggerProviderService:
|
|||
"""
|
||||
Check if system OAuth client exists for a trigger provider.
|
||||
"""
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id=tenant_id, provider_id=provider_id)
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
|
||||
if not is_verified:
|
||||
return False
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
|
|
|
|||
|
|
@ -218,6 +218,8 @@ def _record_trigger_failure_log(
|
|||
finished_at=now,
|
||||
elapsed_time=0.0,
|
||||
total_tokens=0,
|
||||
outputs=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
session.add(trigger_log)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -852,6 +852,7 @@ class TestAgentService:
|
|||
# Add files to message
|
||||
from models.model import MessageFile
|
||||
|
||||
assert message.from_account_id is not None
|
||||
message_file1 = MessageFile(
|
||||
message_id=message.id,
|
||||
type=FileType.IMAGE,
|
||||
|
|
|
|||
|
|
@ -860,22 +860,24 @@ class TestAnnotationService:
|
|||
from models.model import AppAnnotationSetting
|
||||
|
||||
# Create a collection binding first
|
||||
collection_binding = DatasetCollectionBinding()
|
||||
collection_binding.id = fake.uuid4()
|
||||
collection_binding.provider_name = "openai"
|
||||
collection_binding.model_name = "text-embedding-ada-002"
|
||||
collection_binding.type = "annotation"
|
||||
collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}"
|
||||
collection_binding = DatasetCollectionBinding(
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
type="annotation",
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting()
|
||||
annotation_setting.app_id = app.id
|
||||
annotation_setting.score_threshold = 0.8
|
||||
annotation_setting.collection_binding_id = collection_binding.id
|
||||
annotation_setting.created_user_id = account.id
|
||||
annotation_setting.updated_user_id = account.id
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
app_id=app.id,
|
||||
score_threshold=0.8,
|
||||
collection_binding_id=collection_binding.id,
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -919,22 +921,24 @@ class TestAnnotationService:
|
|||
from models.model import AppAnnotationSetting
|
||||
|
||||
# Create a collection binding first
|
||||
collection_binding = DatasetCollectionBinding()
|
||||
collection_binding.id = fake.uuid4()
|
||||
collection_binding.provider_name = "openai"
|
||||
collection_binding.model_name = "text-embedding-ada-002"
|
||||
collection_binding.type = "annotation"
|
||||
collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}"
|
||||
collection_binding = DatasetCollectionBinding(
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
type="annotation",
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting()
|
||||
annotation_setting.app_id = app.id
|
||||
annotation_setting.score_threshold = 0.8
|
||||
annotation_setting.collection_binding_id = collection_binding.id
|
||||
annotation_setting.created_user_id = account.id
|
||||
annotation_setting.updated_user_id = account.id
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
app_id=app.id,
|
||||
score_threshold=0.8,
|
||||
collection_binding_id=collection_binding.id,
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -1020,22 +1024,24 @@ class TestAnnotationService:
|
|||
from models.model import AppAnnotationSetting
|
||||
|
||||
# Create a collection binding first
|
||||
collection_binding = DatasetCollectionBinding()
|
||||
collection_binding.id = fake.uuid4()
|
||||
collection_binding.provider_name = "openai"
|
||||
collection_binding.model_name = "text-embedding-ada-002"
|
||||
collection_binding.type = "annotation"
|
||||
collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}"
|
||||
collection_binding = DatasetCollectionBinding(
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
type="annotation",
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting()
|
||||
annotation_setting.app_id = app.id
|
||||
annotation_setting.score_threshold = 0.8
|
||||
annotation_setting.collection_binding_id = collection_binding.id
|
||||
annotation_setting.created_user_id = account.id
|
||||
annotation_setting.updated_user_id = account.id
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
app_id=app.id,
|
||||
score_threshold=0.8,
|
||||
collection_binding_id=collection_binding.id,
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -1080,22 +1086,24 @@ class TestAnnotationService:
|
|||
from models.model import AppAnnotationSetting
|
||||
|
||||
# Create a collection binding first
|
||||
collection_binding = DatasetCollectionBinding()
|
||||
collection_binding.id = fake.uuid4()
|
||||
collection_binding.provider_name = "openai"
|
||||
collection_binding.model_name = "text-embedding-ada-002"
|
||||
collection_binding.type = "annotation"
|
||||
collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}"
|
||||
collection_binding = DatasetCollectionBinding(
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
type="annotation",
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting()
|
||||
annotation_setting.app_id = app.id
|
||||
annotation_setting.score_threshold = 0.8
|
||||
annotation_setting.collection_binding_id = collection_binding.id
|
||||
annotation_setting.created_user_id = account.id
|
||||
annotation_setting.updated_user_id = account.id
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
app_id=app.id,
|
||||
score_threshold=0.8,
|
||||
collection_binding_id=collection_binding.id,
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -1151,22 +1159,25 @@ class TestAnnotationService:
|
|||
from models.model import AppAnnotationSetting
|
||||
|
||||
# Create a collection binding first
|
||||
collection_binding = DatasetCollectionBinding()
|
||||
collection_binding.id = fake.uuid4()
|
||||
collection_binding.provider_name = "openai"
|
||||
collection_binding.model_name = "text-embedding-ada-002"
|
||||
collection_binding.type = "annotation"
|
||||
collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}"
|
||||
collection_binding = DatasetCollectionBinding(
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
type="annotation",
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting()
|
||||
annotation_setting.app_id = app.id
|
||||
annotation_setting.score_threshold = 0.8
|
||||
annotation_setting.collection_binding_id = collection_binding.id
|
||||
annotation_setting.created_user_id = account.id
|
||||
annotation_setting.updated_user_id = account.id
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
app_id=app.id,
|
||||
score_threshold=0.8,
|
||||
collection_binding_id=collection_binding.id,
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -1211,22 +1222,24 @@ class TestAnnotationService:
|
|||
from models.model import AppAnnotationSetting
|
||||
|
||||
# Create a collection binding first
|
||||
collection_binding = DatasetCollectionBinding()
|
||||
collection_binding.id = fake.uuid4()
|
||||
collection_binding.provider_name = "openai"
|
||||
collection_binding.model_name = "text-embedding-ada-002"
|
||||
collection_binding.type = "annotation"
|
||||
collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}"
|
||||
collection_binding = DatasetCollectionBinding(
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
type="annotation",
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting()
|
||||
annotation_setting.app_id = app.id
|
||||
annotation_setting.score_threshold = 0.8
|
||||
annotation_setting.collection_binding_id = collection_binding.id
|
||||
annotation_setting.created_user_id = account.id
|
||||
annotation_setting.updated_user_id = account.id
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
app_id=app.id,
|
||||
score_threshold=0.8,
|
||||
collection_binding_id=collection_binding.id,
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -502,11 +502,11 @@ class TestAddDocumentToIndexTask:
|
|||
auto_disable_logs = []
|
||||
for _ in range(2):
|
||||
log_entry = DatasetAutoDisableLog(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
)
|
||||
log_entry.id = str(fake.uuid4())
|
||||
db.session.add(log_entry)
|
||||
auto_disable_logs.append(log_entry)
|
||||
|
||||
|
|
|
|||
|
|
@ -384,24 +384,24 @@ class TestCleanDatasetTask:
|
|||
|
||||
# Create dataset metadata and bindings
|
||||
metadata = DatasetMetadata(
|
||||
id=str(uuid.uuid4()),
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant.id,
|
||||
name="test_metadata",
|
||||
type="string",
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
metadata.id = str(uuid.uuid4())
|
||||
metadata.created_at = datetime.now()
|
||||
|
||||
binding = DatasetMetadataBinding(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
metadata_id=metadata.id,
|
||||
document_id=documents[0].id, # Use first document as example
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
binding.id = str(uuid.uuid4())
|
||||
binding.created_at = datetime.now()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
|
@ -697,26 +697,26 @@ class TestCleanDatasetTask:
|
|||
|
||||
for i in range(10): # Create 10 metadata items
|
||||
metadata = DatasetMetadata(
|
||||
id=str(uuid.uuid4()),
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant.id,
|
||||
name=f"test_metadata_{i}",
|
||||
type="string",
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
metadata.id = str(uuid.uuid4())
|
||||
metadata.created_at = datetime.now()
|
||||
metadata_items.append(metadata)
|
||||
|
||||
# Create binding for each metadata item
|
||||
binding = DatasetMetadataBinding(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
metadata_id=metadata.id,
|
||||
document_id=documents[i % len(documents)].id,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
binding.id = str(uuid.uuid4())
|
||||
binding.created_at = datetime.now()
|
||||
bindings.append(binding)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -966,14 +966,15 @@ class TestCleanDatasetTask:
|
|||
|
||||
# Create metadata with special characters
|
||||
special_metadata = DatasetMetadata(
|
||||
id=str(uuid.uuid4()),
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=tenant.id,
|
||||
name=f"metadata_{special_content}",
|
||||
type="string",
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
special_metadata.id = str(uuid.uuid4())
|
||||
special_metadata.created_at = datetime.now()
|
||||
|
||||
db.session.add(special_metadata)
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -112,13 +112,13 @@ class TestRagPipelineRunTasks:
|
|||
|
||||
# Create pipeline
|
||||
pipeline = Pipeline(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
workflow_id=workflow.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
created_by=account.id,
|
||||
)
|
||||
pipeline.id = str(uuid.uuid4())
|
||||
db.session.add(pipeline)
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -50,3 +50,218 @@ def test_validate_input_with_none_for_required_variable():
|
|||
)
|
||||
|
||||
assert str(exc_info.value) == "test_var is required in input form"
|
||||
|
||||
|
||||
def test_validate_inputs_with_default_value():
|
||||
"""Test that default values are used when input is None for optional variables"""
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
# Test with string default value for TEXT_INPUT
|
||||
var_string = VariableEntity(
|
||||
variable="test_var",
|
||||
label="test_var",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
required=False,
|
||||
default="default_string",
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_string,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == "default_string"
|
||||
|
||||
# Test with string default value for PARAGRAPH
|
||||
var_paragraph = VariableEntity(
|
||||
variable="test_paragraph",
|
||||
label="test_paragraph",
|
||||
type=VariableEntityType.PARAGRAPH,
|
||||
required=False,
|
||||
default="default paragraph text",
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_paragraph,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == "default paragraph text"
|
||||
|
||||
# Test with SELECT default value
|
||||
var_select = VariableEntity(
|
||||
variable="test_select",
|
||||
label="test_select",
|
||||
type=VariableEntityType.SELECT,
|
||||
required=False,
|
||||
default="option1",
|
||||
options=["option1", "option2", "option3"],
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_select,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == "option1"
|
||||
|
||||
# Test with number default value (int)
|
||||
var_number_int = VariableEntity(
|
||||
variable="test_number_int",
|
||||
label="test_number_int",
|
||||
type=VariableEntityType.NUMBER,
|
||||
required=False,
|
||||
default=42,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_number_int,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == 42
|
||||
|
||||
# Test with number default value (float)
|
||||
var_number_float = VariableEntity(
|
||||
variable="test_number_float",
|
||||
label="test_number_float",
|
||||
type=VariableEntityType.NUMBER,
|
||||
required=False,
|
||||
default=3.14,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_number_float,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == 3.14
|
||||
|
||||
# Test with number default value as string (frontend sends as string)
|
||||
var_number_string = VariableEntity(
|
||||
variable="test_number_string",
|
||||
label="test_number_string",
|
||||
type=VariableEntityType.NUMBER,
|
||||
required=False,
|
||||
default="123",
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_number_string,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == 123
|
||||
assert isinstance(result, int)
|
||||
|
||||
# Test with float number default value as string
|
||||
var_number_float_string = VariableEntity(
|
||||
variable="test_number_float_string",
|
||||
label="test_number_float_string",
|
||||
type=VariableEntityType.NUMBER,
|
||||
required=False,
|
||||
default="45.67",
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_number_float_string,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == 45.67
|
||||
assert isinstance(result, float)
|
||||
|
||||
# Test with CHECKBOX default value (bool)
|
||||
var_checkbox_true = VariableEntity(
|
||||
variable="test_checkbox_true",
|
||||
label="test_checkbox_true",
|
||||
type=VariableEntityType.CHECKBOX,
|
||||
required=False,
|
||||
default=True,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_checkbox_true,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
var_checkbox_false = VariableEntity(
|
||||
variable="test_checkbox_false",
|
||||
label="test_checkbox_false",
|
||||
type=VariableEntityType.CHECKBOX,
|
||||
required=False,
|
||||
default=False,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_checkbox_false,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
# Test with None as explicit default value
|
||||
var_none_default = VariableEntity(
|
||||
variable="test_none",
|
||||
label="test_none",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_none_default,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
# Test that actual input value takes precedence over default
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_string,
|
||||
value="actual_value",
|
||||
)
|
||||
|
||||
assert result == "actual_value"
|
||||
|
||||
# Test that actual number input takes precedence over default
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_number_int,
|
||||
value=999,
|
||||
)
|
||||
|
||||
assert result == 999
|
||||
|
||||
# Test with FILE default value (dict format from frontend)
|
||||
var_file = VariableEntity(
|
||||
variable="test_file",
|
||||
label="test_file",
|
||||
type=VariableEntityType.FILE,
|
||||
required=False,
|
||||
default={"id": "file123", "name": "default.pdf"},
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_file,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == {"id": "file123", "name": "default.pdf"}
|
||||
|
||||
# Test with FILE_LIST default value (list of dicts)
|
||||
var_file_list = VariableEntity(
|
||||
variable="test_file_list",
|
||||
label="test_file_list",
|
||||
type=VariableEntityType.FILE_LIST,
|
||||
required=False,
|
||||
default=[{"id": "file1", "name": "doc1.pdf"}, {"id": "file2", "name": "doc2.pdf"}],
|
||||
)
|
||||
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var_file_list,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert result == [{"id": "file1", "name": "doc1.pdf"}, {"id": "file2", "name": "doc2.pdf"}]
|
||||
|
|
|
|||
|
|
@ -39,9 +39,9 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
|||
ps.id = "id"
|
||||
|
||||
provider_model_settings = [ps]
|
||||
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
|
|
@ -51,7 +51,6 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
|||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id="id2",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
|
|
@ -61,6 +60,8 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
|||
enabled=True,
|
||||
),
|
||||
]
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
load_balancing_model_configs[1].id = "id2"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
|
|
@ -101,7 +102,6 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
|
|||
provider_model_settings = [ps]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
|
|
@ -111,6 +111,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
|
|||
enabled=True,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
|
|
@ -148,7 +149,6 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
|||
provider_model_settings = [ps]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
|
|
@ -158,7 +158,6 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
|||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id="id2",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
|
|
@ -168,6 +167,8 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
|||
enabled=True,
|
||||
),
|
||||
]
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
load_balancing_model_configs[1].id = "id2"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,188 @@
|
|||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from models.engine import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db_scalar(monkeypatch):
|
||||
"""Provide a controllable fake for db.session.scalar (SQLAlchemy 2.0 style)."""
|
||||
calls = []
|
||||
|
||||
def _install(side_effect):
|
||||
def _fake_scalar(statement):
|
||||
calls.append(statement)
|
||||
return side_effect(statement)
|
||||
|
||||
# Patch the modern API used by the model implementation
|
||||
monkeypatch.setattr(db.session, "scalar", _fake_scalar)
|
||||
|
||||
# Backward-compatibility: if the implementation still uses db.session.get,
|
||||
# make it delegate to the same side_effect so tests remain valid on older code.
|
||||
if hasattr(db.session, "get"):
|
||||
|
||||
def _fake_get(*_args, **_kwargs):
|
||||
return side_effect(None)
|
||||
|
||||
monkeypatch.setattr(db.session, "get", _fake_get)
|
||||
|
||||
return calls
|
||||
|
||||
return _install
|
||||
|
||||
|
||||
def make_account(id_: str = "acc-1"):
|
||||
# Use a simple object to avoid constructing a full SQLAlchemy model instance
|
||||
# Python 3.12 forbids reassigning __class__ for SimpleNamespace; not needed here.
|
||||
obj = types.SimpleNamespace()
|
||||
obj.id = id_
|
||||
return obj
|
||||
|
||||
|
||||
def make_end_user(id_: str = "user-1"):
|
||||
# Lightweight stand-in object; no need to spoof class identity.
|
||||
obj = types.SimpleNamespace()
|
||||
obj.id = id_
|
||||
return obj
|
||||
|
||||
|
||||
def test_created_by_account_returns_account_when_role_account(fake_db_scalar):
|
||||
account = make_account("acc-1")
|
||||
|
||||
# The implementation uses db.session.scalar(select(Account)...). We only need to
|
||||
# return the expected object when called; the exact SQL is irrelevant for this unit test.
|
||||
def side_effect(_statement):
|
||||
return account
|
||||
|
||||
fake_db_scalar(side_effect)
|
||||
|
||||
log = WorkflowNodeExecutionModel(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
workflow_id="w1",
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
created_by="acc-1",
|
||||
)
|
||||
|
||||
assert log.created_by_account is account
|
||||
|
||||
|
||||
def test_created_by_account_returns_none_when_role_not_account(fake_db_scalar):
|
||||
# Even if an Account with matching id exists, property should return None when role is END_USER
|
||||
account = make_account("acc-1")
|
||||
|
||||
def side_effect(_statement):
|
||||
return account
|
||||
|
||||
fake_db_scalar(side_effect)
|
||||
|
||||
log = WorkflowNodeExecutionModel(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
workflow_id="w1",
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
created_by="acc-1",
|
||||
)
|
||||
|
||||
assert log.created_by_account is None
|
||||
|
||||
|
||||
def test_created_by_end_user_returns_end_user_when_role_end_user(fake_db_scalar):
|
||||
end_user = make_end_user("user-1")
|
||||
|
||||
def side_effect(_statement):
|
||||
return end_user
|
||||
|
||||
fake_db_scalar(side_effect)
|
||||
|
||||
log = WorkflowNodeExecutionModel(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
workflow_id="w1",
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=CreatorUserRole.END_USER.value,
|
||||
created_by="user-1",
|
||||
)
|
||||
|
||||
assert log.created_by_end_user is end_user
|
||||
|
||||
|
||||
def test_created_by_end_user_returns_none_when_role_not_end_user(fake_db_scalar):
|
||||
end_user = make_end_user("user-1")
|
||||
|
||||
def side_effect(_statement):
|
||||
return end_user
|
||||
|
||||
fake_db_scalar(side_effect)
|
||||
|
||||
log = WorkflowNodeExecutionModel(
|
||||
tenant_id="t1",
|
||||
app_id="a1",
|
||||
workflow_id="w1",
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=None,
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
node_execution_id=None,
|
||||
node_id="n1",
|
||||
node_type="start",
|
||||
title="Start",
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status="succeeded",
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
execution_metadata=None,
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
created_by="user-1",
|
||||
)
|
||||
|
||||
assert log.created_by_end_user is None
|
||||
|
|
@ -0,0 +1,819 @@
|
|||
"""
|
||||
Comprehensive unit tests for DatasetService creation methods.
|
||||
|
||||
This test suite covers:
|
||||
- create_empty_dataset for internal datasets
|
||||
- create_empty_dataset for external datasets
|
||||
- create_empty_rag_pipeline_dataset
|
||||
- Error conditions and edge cases
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, Pipeline
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import (
|
||||
IconInfo,
|
||||
RagPipelineDatasetCreateEntity,
|
||||
)
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
|
||||
|
||||
class DatasetCreateTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset creation tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_mock(
|
||||
account_id: str = "account-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock account."""
|
||||
account = create_autospec(Account, instance=True)
|
||||
account.id = account_id
|
||||
account.current_tenant_id = tenant_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(account, key, value)
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
|
||||
"""Create a mock embedding model."""
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = model
|
||||
embedding_model.provider = provider
|
||||
return embedding_model
|
||||
|
||||
@staticmethod
|
||||
def create_retrieval_model_mock() -> Mock:
|
||||
"""Create a mock retrieval model."""
|
||||
retrieval_model = Mock(spec=RetrievalModel)
|
||||
retrieval_model.model_dump.return_value = {
|
||||
"search_method": "semantic_search",
|
||||
"top_k": 2,
|
||||
"score_threshold": 0.0,
|
||||
}
|
||||
retrieval_model.reranking_model = None
|
||||
return retrieval_model
|
||||
|
||||
@staticmethod
|
||||
def create_external_knowledge_api_mock(api_id: str = "api-123", **kwargs) -> Mock:
|
||||
"""Create a mock external knowledge API."""
|
||||
api = Mock()
|
||||
api.id = api_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(api, key, value)
|
||||
return api
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
name: str = "Test Dataset",
|
||||
tenant_id: str = "tenant-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset."""
|
||||
dataset = create_autospec(Dataset, instance=True)
|
||||
dataset.id = dataset_id
|
||||
dataset.name = name
|
||||
dataset.tenant_id = tenant_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_pipeline_mock(
|
||||
pipeline_id: str = "pipeline-123",
|
||||
name: str = "Test Pipeline",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock pipeline."""
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
pipeline.id = pipeline_id
|
||||
pipeline.name = name
|
||||
for key, value in kwargs.items():
|
||||
setattr(pipeline, key, value)
|
||||
return pipeline
|
||||
|
||||
|
||||
class TestDatasetServiceCreateEmptyDataset:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.create_empty_dataset method.
|
||||
|
||||
This test suite covers:
|
||||
- Internal dataset creation (vendor provider)
|
||||
- External dataset creation
|
||||
- High quality indexing technique with embedding models
|
||||
- Economy indexing technique
|
||||
- Retrieval model configuration
|
||||
- Error conditions (duplicate names, missing external knowledge IDs)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""Common mock setup for dataset service dependencies."""
|
||||
with (
|
||||
patch("services.dataset_service.db.session") as mock_db,
|
||||
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||
patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding,
|
||||
patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking,
|
||||
patch("services.dataset_service.ExternalDatasetService") as mock_external_service,
|
||||
):
|
||||
yield {
|
||||
"db_session": mock_db,
|
||||
"model_manager": mock_model_manager,
|
||||
"check_embedding": mock_check_embedding,
|
||||
"check_reranking": mock_check_reranking,
|
||||
"external_service": mock_external_service,
|
||||
}
|
||||
|
||||
# ==================== Internal Dataset Creation Tests ====================
|
||||
|
||||
def test_create_internal_dataset_basic_success(self, mock_dataset_service_dependencies):
|
||||
"""Test successful creation of basic internal dataset."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Test Dataset"
|
||||
description = "Test description"
|
||||
|
||||
# Mock database query to return None (no duplicate name)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock database session operations
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.name == name
|
||||
assert result.description == description
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.created_by == account.id
|
||||
assert result.updated_by == account.id
|
||||
assert result.provider == "vendor"
|
||||
assert result.permission == "only_me"
|
||||
mock_db.add.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_internal_dataset_with_economy_indexing(self, mock_dataset_service_dependencies):
|
||||
"""Test successful creation of internal dataset with economy indexing."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Economy Dataset"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique="economy",
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.indexing_technique == "economy"
|
||||
assert result.embedding_model_provider is None
|
||||
assert result.embedding_model is None
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_internal_dataset_with_high_quality_indexing_default_embedding(
|
||||
self, mock_dataset_service_dependencies
|
||||
):
|
||||
"""Test creation with high_quality indexing using default embedding model."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "High Quality Dataset"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock model manager
|
||||
embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock()
|
||||
mock_model_manager_instance = Mock()
|
||||
mock_model_manager_instance.get_default_model_instance.return_value = embedding_model
|
||||
mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_model.provider
|
||||
assert result.embedding_model == embedding_model.model
|
||||
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_internal_dataset_with_high_quality_indexing_custom_embedding(
|
||||
self, mock_dataset_service_dependencies
|
||||
):
|
||||
"""Test creation with high_quality indexing using custom embedding model."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Custom Embedding Dataset"
|
||||
embedding_provider = "openai"
|
||||
embedding_model_name = "text-embedding-3-small"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock model manager
|
||||
embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock(
|
||||
model=embedding_model_name, provider=embedding_provider
|
||||
)
|
||||
mock_model_manager_instance = Mock()
|
||||
mock_model_manager_instance.get_model_instance.return_value = embedding_model
|
||||
mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
account=account,
|
||||
embedding_model_provider=embedding_provider,
|
||||
embedding_model_name=embedding_model_name,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_provider
|
||||
assert result.embedding_model == embedding_model_name
|
||||
mock_dataset_service_dependencies["check_embedding"].assert_called_once_with(
|
||||
tenant_id, embedding_provider, embedding_model_name
|
||||
)
|
||||
mock_model_manager_instance.get_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
provider=embedding_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=embedding_model_name,
|
||||
)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_internal_dataset_with_retrieval_model(self, mock_dataset_service_dependencies):
|
||||
"""Test creation with retrieval model configuration."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Retrieval Model Dataset"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock retrieval model
|
||||
retrieval_model = DatasetCreateTestDataFactory.create_retrieval_model_mock()
|
||||
retrieval_model_dict = {"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0}
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.retrieval_model == retrieval_model_dict
|
||||
retrieval_model.model_dump.assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_internal_dataset_with_retrieval_model_reranking(self, mock_dataset_service_dependencies):
|
||||
"""Test creation with retrieval model that includes reranking."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Reranking Dataset"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock model manager
|
||||
embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock()
|
||||
mock_model_manager_instance = Mock()
|
||||
mock_model_manager_instance.get_default_model_instance.return_value = embedding_model
|
||||
mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance
|
||||
|
||||
# Mock retrieval model with reranking
|
||||
reranking_model = Mock()
|
||||
reranking_model.reranking_provider_name = "cohere"
|
||||
reranking_model.reranking_model_name = "rerank-english-v3.0"
|
||||
|
||||
retrieval_model = DatasetCreateTestDataFactory.create_retrieval_model_mock()
|
||||
retrieval_model.reranking_model = reranking_model
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_dataset_service_dependencies["check_reranking"].assert_called_once_with(
|
||||
tenant_id, "cohere", "rerank-english-v3.0"
|
||||
)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_internal_dataset_with_custom_permission(self, mock_dataset_service_dependencies):
|
||||
"""Test creation with custom permission setting."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Custom Permission Dataset"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
permission="all_team_members",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.permission == "all_team_members"
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
# ==================== External Dataset Creation Tests ====================
|
||||
|
||||
def test_create_external_dataset_success(self, mock_dataset_service_dependencies):
|
||||
"""Test successful creation of external dataset."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "External Dataset"
|
||||
external_api_id = "external-api-123"
|
||||
external_knowledge_id = "external-knowledge-456"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock external knowledge API
|
||||
external_api = DatasetCreateTestDataFactory.create_external_knowledge_api_mock(api_id=external_api_id)
|
||||
mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
provider="external",
|
||||
external_knowledge_api_id=external_api_id,
|
||||
external_knowledge_id=external_knowledge_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.provider == "external"
|
||||
assert mock_db.add.call_count == 2 # Dataset + ExternalKnowledgeBindings
|
||||
mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.assert_called_once_with(
|
||||
external_api_id
|
||||
)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when external knowledge API is not found."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "External Dataset"
|
||||
external_api_id = "non-existent-api"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock external knowledge API not found
|
||||
mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = None
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="External API template not found"):
|
||||
DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
provider="external",
|
||||
external_knowledge_api_id=external_api_id,
|
||||
external_knowledge_id="knowledge-123",
|
||||
)
|
||||
|
||||
def test_create_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when external knowledge ID is missing."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "External Dataset"
|
||||
external_api_id = "external-api-123"
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock external knowledge API
|
||||
external_api = DatasetCreateTestDataFactory.create_external_knowledge_api_mock(api_id=external_api_id)
|
||||
mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api
|
||||
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="external_knowledge_id is required"):
|
||||
DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
provider="external",
|
||||
external_knowledge_api_id=external_api_id,
|
||||
external_knowledge_id=None,
|
||||
)
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
def test_create_dataset_duplicate_name_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when dataset name already exists."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id)
|
||||
name = "Duplicate Dataset"
|
||||
|
||||
# Mock database query to return existing dataset
|
||||
existing_dataset = DatasetCreateTestDataFactory.create_dataset_mock(name=name)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = existing_dataset
|
||||
mock_dataset_service_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {name} already exists"):
|
||||
DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetServiceCreateEmptyRagPipelineDataset:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.create_empty_rag_pipeline_dataset method.
|
||||
|
||||
This test suite covers:
|
||||
- RAG pipeline dataset creation with provided name
|
||||
- RAG pipeline dataset creation with auto-generated name
|
||||
- Pipeline creation
|
||||
- Error conditions (duplicate names, missing current user)
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_pipeline_dependencies(self):
|
||||
"""Common mock setup for RAG pipeline dataset creation."""
|
||||
with (
|
||||
patch("services.dataset_service.db.session") as mock_db,
|
||||
patch("services.dataset_service.current_user") as mock_current_user,
|
||||
patch("services.dataset_service.generate_incremental_name") as mock_generate_name,
|
||||
):
|
||||
# Configure mock_current_user to behave like a Flask-Login proxy
|
||||
# Default: no user (falsy)
|
||||
mock_current_user.id = None
|
||||
yield {
|
||||
"db_session": mock_db,
|
||||
"current_user_mock": mock_current_user,
|
||||
"generate_name": mock_generate_name,
|
||||
}
|
||||
|
||||
def test_create_rag_pipeline_dataset_with_name_success(self, mock_rag_pipeline_dependencies):
|
||||
"""Test successful creation of RAG pipeline dataset with provided name."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
name = "RAG Pipeline Dataset"
|
||||
description = "RAG Pipeline Description"
|
||||
|
||||
# Mock current user - set up the mock to have id attribute accessible directly
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
|
||||
|
||||
# Mock database query (no duplicate name)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock database operations
|
||||
mock_db = mock_rag_pipeline_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Create entity
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name=name,
|
||||
description=description,
|
||||
icon_info=icon_info,
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.name == name
|
||||
assert result.description == description
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.created_by == user_id
|
||||
assert result.provider == "vendor"
|
||||
assert result.runtime_mode == "rag_pipeline"
|
||||
assert result.permission == "only_me"
|
||||
assert mock_db.add.call_count == 2 # Pipeline + Dataset
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_rag_pipeline_dataset_with_auto_generated_name(self, mock_rag_pipeline_dependencies):
|
||||
"""Test creation of RAG pipeline dataset with auto-generated name."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
auto_name = "Untitled 1"
|
||||
|
||||
# Mock current user - set up the mock to have id attribute accessible directly
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
|
||||
|
||||
# Mock database query (empty name, need to generate)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock name generation
|
||||
mock_rag_pipeline_dependencies["generate_name"].return_value = auto_name
|
||||
|
||||
# Mock database operations
|
||||
mock_db = mock_rag_pipeline_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Create entity with empty name
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name="",
|
||||
description="",
|
||||
icon_info=icon_info,
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.name == auto_name
|
||||
mock_rag_pipeline_dependencies["generate_name"].assert_called_once()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_rag_pipeline_dataset_duplicate_name_error(self, mock_rag_pipeline_dependencies):
|
||||
"""Test error when RAG pipeline dataset name already exists."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
name = "Duplicate RAG Dataset"
|
||||
|
||||
# Mock current user - set up the mock to have id attribute accessible directly
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
|
||||
|
||||
# Mock database query to return existing dataset
|
||||
existing_dataset = DatasetCreateTestDataFactory.create_dataset_mock(name=name)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = existing_dataset
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Create entity
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name=name,
|
||||
description="",
|
||||
icon_info=icon_info,
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {name} already exists"):
|
||||
DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
def test_create_rag_pipeline_dataset_missing_current_user_error(self, mock_rag_pipeline_dependencies):
|
||||
"""Test error when current user is not available."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Mock current user as None - set id to None so the check fails
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = None
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Create entity
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name="Test Dataset",
|
||||
description="",
|
||||
icon_info=icon_info,
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Current user or current user id not found"):
|
||||
DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
def test_create_rag_pipeline_dataset_with_custom_permission(self, mock_rag_pipeline_dependencies):
|
||||
"""Test creation with custom permission setting."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
name = "Custom Permission RAG Dataset"
|
||||
|
||||
# Mock current user - set up the mock to have id attribute accessible directly
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock database operations
|
||||
mock_db = mock_rag_pipeline_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Create entity
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name=name,
|
||||
description="",
|
||||
icon_info=icon_info,
|
||||
permission="all_team",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.permission == "all_team"
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def test_create_rag_pipeline_dataset_with_icon_info(self, mock_rag_pipeline_dependencies):
|
||||
"""Test creation with icon info configuration."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
name = "Icon Info RAG Dataset"
|
||||
|
||||
# Mock current user - set up the mock to have id attribute accessible directly
|
||||
mock_rag_pipeline_dependencies["current_user_mock"].id = user_id
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock database operations
|
||||
mock_db = mock_rag_pipeline_dependencies["db_session"]
|
||||
mock_db.add = Mock()
|
||||
mock_db.flush = Mock()
|
||||
mock_db.commit = Mock()
|
||||
|
||||
# Create entity with icon info
|
||||
icon_info = IconInfo(
|
||||
icon="📚",
|
||||
icon_background="#E8F5E9",
|
||||
icon_type="emoji",
|
||||
icon_url="https://example.com/icon.png",
|
||||
)
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name=name,
|
||||
description="",
|
||||
icon_info=icon_info,
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.icon_info == icon_info.model_dump()
|
||||
mock_db.commit.assert_called_once()
|
||||
|
|
@ -0,0 +1,746 @@
|
|||
"""
|
||||
Comprehensive unit tests for DatasetService retrieval/list methods.
|
||||
|
||||
This test suite covers:
|
||||
- get_datasets - pagination, search, filtering, permissions
|
||||
- get_dataset - single dataset retrieval
|
||||
- get_datasets_by_ids - bulk retrieval
|
||||
- get_process_rules - dataset processing rules
|
||||
- get_dataset_queries - dataset query history
|
||||
- get_related_apps - apps using the dataset
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
Dataset,
|
||||
DatasetPermission,
|
||||
DatasetPermissionEnum,
|
||||
DatasetProcessRule,
|
||||
DatasetQuery,
|
||||
)
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
|
||||
|
||||
class DatasetRetrievalTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset retrieval tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
name: str = "Test Dataset",
|
||||
tenant_id: str = "tenant-123",
|
||||
created_by: str = "user-123",
|
||||
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.name = name
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.created_by = created_by
|
||||
dataset.permission = permission
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_account_mock(
|
||||
account_id: str = "account-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock account."""
|
||||
account = create_autospec(Account, instance=True)
|
||||
account.id = account_id
|
||||
account.current_tenant_id = tenant_id
|
||||
account.current_role = role
|
||||
for key, value in kwargs.items():
|
||||
setattr(account, key, value)
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_permission_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
account_id: str = "account-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset permission."""
|
||||
permission = Mock(spec=DatasetPermission)
|
||||
permission.dataset_id = dataset_id
|
||||
permission.account_id = account_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(permission, key, value)
|
||||
return permission
|
||||
|
||||
@staticmethod
|
||||
def create_process_rule_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
mode: str = "automatic",
|
||||
rules: dict | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset process rule."""
|
||||
process_rule = Mock(spec=DatasetProcessRule)
|
||||
process_rule.dataset_id = dataset_id
|
||||
process_rule.mode = mode
|
||||
process_rule.rules_dict = rules or {}
|
||||
for key, value in kwargs.items():
|
||||
setattr(process_rule, key, value)
|
||||
return process_rule
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_query_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
query_id: str = "query-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset query."""
|
||||
dataset_query = Mock(spec=DatasetQuery)
|
||||
dataset_query.id = query_id
|
||||
dataset_query.dataset_id = dataset_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset_query, key, value)
|
||||
return dataset_query
|
||||
|
||||
@staticmethod
|
||||
def create_app_dataset_join_mock(
|
||||
app_id: str = "app-123",
|
||||
dataset_id: str = "dataset-123",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock app-dataset join."""
|
||||
join = Mock(spec=AppDatasetJoin)
|
||||
join.app_id = app_id
|
||||
join.dataset_id = dataset_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(join, key, value)
|
||||
return join
|
||||
|
||||
|
||||
class TestDatasetServiceGetDatasets:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.get_datasets method.
|
||||
|
||||
This test suite covers:
|
||||
- Pagination
|
||||
- Search functionality
|
||||
- Tag filtering
|
||||
- Permission-based filtering (ONLY_ME, ALL_TEAM, PARTIAL_TEAM)
|
||||
- Role-based filtering (OWNER, DATASET_OPERATOR, NORMAL)
|
||||
- include_all flag
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_datasets tests."""
|
||||
with (
|
||||
patch("services.dataset_service.db.session") as mock_db,
|
||||
patch("services.dataset_service.db.paginate") as mock_paginate,
|
||||
patch("services.dataset_service.TagService") as mock_tag_service,
|
||||
):
|
||||
yield {
|
||||
"db_session": mock_db,
|
||||
"paginate": mock_paginate,
|
||||
"tag_service": mock_tag_service,
|
||||
}
|
||||
|
||||
# ==================== Basic Retrieval Tests ====================
|
||||
|
||||
def test_get_datasets_basic_pagination(self, mock_dependencies):
|
||||
"""Test basic pagination without user or filters."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id=f"dataset-{i}", name=f"Dataset {i}", tenant_id=tenant_id
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
mock_paginate_result.total = 5
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 5
|
||||
assert total == 5
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_datasets_with_search(self, mock_dependencies):
|
||||
"""Test get_datasets with search keyword."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
search = "test"
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-1", name="Test Dataset", tenant_id=tenant_id
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, search=search)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_datasets_with_tag_filtering(self, mock_dependencies):
|
||||
"""Test get_datasets with tag_ids filtering."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
tag_ids = ["tag-1", "tag-2"]
|
||||
|
||||
# Mock tag service
|
||||
target_ids = ["dataset-1", "dataset-2"]
|
||||
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.return_value = target_ids
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
for dataset_id in target_ids
|
||||
]
|
||||
mock_paginate_result.total = 2
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, tag_ids=tag_ids)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 2
|
||||
assert total == 2
|
||||
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.assert_called_once_with(
|
||||
"knowledge", tenant_id, tag_ids
|
||||
)
|
||||
|
||||
def test_get_datasets_with_empty_tag_ids(self, mock_dependencies):
|
||||
"""Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
tag_ids = []
|
||||
|
||||
# Mock pagination result - when tag_ids is empty, tag filtering is skipped
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", tenant_id=tenant_id)
|
||||
for i in range(3)
|
||||
]
|
||||
mock_paginate_result.total = 3
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, tag_ids=tag_ids)
|
||||
|
||||
# Assert
|
||||
# When tag_ids is empty, tag filtering is skipped, so normal query results are returned
|
||||
assert len(datasets) == 3
|
||||
assert total == 3
|
||||
# Tag service should not be called when tag_ids is empty
|
||||
mock_dependencies["tag_service"].get_target_ids_by_tag_ids.assert_not_called()
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
# ==================== Permission-Based Filtering Tests ====================
|
||||
|
||||
def test_get_datasets_without_user_shows_only_all_team(self, mock_dependencies):
|
||||
"""Test that without user, only ALL_TEAM datasets are shown."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-1",
|
||||
tenant_id=tenant_id,
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page, per_page, tenant_id=tenant_id, user=None)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_datasets_owner_with_include_all(self, mock_dependencies):
|
||||
"""Test that OWNER with include_all=True sees all datasets."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id="owner-123", tenant_id=tenant_id, role=TenantAccountRole.OWNER
|
||||
)
|
||||
|
||||
# Mock dataset permissions query (empty - owner doesn't need explicit permissions)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=f"dataset-{i}", tenant_id=tenant_id)
|
||||
for i in range(3)
|
||||
]
|
||||
mock_paginate_result.total = 3
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page=1, per_page=20, tenant_id=tenant_id, user=user, include_all=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 3
|
||||
assert total == 3
|
||||
|
||||
def test_get_datasets_normal_user_only_me_permission(self, mock_dependencies):
|
||||
"""Test that normal user sees ONLY_ME datasets they created."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = "user-123"
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock dataset permissions query (no explicit permissions)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-1",
|
||||
tenant_id=tenant_id,
|
||||
created_by=user_id,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_normal_user_all_team_permission(self, mock_dependencies):
|
||||
"""Test that normal user sees ALL_TEAM datasets."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id="user-123", tenant_id=tenant_id, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock dataset permissions query (no explicit permissions)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id="dataset-1",
|
||||
tenant_id=tenant_id,
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_normal_user_partial_team_with_permission(self, mock_dependencies):
|
||||
"""Test that normal user sees PARTIAL_TEAM datasets they have permission for."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = "user-123"
|
||||
dataset_id = "dataset-1"
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
# Mock dataset permissions query - user has permission
|
||||
permission = DatasetRetrievalTestDataFactory.create_dataset_permission_mock(
|
||||
dataset_id=dataset_id, account_id=user_id
|
||||
)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = [permission]
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_dataset_operator_with_permissions(self, mock_dependencies):
|
||||
"""Test that DATASET_OPERATOR only sees datasets they have explicit permission for."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = "operator-123"
|
||||
dataset_id = "dataset-1"
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.DATASET_OPERATOR
|
||||
)
|
||||
|
||||
# Mock dataset permissions query - operator has permission
|
||||
permission = DatasetRetrievalTestDataFactory.create_dataset_permission_mock(
|
||||
dataset_id=dataset_id, account_id=user_id
|
||||
)
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = [permission]
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
]
|
||||
mock_paginate_result.total = 1
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_dataset_operator_without_permissions(self, mock_dependencies):
|
||||
"""Test that DATASET_OPERATOR without permissions returns empty result."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = "operator-123"
|
||||
user = DatasetRetrievalTestDataFactory.create_account_mock(
|
||||
account_id=user_id, tenant_id=tenant_id, role=TenantAccountRole.DATASET_OPERATOR
|
||||
)
|
||||
|
||||
# Mock dataset permissions query - no permissions
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant_id, user=user)
|
||||
|
||||
# Assert
|
||||
assert datasets == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
class TestDatasetServiceGetDataset:
|
||||
"""Comprehensive unit tests for DatasetService.get_dataset method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_dataset tests."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield {"db_session": mock_db}
|
||||
|
||||
def test_get_dataset_success(self, mock_dependencies):
|
||||
"""Test successful retrieval of a single dataset."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id)
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = dataset
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.id == dataset_id
|
||||
mock_query.filter_by.assert_called_once_with(id=dataset_id)
|
||||
|
||||
def test_get_dataset_not_found(self, mock_dependencies):
|
||||
"""Test retrieval when dataset doesn't exist."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Mock database query returning None
|
||||
mock_query = Mock()
|
||||
mock_query.filter_by.return_value.first.return_value = None
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestDatasetServiceGetDatasetsByIds:
|
||||
"""Comprehensive unit tests for DatasetService.get_datasets_by_ids method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_datasets_by_ids tests."""
|
||||
with patch("services.dataset_service.db.paginate") as mock_paginate:
|
||||
yield {"paginate": mock_paginate}
|
||||
|
||||
def test_get_datasets_by_ids_success(self, mock_dependencies):
|
||||
"""Test successful bulk retrieval of datasets by IDs."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
dataset_ids = [str(uuid4()), str(uuid4()), str(uuid4())]
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_mock(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
mock_paginate_result.total = len(dataset_ids)
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert len(datasets) == 3
|
||||
assert total == 3
|
||||
assert all(dataset.id in dataset_ids for dataset in datasets)
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_datasets_by_ids_empty_list(self, mock_dependencies):
|
||||
"""Test get_datasets_by_ids with empty list returns empty result."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
dataset_ids = []
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets_by_ids(dataset_ids, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert datasets == []
|
||||
assert total == 0
|
||||
mock_dependencies["paginate"].assert_not_called()
|
||||
|
||||
def test_get_datasets_by_ids_none_list(self, mock_dependencies):
|
||||
"""Test get_datasets_by_ids with None returns empty result."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets_by_ids(None, tenant_id)
|
||||
|
||||
# Assert
|
||||
assert datasets == []
|
||||
assert total == 0
|
||||
mock_dependencies["paginate"].assert_not_called()
|
||||
|
||||
|
||||
class TestDatasetServiceGetProcessRules:
|
||||
"""Comprehensive unit tests for DatasetService.get_process_rules method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_process_rules tests."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield {"db_session": mock_db}
|
||||
|
||||
def test_get_process_rules_with_existing_rule(self, mock_dependencies):
|
||||
"""Test retrieval of process rules when rule exists."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
rules_data = {
|
||||
"pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}],
|
||||
"segmentation": {"delimiter": "\n", "max_tokens": 500},
|
||||
}
|
||||
process_rule = DatasetRetrievalTestDataFactory.create_process_rule_mock(
|
||||
dataset_id=dataset_id, mode="custom", rules=rules_data
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value = process_rule
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_process_rules(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result["mode"] == "custom"
|
||||
assert result["rules"] == rules_data
|
||||
|
||||
def test_get_process_rules_without_existing_rule(self, mock_dependencies):
|
||||
"""Test retrieval of process rules when no rule exists (returns defaults)."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Mock database query returning None
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value = None
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_process_rules(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result["mode"] == DocumentService.DEFAULT_RULES["mode"]
|
||||
assert "rules" in result
|
||||
assert result["rules"] == DocumentService.DEFAULT_RULES["rules"]
|
||||
|
||||
|
||||
class TestDatasetServiceGetDatasetQueries:
|
||||
"""Comprehensive unit tests for DatasetService.get_dataset_queries method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_dataset_queries tests."""
|
||||
with patch("services.dataset_service.db.paginate") as mock_paginate:
|
||||
yield {"paginate": mock_paginate}
|
||||
|
||||
def test_get_dataset_queries_success(self, mock_dependencies):
|
||||
"""Test successful retrieval of dataset queries."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset_query_mock(dataset_id=dataset_id, query_id=f"query-{i}")
|
||||
for i in range(3)
|
||||
]
|
||||
mock_paginate_result.total = 3
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
queries, total = DatasetService.get_dataset_queries(dataset_id, page, per_page)
|
||||
|
||||
# Assert
|
||||
assert len(queries) == 3
|
||||
assert total == 3
|
||||
assert all(query.dataset_id == dataset_id for query in queries)
|
||||
mock_dependencies["paginate"].assert_called_once()
|
||||
|
||||
def test_get_dataset_queries_empty_result(self, mock_dependencies):
|
||||
"""Test retrieval when no queries exist."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
# Mock pagination result (empty)
|
||||
mock_paginate_result = Mock()
|
||||
mock_paginate_result.items = []
|
||||
mock_paginate_result.total = 0
|
||||
mock_dependencies["paginate"].return_value = mock_paginate_result
|
||||
|
||||
# Act
|
||||
queries, total = DatasetService.get_dataset_queries(dataset_id, page, per_page)
|
||||
|
||||
# Assert
|
||||
assert queries == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
class TestDatasetServiceGetRelatedApps:
|
||||
"""Comprehensive unit tests for DatasetService.get_related_apps method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Common mock setup for get_related_apps tests."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
yield {"db_session": mock_db}
|
||||
|
||||
def test_get_related_apps_success(self, mock_dependencies):
|
||||
"""Test successful retrieval of related apps."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Mock app-dataset joins
|
||||
app_joins = [
|
||||
DatasetRetrievalTestDataFactory.create_app_dataset_join_mock(app_id=f"app-{i}", dataset_id=dataset_id)
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
# Mock database query
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.order_by.return_value.all.return_value = app_joins
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_related_apps(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert all(join.dataset_id == dataset_id for join in result)
|
||||
mock_query.where.assert_called_once()
|
||||
mock_query.where.return_value.order_by.assert_called_once()
|
||||
|
||||
def test_get_related_apps_empty_result(self, mock_dependencies):
|
||||
"""Test retrieval when no related apps exist."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Mock database query returning empty list
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.order_by.return_value.all.return_value = []
|
||||
mock_dependencies["db_session"].query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_related_apps(dataset_id)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
106
api/uv.lock
106
api/uv.lock
|
|
@ -1003,7 +1003,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "clickhouse-connect"
|
||||
version = "0.7.19"
|
||||
version = "0.10.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
|
|
@ -1012,28 +1012,24 @@ dependencies = [
|
|||
{ name = "urllib3" },
|
||||
{ name = "zstandard" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f4/8e/bf6012f7b45dbb74e19ad5c881a7bbcd1e7dd2b990f12cc434294d917800/clickhouse-connect-0.7.19.tar.gz", hash = "sha256:ce8f21f035781c5ef6ff57dc162e8150779c009b59f14030ba61f8c9c10c06d0", size = 84918, upload-time = "2024-08-21T21:37:16.639Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7b/fd/f8bea1157d40f117248dcaa9abdbf68c729513fcf2098ab5cb4aa58768b8/clickhouse_connect-0.10.0.tar.gz", hash = "sha256:a0256328802c6e5580513e197cef7f9ba49a99fc98e9ba410922873427569564", size = 104753, upload-time = "2025-11-14T20:31:00.947Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/68/6f/a78cad40dc0f1fee19094c40abd7d23ff04bb491732c3a65b3661d426c89/clickhouse_connect-0.7.19-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ee47af8926a7ec3a970e0ebf29a82cbbe3b1b7eae43336a81b3a0ca18091de5f", size = 253530, upload-time = "2024-08-21T21:35:53.372Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/40/82/419d110149900ace5eb0787c668d11e1657ac0eabb65c1404f039746f4ed/clickhouse_connect-0.7.19-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ce429233b2d21a8a149c8cd836a2555393cbcf23d61233520db332942ffb8964", size = 245691, upload-time = "2024-08-21T21:35:55.074Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/9c/ad6708ced6cf9418334d2bf19bbba3c223511ed852eb85f79b1e7c20cdbd/clickhouse_connect-0.7.19-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:617c04f5c46eed3344a7861cd96fb05293e70d3b40d21541b1e459e7574efa96", size = 1055273, upload-time = "2024-08-21T21:35:56.478Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/99/88c24542d6218100793cfb13af54d7ad4143d6515b0b3d621ba3b5a2d8af/clickhouse_connect-0.7.19-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f08e33b8cc2dc1873edc5ee4088d4fc3c0dbb69b00e057547bcdc7e9680b43e5", size = 1067030, upload-time = "2024-08-21T21:35:58.096Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/84/19eb776b4e760317c21214c811f04f612cba7eee0f2818a7d6806898a994/clickhouse_connect-0.7.19-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:921886b887f762e5cc3eef57ef784d419a3f66df85fd86fa2e7fbbf464c4c54a", size = 1027207, upload-time = "2024-08-21T21:35:59.832Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/22/81/c2982a33b088b6c9af5d0bdc46413adc5fedceae063b1f8b56570bb28887/clickhouse_connect-0.7.19-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6ad0cf8552a9e985cfa6524b674ae7c8f5ba51df5bd3ecddbd86c82cdbef41a7", size = 1054850, upload-time = "2024-08-21T21:36:01.559Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/a4/4a84ed3e92323d12700011cc8c4039f00a8c888079d65e75a4d4758ba288/clickhouse_connect-0.7.19-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:70f838ef0861cdf0e2e198171a1f3fd2ee05cf58e93495eeb9b17dfafb278186", size = 1022784, upload-time = "2024-08-21T21:36:02.805Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/67/3f5cc6f78c9adbbd6a3183a3f9f3196a116be19e958d7eaa6e307b391fed/clickhouse_connect-0.7.19-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c5f0d207cb0dcc1adb28ced63f872d080924b7562b263a9d54d4693b670eb066", size = 1071084, upload-time = "2024-08-21T21:36:04.052Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/8d/a294e1cc752e22bc6ee08aa421ea31ed9559b09d46d35499449140a5c374/clickhouse_connect-0.7.19-cp311-cp311-win32.whl", hash = "sha256:8c96c4c242b98fcf8005e678a26dbd4361748721b6fa158c1fe84ad15c7edbbe", size = 221156, upload-time = "2024-08-21T21:36:05.72Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/68/69/09b3a4e53f5d3d770e9fa70f6f04642cdb37cc76d37279c55fd4e868f845/clickhouse_connect-0.7.19-cp311-cp311-win_amd64.whl", hash = "sha256:bda092bab224875ed7c7683707d63f8a2322df654c4716e6611893a18d83e908", size = 238826, upload-time = "2024-08-21T21:36:06.892Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/af/f8/1d48719728bac33c1a9815e0a7230940e078fd985b09af2371715de78a3c/clickhouse_connect-0.7.19-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8f170d08166438d29f0dcfc8a91b672c783dc751945559e65eefff55096f9274", size = 256687, upload-time = "2024-08-21T21:36:08.245Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/0d/3cbbbd204be045c4727f9007679ad97d3d1d559b43ba844373a79af54d16/clickhouse_connect-0.7.19-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:26b80cb8f66bde9149a9a2180e2cc4895c1b7d34f9dceba81630a9b9a9ae66b2", size = 247631, upload-time = "2024-08-21T21:36:09.679Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/44/adb55285226d60e9c46331a9980c88dad8c8de12abb895c4e3149a088092/clickhouse_connect-0.7.19-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ba80e3598acf916c4d1b2515671f65d9efee612a783c17c56a5a646f4db59b9", size = 1053767, upload-time = "2024-08-21T21:36:11.361Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/f3/a109c26a41153768be57374cb823cac5daf74c9098a5c61081ffabeb4e59/clickhouse_connect-0.7.19-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d38c30bd847af0ce7ff738152478f913854db356af4d5824096394d0eab873d", size = 1072014, upload-time = "2024-08-21T21:36:12.752Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/80/9c200e5e392a538f2444c9a6a93e1cf0e36588c7e8720882ac001e23b246/clickhouse_connect-0.7.19-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d41d4b159071c0e4f607563932d4fa5c2a8fc27d3ba1200d0929b361e5191864", size = 1027423, upload-time = "2024-08-21T21:36:14.483Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/a3/219fcd1572f1ce198dcef86da8c6c526b04f56e8b7a82e21119677f89379/clickhouse_connect-0.7.19-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3682c2426f5dbda574611210e3c7c951b9557293a49eb60a7438552435873889", size = 1053683, upload-time = "2024-08-21T21:36:15.828Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/df/687d90fbc0fd8ce586c46400f3791deac120e4c080aa8b343c0f676dfb08/clickhouse_connect-0.7.19-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6d492064dca278eb61be3a2d70a5f082e2ebc8ceebd4f33752ae234116192020", size = 1021120, upload-time = "2024-08-21T21:36:17.184Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/3b/39ba71b103275df8ec90d424dbaca2dba82b28398c3d2aeac5a0141b6aae/clickhouse_connect-0.7.19-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:62612da163b934c1ff35df6155a47cf17ac0e2d2f9f0f8f913641e5c02cdf39f", size = 1073652, upload-time = "2024-08-21T21:36:19.053Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/92/06df8790a7d93d5d5f1098604fc7d79682784818030091966a3ce3f766a8/clickhouse_connect-0.7.19-cp312-cp312-win32.whl", hash = "sha256:196e48c977affc045794ec7281b4d711e169def00535ecab5f9fdeb8c177f149", size = 221589, upload-time = "2024-08-21T21:36:20.796Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/1f/935d0810b73184a1d306f92458cb0a2e9b0de2377f536da874e063b8e422/clickhouse_connect-0.7.19-cp312-cp312-win_amd64.whl", hash = "sha256:b771ca6a473d65103dcae82810d3a62475c5372fc38d8f211513c72b954fb020", size = 239584, upload-time = "2024-08-21T21:36:22.105Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/4e/f90caf963d14865c7a3f0e5d80b77e67e0fe0bf39b3de84110707746fa6b/clickhouse_connect-0.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:195f1824405501b747b572e1365c6265bb1629eeb712ce91eda91da3c5794879", size = 272911, upload-time = "2025-11-14T20:29:57.129Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/c7/e01bd2dd80ea4fbda8968e5022c60091a872fd9de0a123239e23851da231/clickhouse_connect-0.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7907624635fe7f28e1b85c7c8b125a72679a63ecdb0b9f4250b704106ef438f8", size = 265938, upload-time = "2025-11-14T20:29:58.443Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f4/07/8b567b949abca296e118331d13380bbdefa4225d7d1d32233c59d4b4b2e1/clickhouse_connect-0.10.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:60772faa54d56f0fa34650460910752a583f5948f44dddeabfafaecbca21fc54", size = 1113548, upload-time = "2025-11-14T20:29:59.781Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/13/11f2d37fc95e74d7e2d80702cde87666ce372486858599a61f5209e35fc5/clickhouse_connect-0.10.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7fe2a6cd98517330c66afe703fb242c0d3aa2c91f2f7dc9fb97c122c5c60c34b", size = 1135061, upload-time = "2025-11-14T20:30:01.244Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/d0/517181ea80060f84d84cff4d42d330c80c77bb352b728fb1f9681fbad291/clickhouse_connect-0.10.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a2427d312bc3526520a0be8c648479af3f6353da7a33a62db2368d6203b08efd", size = 1105105, upload-time = "2025-11-14T20:30:02.679Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/b2/4ad93e898562725b58c537cad83ab2694c9b1c1ef37fa6c3f674bdad366a/clickhouse_connect-0.10.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:63bbb5721bfece698e155c01b8fa95ce4377c584f4d04b43f383824e8a8fa129", size = 1150791, upload-time = "2025-11-14T20:30:03.824Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/a4/fdfbfacc1fa67b8b1ce980adcf42f9e3202325586822840f04f068aff395/clickhouse_connect-0.10.0-cp311-cp311-win32.whl", hash = "sha256:48554e836c6b56fe0854d9a9f565569010583d4960094d60b68a53f9f83042f0", size = 244014, upload-time = "2025-11-14T20:30:05.157Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/08/50/cf53f33f4546a9ce2ab1b9930db4850aa1ae53bff1e4e4fa97c566cdfa19/clickhouse_connect-0.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:9eb8df083e5fda78ac7249938691c2c369e8578b5df34c709467147e8289f1d9", size = 262356, upload-time = "2025-11-14T20:30:06.478Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/59/fadbbf64f4c6496cd003a0a3c9223772409a86d0eea9d4ff45d2aa88aabf/clickhouse_connect-0.10.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b090c7d8e602dd084b2795265cd30610461752284763d9ad93a5d619a0e0ff21", size = 276401, upload-time = "2025-11-14T20:30:07.469Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/e3/781f9970f2ef202410f0d64681e42b2aecd0010097481a91e4df186a36c7/clickhouse_connect-0.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b8a708d38b81dcc8c13bb85549c904817e304d2b7f461246fed2945524b7a31b", size = 268193, upload-time = "2025-11-14T20:30:08.503Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/e0/64ab66b38fce762b77b5203a4fcecc603595f2a2361ce1605fc7bb79c835/clickhouse_connect-0.10.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3646fc9184a5469b95cf4a0846e6954e6e9e85666f030a5d2acae58fa8afb37e", size = 1123810, upload-time = "2025-11-14T20:30:09.62Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f5/03/19121aecf11a30feaf19049be96988131798c54ac6ba646a38e5faecaa0a/clickhouse_connect-0.10.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fe7e6be0f40a8a77a90482944f5cc2aa39084c1570899e8d2d1191f62460365b", size = 1153409, upload-time = "2025-11-14T20:30:10.855Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/ee/63870fd8b666c6030393950ad4ee76b7b69430f5a49a5d3fa32a70b11942/clickhouse_connect-0.10.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:88b4890f13163e163bf6fa61f3a013bb974c95676853b7a4e63061faf33911ac", size = 1104696, upload-time = "2025-11-14T20:30:12.187Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/bc/fcd8da1c4d007ebce088783979c495e3d7360867cfa8c91327ed235778f5/clickhouse_connect-0.10.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6286832cc79affc6fddfbf5563075effa65f80e7cd1481cf2b771ce317c67d08", size = 1156389, upload-time = "2025-11-14T20:30:13.385Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4e/33/7cb99cc3fc503c23fd3a365ec862eb79cd81c8dc3037242782d709280fa9/clickhouse_connect-0.10.0-cp312-cp312-win32.whl", hash = "sha256:92b8b6691a92d2613ee35f5759317bd4be7ba66d39bf81c4deed620feb388ca6", size = 243682, upload-time = "2025-11-14T20:30:14.52Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/48/5c/12eee6a1f5ecda2dfc421781fde653c6d6ca6f3080f24547c0af40485a5a/clickhouse_connect-0.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:1159ee2c33e7eca40b53dda917a8b6a2ed889cb4c54f3d83b303b31ddb4f351d", size = 262790, upload-time = "2025-11-14T20:30:15.555Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1055,6 +1051,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/23/38/749c708619f402d4d582dfa73fbeb64ade77b1f250a93bd064d2a1aa3776/clickzetta_connector_python-0.8.106-py3-none-any.whl", hash = "sha256:120d6700051d97609dbd6655c002ab3bc260b7c8e67d39dfc7191e749563f7b4", size = 78121, upload-time = "2025-10-29T02:38:15.014Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cloudpickle"
|
||||
version = "3.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cloudscraper"
|
||||
version = "1.2.71"
|
||||
|
|
@ -1255,6 +1260,20 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/0d/c3/e90f4a4feae6410f914f8ebac129b9ae7a8c92eb60a638012dde42030a9d/cryptography-46.0.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:6b5063083824e5509fdba180721d55909ffacccc8adbec85268b48439423d78c", size = 3438528, upload-time = "2025-10-15T23:18:26.227Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "databricks-sdk"
|
||||
version = "0.73.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "google-auth" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a8/7f/cfb2a00d10f6295332616e5b22f2ae3aaf2841a3afa6c49262acb6b94f5b/databricks_sdk-0.73.0.tar.gz", hash = "sha256:db09eaaacd98e07dded78d3e7ab47d2f6c886e0380cb577977bd442bace8bd8d", size = 801017, upload-time = "2025-11-05T06:52:58.509Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a7/27/b822b474aaefb684d11df358d52e012699a2a8af231f9b47c54b73f280cb/databricks_sdk-0.73.0-py3-none-any.whl", hash = "sha256:a4d3cfd19357a2b459d2dc3101454d7f0d1b62865ce099c35d0c342b66ac64ff", size = 753896, upload-time = "2025-11-05T06:52:56.451Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dataclasses-json"
|
||||
version = "0.6.7"
|
||||
|
|
@ -1350,6 +1369,7 @@ dependencies = [
|
|||
{ name = "langsmith" },
|
||||
{ name = "litellm" },
|
||||
{ name = "markdown" },
|
||||
{ name = "mlflow-skinny" },
|
||||
{ name = "numpy" },
|
||||
{ name = "openpyxl" },
|
||||
{ name = "opentelemetry-api" },
|
||||
|
|
@ -1544,6 +1564,7 @@ requires-dist = [
|
|||
{ name = "langsmith", specifier = "~=0.1.77" },
|
||||
{ name = "litellm", specifier = "==1.77.1" },
|
||||
{ name = "markdown", specifier = "~=3.5.1" },
|
||||
{ name = "mlflow-skinny", specifier = ">=3.0.0" },
|
||||
{ name = "numpy", specifier = "~=1.26.4" },
|
||||
{ name = "openpyxl", specifier = "~=3.1.5" },
|
||||
{ name = "opentelemetry-api", specifier = "==1.27.0" },
|
||||
|
|
@ -1678,7 +1699,7 @@ vdb = [
|
|||
{ name = "alibabacloud-gpdb20160503", specifier = "~=3.8.0" },
|
||||
{ name = "alibabacloud-tea-openapi", specifier = "~=0.3.9" },
|
||||
{ name = "chromadb", specifier = "==0.5.20" },
|
||||
{ name = "clickhouse-connect", specifier = "~=0.7.16" },
|
||||
{ name = "clickhouse-connect", specifier = "~=0.10.0" },
|
||||
{ name = "clickzetta-connector-python", specifier = ">=0.8.102" },
|
||||
{ name = "couchbase", specifier = "~=4.3.0" },
|
||||
{ name = "elasticsearch", specifier = "==8.14.0" },
|
||||
|
|
@ -3338,6 +3359,36 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/d3/82/41d9b80f09b82e066894d9b508af07b7b0fa325ce0322980674de49106a0/milvus_lite-2.5.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:25ce13f4b8d46876dd2b7ac8563d7d8306da7ff3999bb0d14b116b30f71d706c", size = 55263911, upload-time = "2025-06-30T04:24:19.434Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlflow-skinny"
|
||||
version = "3.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cachetools" },
|
||||
{ name = "click" },
|
||||
{ name = "cloudpickle" },
|
||||
{ name = "databricks-sdk" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "gitpython" },
|
||||
{ name = "importlib-metadata" },
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-proto" },
|
||||
{ name = "opentelemetry-sdk" },
|
||||
{ name = "packaging" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
{ name = "sqlparse" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8d/8e/2a2d0cd5b1b985c5278202805f48aae6f2adc3ddc0fce3385ec50e07e258/mlflow_skinny-3.6.0.tar.gz", hash = "sha256:cc04706b5b6faace9faf95302a6e04119485e1bfe98ddc9b85b81984e80944b6", size = 1963286, upload-time = "2025-11-07T18:33:52.596Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/78/e8fdc3e1708bdfd1eba64f41ce96b461cae1b505aa08b69352ac99b4caa4/mlflow_skinny-3.6.0-py3-none-any.whl", hash = "sha256:c83b34fce592acb2cc6bddcb507587a6d9ef3f590d9e7a8658c85e0980596d78", size = 2364629, upload-time = "2025-11-07T18:33:50.744Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mmh3"
|
||||
version = "5.2.0"
|
||||
|
|
@ -5729,6 +5780,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/9b/70/20c1912bc0bfebf516d59d618209443b136c58a7cff141afa7cf30969988/sqlglot-27.29.0-py3-none-any.whl", hash = "sha256:9a5ea8ac61826a7763de10cad45a35f0aa9bfcf7b96ee74afb2314de9089e1cb", size = 526060, upload-time = "2025-10-29T13:50:22.061Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlparse"
|
||||
version = "0.5.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e5/40/edede8dd6977b0d3da179a342c198ed100dd2aba4be081861ee5911e4da4/sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272", size = 84999, upload-time = "2024-12-10T12:05:30.728Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/5c/bfd6bd0bf979426d405cc6e71eceb8701b148b16c21d2dc3c261efc61c7b/sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca", size = 44415, upload-time = "2024-12-10T12:05:27.824Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sseclient-py"
|
||||
version = "1.8.0"
|
||||
|
|
|
|||
|
|
@ -40,7 +40,9 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
|
|||
- Ensure the `middleware.env` file is created by running `cp middleware.env.example middleware.env` (refer to the `middleware.env.example` file).
|
||||
1. **Running Middleware Services**:
|
||||
- Navigate to the `docker` directory.
|
||||
- Execute `docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d` to start the middleware services. (Change the profile to other vector database if you are not using weaviate)
|
||||
- Execute `docker compose --env-file middleware.env -f docker-compose.middleware.yaml -p dify up -d` to start PostgreSQL/MySQL (per `DB_TYPE`) plus the bundled Weaviate instance.
|
||||
|
||||
> Compose automatically loads `COMPOSE_PROFILES=${DB_TYPE:-postgresql},weaviate` from `middleware.env`, so no extra `--profile` flags are needed. Adjust variables in `middleware.env` if you want a different combination of services.
|
||||
|
||||
### Migration for Existing Users
|
||||
|
||||
|
|
|
|||
|
|
@ -134,6 +134,13 @@ WEAVIATE_AUTHORIZATION_ADMINLIST_ENABLED=true
|
|||
WEAVIATE_AUTHORIZATION_ADMINLIST_USERS=hello@dify.ai
|
||||
WEAVIATE_HOST_VOLUME=./volumes/weaviate
|
||||
|
||||
# ------------------------------
|
||||
# Docker Compose profile configuration
|
||||
# ------------------------------
|
||||
# Loaded automatically when running `docker compose --env-file middleware.env ...`.
|
||||
# Controls which DB/vector services start, so no extra `--profile` flag is needed.
|
||||
COMPOSE_PROFILES=${DB_TYPE:-postgresql},weaviate
|
||||
|
||||
# ------------------------------
|
||||
# Docker Compose Service Expose Host Port Configurations
|
||||
# ------------------------------
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
module.exports = {
|
||||
presets: [
|
||||
[
|
||||
"@babel/preset-env",
|
||||
{
|
||||
targets: {
|
||||
node: "current",
|
||||
},
|
||||
},
|
||||
],
|
||||
],
|
||||
};
|
||||
|
|
@ -71,7 +71,7 @@ export const routes = {
|
|||
},
|
||||
stopWorkflow: {
|
||||
method: "POST",
|
||||
url: (task_id) => `/workflows/${task_id}/stop`,
|
||||
url: (task_id) => `/workflows/tasks/${task_id}/stop`,
|
||||
}
|
||||
|
||||
};
|
||||
|
|
@ -94,11 +94,13 @@ export class DifyClient {
|
|||
stream = false,
|
||||
headerParams = {}
|
||||
) {
|
||||
const isFormData =
|
||||
(typeof FormData !== "undefined" && data instanceof FormData) ||
|
||||
(data && data.constructor && data.constructor.name === "FormData");
|
||||
const headers = {
|
||||
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
"Content-Type": "application/json",
|
||||
...headerParams
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
...(isFormData ? {} : { "Content-Type": "application/json" }),
|
||||
...headerParams,
|
||||
};
|
||||
|
||||
const url = `${this.baseUrl}${endpoint}`;
|
||||
|
|
@ -152,12 +154,7 @@ export class DifyClient {
|
|||
return this.sendRequest(
|
||||
routes.fileUpload.method,
|
||||
routes.fileUpload.url(),
|
||||
data,
|
||||
null,
|
||||
false,
|
||||
{
|
||||
"Content-Type": 'multipart/form-data'
|
||||
}
|
||||
data
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -179,8 +176,8 @@ export class DifyClient {
|
|||
getMeta(user) {
|
||||
const params = { user };
|
||||
return this.sendRequest(
|
||||
routes.meta.method,
|
||||
routes.meta.url(),
|
||||
routes.getMeta.method,
|
||||
routes.getMeta.url(),
|
||||
null,
|
||||
params
|
||||
);
|
||||
|
|
@ -320,12 +317,7 @@ export class ChatClient extends DifyClient {
|
|||
return this.sendRequest(
|
||||
routes.audioToText.method,
|
||||
routes.audioToText.url(),
|
||||
data,
|
||||
null,
|
||||
false,
|
||||
{
|
||||
"Content-Type": 'multipart/form-data'
|
||||
}
|
||||
data
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
import { DifyClient, BASE_URL, routes } from ".";
|
||||
import { DifyClient, WorkflowClient, BASE_URL, routes } from ".";
|
||||
|
||||
import axios from 'axios'
|
||||
|
||||
jest.mock('axios')
|
||||
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks()
|
||||
})
|
||||
|
||||
describe('Client', () => {
|
||||
let difyClient
|
||||
beforeEach(() => {
|
||||
|
|
@ -27,13 +31,9 @@ describe('Send Requests', () => {
|
|||
difyClient = new DifyClient('test')
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks()
|
||||
})
|
||||
|
||||
it('should make a successful request to the application parameter', async () => {
|
||||
const method = 'GET'
|
||||
const endpoint = routes.application.url
|
||||
const endpoint = routes.application.url()
|
||||
const expectedResponse = { data: 'response' }
|
||||
axios.mockResolvedValue(expectedResponse)
|
||||
|
||||
|
|
@ -62,4 +62,80 @@ describe('Send Requests', () => {
|
|||
errorMessage
|
||||
)
|
||||
})
|
||||
|
||||
it('uses the getMeta route configuration', async () => {
|
||||
axios.mockResolvedValue({ data: 'ok' })
|
||||
await difyClient.getMeta('end-user')
|
||||
|
||||
expect(axios).toHaveBeenCalledWith({
|
||||
method: routes.getMeta.method,
|
||||
url: `${BASE_URL}${routes.getMeta.url()}`,
|
||||
params: { user: 'end-user' },
|
||||
headers: {
|
||||
Authorization: `Bearer ${difyClient.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
responseType: 'json',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('File uploads', () => {
|
||||
let difyClient
|
||||
const OriginalFormData = global.FormData
|
||||
|
||||
beforeAll(() => {
|
||||
global.FormData = class FormDataMock {}
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
global.FormData = OriginalFormData
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
difyClient = new DifyClient('test')
|
||||
})
|
||||
|
||||
it('does not override multipart boundary headers for FormData', async () => {
|
||||
const form = new FormData()
|
||||
axios.mockResolvedValue({ data: 'ok' })
|
||||
|
||||
await difyClient.fileUpload(form)
|
||||
|
||||
expect(axios).toHaveBeenCalledWith({
|
||||
method: routes.fileUpload.method,
|
||||
url: `${BASE_URL}${routes.fileUpload.url()}`,
|
||||
data: form,
|
||||
params: null,
|
||||
headers: {
|
||||
Authorization: `Bearer ${difyClient.apiKey}`,
|
||||
},
|
||||
responseType: 'json',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Workflow client', () => {
|
||||
let workflowClient
|
||||
|
||||
beforeEach(() => {
|
||||
workflowClient = new WorkflowClient('test')
|
||||
})
|
||||
|
||||
it('uses tasks stop path for workflow stop', async () => {
|
||||
axios.mockResolvedValue({ data: 'stopped' })
|
||||
await workflowClient.stop('task-1', 'end-user')
|
||||
|
||||
expect(axios).toHaveBeenCalledWith({
|
||||
method: routes.stopWorkflow.method,
|
||||
url: `${BASE_URL}${routes.stopWorkflow.url('task-1')}`,
|
||||
data: { user: 'end-user' },
|
||||
params: null,
|
||||
headers: {
|
||||
Authorization: `Bearer ${workflowClient.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
responseType: 'json',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
module.exports = {
|
||||
testEnvironment: "node",
|
||||
transform: {
|
||||
"^.+\\.[tj]sx?$": "babel-jest",
|
||||
},
|
||||
};
|
||||
|
|
@ -18,11 +18,6 @@
|
|||
"scripts": {
|
||||
"test": "jest"
|
||||
},
|
||||
"jest": {
|
||||
"transform": {
|
||||
"^.+\\.[t|j]sx?$": "babel-jest"
|
||||
}
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.3.5"
|
||||
},
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next'
|
|||
import { useBoolean } from 'ahooks'
|
||||
import TracingIcon from './tracing-icon'
|
||||
import ProviderPanel from './provider-panel'
|
||||
import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
|
||||
import { TracingProvider } from './type'
|
||||
import ProviderConfigModal from './provider-config-modal'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
|
|
@ -30,8 +30,10 @@ export type PopupProps = {
|
|||
opikConfig: OpikConfig | null
|
||||
weaveConfig: WeaveConfig | null
|
||||
aliyunConfig: AliyunConfig | null
|
||||
mlflowConfig: MLflowConfig | null
|
||||
databricksConfig: DatabricksConfig | null
|
||||
tencentConfig: TencentConfig | null
|
||||
onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => void
|
||||
onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig | MLflowConfig | DatabricksConfig) => void
|
||||
onConfigRemoved: (provider: TracingProvider) => void
|
||||
}
|
||||
|
||||
|
|
@ -49,6 +51,8 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||
opikConfig,
|
||||
weaveConfig,
|
||||
aliyunConfig,
|
||||
mlflowConfig,
|
||||
databricksConfig,
|
||||
tencentConfig,
|
||||
onConfigUpdated,
|
||||
onConfigRemoved,
|
||||
|
|
@ -73,7 +77,7 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||
}
|
||||
}, [onChooseProvider])
|
||||
|
||||
const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig) => {
|
||||
const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig) => {
|
||||
onConfigUpdated(currentProvider!, payload)
|
||||
hideConfigModal()
|
||||
}, [currentProvider, hideConfigModal, onConfigUpdated])
|
||||
|
|
@ -83,8 +87,8 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||
hideConfigModal()
|
||||
}, [currentProvider, hideConfigModal, onConfigRemoved])
|
||||
|
||||
const providerAllConfigured = arizeConfig && phoenixConfig && langSmithConfig && langFuseConfig && opikConfig && weaveConfig && aliyunConfig && tencentConfig
|
||||
const providerAllNotConfigured = !arizeConfig && !phoenixConfig && !langSmithConfig && !langFuseConfig && !opikConfig && !weaveConfig && !aliyunConfig && !tencentConfig
|
||||
const providerAllConfigured = arizeConfig && phoenixConfig && langSmithConfig && langFuseConfig && opikConfig && weaveConfig && aliyunConfig && mlflowConfig && databricksConfig && tencentConfig
|
||||
const providerAllNotConfigured = !arizeConfig && !phoenixConfig && !langSmithConfig && !langFuseConfig && !opikConfig && !weaveConfig && !aliyunConfig && !mlflowConfig && !databricksConfig && !tencentConfig
|
||||
|
||||
const switchContent = (
|
||||
<Switch
|
||||
|
|
@ -185,6 +189,32 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||
/>
|
||||
)
|
||||
|
||||
const mlflowPanel = (
|
||||
<ProviderPanel
|
||||
type={TracingProvider.mlflow}
|
||||
readOnly={readOnly}
|
||||
config={mlflowConfig}
|
||||
hasConfigured={!!mlflowConfig}
|
||||
onConfig={handleOnConfig(TracingProvider.mlflow)}
|
||||
isChosen={chosenProvider === TracingProvider.mlflow}
|
||||
onChoose={handleOnChoose(TracingProvider.mlflow)}
|
||||
key="mlflow-provider-panel"
|
||||
/>
|
||||
)
|
||||
|
||||
const databricksPanel = (
|
||||
<ProviderPanel
|
||||
type={TracingProvider.databricks}
|
||||
readOnly={readOnly}
|
||||
config={databricksConfig}
|
||||
hasConfigured={!!databricksConfig}
|
||||
onConfig={handleOnConfig(TracingProvider.databricks)}
|
||||
isChosen={chosenProvider === TracingProvider.databricks}
|
||||
onChoose={handleOnChoose(TracingProvider.databricks)}
|
||||
key="databricks-provider-panel"
|
||||
/>
|
||||
)
|
||||
|
||||
const tencentPanel = (
|
||||
<ProviderPanel
|
||||
type={TracingProvider.tencent}
|
||||
|
|
@ -221,6 +251,12 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||
if (aliyunConfig)
|
||||
configuredPanels.push(aliyunPanel)
|
||||
|
||||
if (mlflowConfig)
|
||||
configuredPanels.push(mlflowPanel)
|
||||
|
||||
if (databricksConfig)
|
||||
configuredPanels.push(databricksPanel)
|
||||
|
||||
if (tencentConfig)
|
||||
configuredPanels.push(tencentPanel)
|
||||
|
||||
|
|
@ -251,6 +287,12 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||
if (!aliyunConfig)
|
||||
notConfiguredPanels.push(aliyunPanel)
|
||||
|
||||
if (!mlflowConfig)
|
||||
notConfiguredPanels.push(mlflowPanel)
|
||||
|
||||
if (!databricksConfig)
|
||||
notConfiguredPanels.push(databricksPanel)
|
||||
|
||||
if (!tencentConfig)
|
||||
notConfiguredPanels.push(tencentPanel)
|
||||
|
||||
|
|
@ -258,6 +300,10 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||
}
|
||||
|
||||
const configuredProviderConfig = () => {
|
||||
if (currentProvider === TracingProvider.mlflow)
|
||||
return mlflowConfig
|
||||
if (currentProvider === TracingProvider.databricks)
|
||||
return databricksConfig
|
||||
if (currentProvider === TracingProvider.arize)
|
||||
return arizeConfig
|
||||
if (currentProvider === TracingProvider.phoenix)
|
||||
|
|
@ -316,6 +362,8 @@ const ConfigPopup: FC<PopupProps> = ({
|
|||
{langfusePanel}
|
||||
{langSmithPanel}
|
||||
{opikPanel}
|
||||
{mlflowPanel}
|
||||
{databricksPanel}
|
||||
{weavePanel}
|
||||
{arizePanel}
|
||||
{phoenixPanel}
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue