diff --git a/api/controllers/common/schema.py b/api/controllers/common/schema.py index e0896a8dc2..a5a3e4ebbd 100644 --- a/api/controllers/common/schema.py +++ b/api/controllers/common/schema.py @@ -1,7 +1,11 @@ """Helpers for registering Pydantic models with Flask-RESTX namespaces.""" +from enum import StrEnum + from flask_restx import Namespace -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter + +from controllers.console import console_ns DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -19,8 +23,25 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No register_schema_model(namespace, model) +def get_or_create_model(model_name: str, field_def): + existing = console_ns.models.get(model_name) + if existing is None: + existing = console_ns.model(model_name, field_def) + return existing + + +def register_enum_models(namespace: Namespace, *models: type[StrEnum]) -> None: + """Register multiple StrEnum with a namespace.""" + for model in models: + namespace.schema_model( + model.__name__, TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) + ) + + __all__ = [ "DEFAULT_REF_TEMPLATE_SWAGGER_2_0", + "get_or_create_model", + "register_enum_models", "register_schema_model", "register_schema_models", ] diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 9b0d4b1a78..c81709e985 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -22,10 +22,10 @@ api_key_fields = { "created_at": TimestampField, } -api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")} - api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields) +api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")} + api_key_list_model = console_ns.model( "ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")} ) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index dad184c54b..8c371da596 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -9,9 +9,11 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest -from controllers.common.schema import register_schema_models +from controllers.common.helpers import FileInfo +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model +from controllers.console.workspace.models import LoadBalancingPayload from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, @@ -22,18 +24,36 @@ from controllers.console.wraps import ( ) from core.file import helpers as file_helpers from core.ops.ops_trace_manager import OpsTraceManager -from core.workflow.enums import NodeType +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.workflow.enums import NodeType, WorkflowExecutionStatus from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required -from models import App, Workflow +from models import App, DatasetPermissionEnum, Workflow from models.model import IconType from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService +from services.entities.knowledge_entities.knowledge_entities import ( + DataSource, + InfoList, + NotionIcon, + NotionInfo, + NotionPage, + PreProcessingRule, + RerankingModel, + Rule, + Segmentation, + WebsiteInfo, + WeightKeywordSetting, + WeightModel, + WeightVectorSetting, +) from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] +register_enum_models(console_ns, IconType) + class AppListQuery(BaseModel): page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)") @@ -151,7 +171,7 @@ def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | if icon is None or icon_type is None: return None icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) - if icon_type_value.lower() != IconType.IMAGE.value: + if icon_type_value.lower() != IconType.IMAGE: return None return file_helpers.get_signed_file_url(icon) @@ -391,6 +411,8 @@ class AppExportResponse(ResponseModel): data: str +register_enum_models(console_ns, RetrievalMethod, WorkflowExecutionStatus, DatasetPermissionEnum) + register_schema_models( console_ns, AppListQuery, @@ -414,6 +436,22 @@ register_schema_models( AppDetailWithSite, AppPagination, AppExportResponse, + Segmentation, + PreProcessingRule, + Rule, + WeightVectorSetting, + WeightKeywordSetting, + WeightModel, + RerankingModel, + InfoList, + NotionInfo, + FileInfo, + WebsiteInfo, + NotionPage, + NotionIcon, + RerankingModel, + DataSource, + LoadBalancingPayload, ) diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 22e2aeb720..fdef54ba5a 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -41,14 +41,14 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class AppImportPayload(BaseModel): mode: str = Field(..., description="Import mode") - yaml_content: str | None = None - yaml_url: str | None = None - name: str | None = None - description: str | None = None - icon_type: str | None = None - icon: str | None = None - icon_background: str | None = None - app_id: str | None = None + yaml_content: str | None = Field(None) + yaml_url: str | None = Field(None) + name: str | None = Field(None) + description: str | None = Field(None) + icon_type: str | None = Field(None) + icon: str | None = Field(None) + icon_background: str | None = Field(None) + app_id: str | None = Field(None) console_ns.schema_model( diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index acaf85a6b1..755463cb70 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -12,6 +12,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from controllers.console import console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync +from controllers.console.app.workflow_run import workflow_run_node_execution_model from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError @@ -35,7 +36,6 @@ from extensions.ext_database import db from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields -from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value @@ -88,26 +88,6 @@ workflow_pagination_fields_copy = workflow_pagination_fields.copy() workflow_pagination_fields_copy["items"] = fields.List(fields.Nested(workflow_model), attribute="items") workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagination_fields_copy) -# Reuse workflow_run_node_execution_model from workflow_run.py if already registered -# Otherwise register it here -from fields.end_user_fields import simple_end_user_fields - -simple_end_user_model = None -try: - simple_end_user_model = console_ns.models.get("SimpleEndUser") -except AttributeError: - pass -if simple_end_user_model is None: - simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields) - -workflow_run_node_execution_model = None -try: - workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution") -except AttributeError: - pass -if workflow_run_node_execution_model is None: - workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields) - class SyncDraftWorkflowPayload(BaseModel): graph: dict[str, Any] diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index 9433b732e4..8236e766ae 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -1,13 +1,14 @@ import logging from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from configs import dify_config +from controllers.common.schema import get_or_create_model from extensions.ext_database import db from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields from libs.login import current_user, login_required @@ -22,6 +23,14 @@ from ..wraps import account_initialization_required, edit_permission_required, s logger = logging.getLogger(__name__) DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" +trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields) + +triggers_list_fields_copy = triggers_list_fields.copy() +triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model)) +triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy) + +webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields) + class Parser(BaseModel): node_id: str @@ -48,7 +57,7 @@ class WebhookTriggerApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(webhook_trigger_fields) + @marshal_with(webhook_trigger_model) def get(self, app_model: App): """Get webhook trigger for a node""" args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -80,7 +89,7 @@ class AppTriggersApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(triggers_list_fields) + @marshal_with(triggers_list_model) def get(self, app_model: App): """Get app triggers list""" assert isinstance(current_user, Account) @@ -120,7 +129,7 @@ class AppTriggerEnableApi(Resource): @account_initialization_required @edit_permission_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(trigger_fields) + @marshal_with(trigger_model) def post(self, app_model: App): """Update app trigger (enable/disable)""" args = ParserEnable.model_validate(console_ns.payload) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index d05e726dcb..01e9bf77c0 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -3,13 +3,13 @@ from collections.abc import Generator from typing import Any, cast from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.common.schema import register_schema_model +from controllers.common.schema import get_or_create_model, register_schema_model from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.indexing_runner import IndexingRunner @@ -17,7 +17,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db -from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields +from fields.data_source_fields import ( + integrate_fields, + integrate_icon_fields, + integrate_list_fields, + integrate_notion_info_list_fields, + integrate_page_fields, + integrate_workspace_fields, +) from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DataSourceOauthBinding, Document @@ -49,6 +56,49 @@ class DataSourceNotionPreviewQuery(BaseModel): register_schema_model(console_ns, NotionEstimatePayload) +integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields) + +integrate_page_fields_copy = integrate_page_fields.copy() +integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True) +integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy) + +integrate_workspace_fields_copy = integrate_workspace_fields.copy() +integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model)) +integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy) + +integrate_fields_copy = integrate_fields.copy() +integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model) +integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy) + +integrate_list_fields_copy = integrate_list_fields.copy() +integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model)) +integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy) + +notion_page_fields = { + "page_name": fields.String, + "page_id": fields.String, + "page_icon": fields.Nested(integrate_icon_model, allow_null=True), + "is_bound": fields.Boolean, + "parent_id": fields.String, + "type": fields.String, +} +notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields) + +notion_workspace_fields = { + "workspace_name": fields.String, + "workspace_id": fields.String, + "workspace_icon": fields.String, + "pages": fields.List(fields.Nested(notion_page_model)), +} +notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields) + +integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy() +integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model)) +integrate_notion_info_list_model = get_or_create_model( + "NotionIntegrateInfoList", integrate_notion_info_list_fields_copy +) + + @console_ns.route( "/data-source/integrates", "/data-source/integrates//", @@ -57,7 +107,7 @@ class DataSourceApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(integrate_list_fields) + @marshal_with(integrate_list_model) def get(self): _, current_tenant_id = current_account_with_tenant() @@ -142,7 +192,7 @@ class DataSourceNotionListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(integrate_notion_info_list_fields) + @marshal_with(integrate_notion_info_list_model) def get(self): current_user, current_tenant_id = current_account_with_tenant() diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 37c828c3a8..8fbbc51e21 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -8,7 +8,7 @@ from werkzeug.exceptions import Forbidden, NotFound import services from configs import dify_config -from controllers.common.schema import register_schema_models +from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns from controllers.console.apikey import ( api_key_item_model, @@ -34,6 +34,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from fields.app_fields import app_detail_kernel_fields, related_app_list from fields.dataset_fields import ( + content_fields, dataset_detail_fields, dataset_fields, dataset_query_detail_fields, @@ -41,6 +42,7 @@ from fields.dataset_fields import ( doc_metadata_fields, external_knowledge_info_fields, external_retrieval_model_fields, + file_info_fields, icon_info_fields, keyword_setting_fields, reranking_model_fields, @@ -55,41 +57,33 @@ from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService - -def _get_or_create_model(model_name: str, field_def): - existing = console_ns.models.get(model_name) - if existing is None: - existing = console_ns.model(model_name, field_def) - return existing - - # Register models for flask_restx to avoid dict type issues in Swagger -dataset_base_model = _get_or_create_model("DatasetBase", dataset_fields) +dataset_base_model = get_or_create_model("DatasetBase", dataset_fields) -tag_model = _get_or_create_model("Tag", tag_fields) +tag_model = get_or_create_model("Tag", tag_fields) -keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields) -vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields) +keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields) +vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields) weighted_score_fields_copy = weighted_score_fields.copy() weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model) weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model) -weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy) +weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy) -reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields) +reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields) dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy() dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model) dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True) -dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy) +dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy) -external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields) +external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields) -external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields) +external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields) -doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields) +doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields) -icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields) +icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields) dataset_detail_fields_copy = dataset_detail_fields.copy() dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model) @@ -98,14 +92,22 @@ dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_k dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True) dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model)) dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model) -dataset_detail_model = _get_or_create_model("DatasetDetail", dataset_detail_fields_copy) +dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy) -dataset_query_detail_model = _get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields) +file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields) -app_detail_kernel_model = _get_or_create_model("AppDetailKernel", app_detail_kernel_fields) +content_fields_copy = content_fields.copy() +content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True) +content_model = get_or_create_model("DatasetContent", content_fields_copy) + +dataset_query_detail_fields_copy = dataset_query_detail_fields.copy() +dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model) +dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy) + +app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields) related_app_list_copy = related_app_list.copy() related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model)) -related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy) +related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy) def _validate_indexing_technique(value: str | None) -> str | None: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 2599e6293a..57fb9abf29 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -14,7 +14,7 @@ from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound import services -from controllers.common.schema import register_schema_models +from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns from core.errors.error import ( LLMBadRequestError, @@ -72,34 +72,27 @@ logger = logging.getLogger(__name__) DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100 -def _get_or_create_model(model_name: str, field_def): - existing = console_ns.models.get(model_name) - if existing is None: - existing = console_ns.model(model_name, field_def) - return existing - - # Register models for flask_restx to avoid dict type issues in Swagger -dataset_model = _get_or_create_model("Dataset", dataset_fields) +dataset_model = get_or_create_model("Dataset", dataset_fields) -document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields) +document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields) document_fields_copy = document_fields.copy() document_fields_copy["doc_metadata"] = fields.List( fields.Nested(document_metadata_model), attribute="doc_metadata_details" ) -document_model = _get_or_create_model("Document", document_fields_copy) +document_model = get_or_create_model("Document", document_fields_copy) document_with_segments_fields_copy = document_with_segments_fields.copy() document_with_segments_fields_copy["doc_metadata"] = fields.List( fields.Nested(document_metadata_model), attribute="doc_metadata_details" ) -document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy) +document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy) dataset_and_document_fields_copy = dataset_and_document_fields.copy() dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model) dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model)) -dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) +dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) class DocumentRetryPayload(BaseModel): @@ -1178,7 +1171,7 @@ class DocumentRenameApi(DocumentResource): @setup_required @login_required @account_initialization_required - @marshal_with(document_fields) + @marshal_with(document_model) @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) def post(self, dataset_id, document_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 16fecb41c6..08e1ddd3e0 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -90,6 +90,7 @@ register_schema_models( ChildChunkCreatePayload, ChildChunkUpdatePayload, ChildChunkBatchUpdatePayload, + ChildChunkUpdateArgs, ) diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 588eb6e1b8..86090bcd10 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services -from controllers.common.schema import register_schema_models +from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required @@ -28,34 +28,27 @@ from services.hit_testing_service import HitTestingService from services.knowledge_service import ExternalDatasetTestService -def _get_or_create_model(model_name: str, field_def): - existing = console_ns.models.get(model_name) - if existing is None: - existing = console_ns.model(model_name, field_def) - return existing - - def _build_dataset_detail_model(): - keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields) - vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields) + keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields) + vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields) weighted_score_fields_copy = weighted_score_fields.copy() weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model) weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model) - weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy) + weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy) - reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields) + reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields) dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy() dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model) dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True) - dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy) + dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy) - tag_model = _get_or_create_model("Tag", tag_fields) - doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields) - external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields) - external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields) - icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields) + tag_model = get_or_create_model("Tag", tag_fields) + doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields) + external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields) + external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields) + icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields) dataset_detail_fields_copy = dataset_detail_fields.copy() dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model) @@ -64,7 +57,7 @@ def _build_dataset_detail_model(): dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True) dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model)) dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model) - return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy) + return get_or_create_model("DatasetDetail", dataset_detail_fields_copy) try: diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 8eead1696a..05fc4cd714 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -4,14 +4,16 @@ from flask_restx import Resource, marshal_with from pydantic import BaseModel from werkzeug.exceptions import NotFound -from controllers.common.schema import register_schema_model, register_schema_models +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from fields.dataset_fields import dataset_metadata_fields from libs.login import current_account_with_tenant, login_required from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, MetadataArgs, + MetadataDetail, MetadataOperationData, ) from services.metadata_service import MetadataService @@ -21,8 +23,9 @@ class MetadataUpdatePayload(BaseModel): name: str -register_schema_models(console_ns, MetadataArgs, MetadataOperationData) -register_schema_model(console_ns, MetadataUpdatePayload) +register_schema_models( + console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail +) @console_ns.route("/datasets//metadata") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 720e2ce365..2911b1cf18 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -2,7 +2,7 @@ import logging from typing import Any, NoReturn from flask import Response, request -from flask_restx import Resource, fields, marshal, marshal_with +from flask_restx import Resource, marshal, marshal_with from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -14,7 +14,9 @@ from controllers.console.app.error import ( ) from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage] - _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage] + workflow_draft_variable_list_model, + workflow_draft_variable_list_without_value_model, + workflow_draft_variable_model, ) from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required @@ -27,7 +29,6 @@ from factories.variable_factory import build_segment_with_type from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline -from models.workflow import WorkflowDraftVariable from services.rag_pipeline.rag_pipeline import RagPipelineService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService @@ -52,20 +53,6 @@ class WorkflowDraftVariablePatchPayload(BaseModel): register_schema_models(console_ns, WorkflowDraftVariablePatchPayload) -def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: - return var_list.variables - - -_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = { - "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items), - "total": fields.Raw(), -} - -_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = { - "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items), -} - - def _api_prerequisite(f): """Common prerequisites for all draft workflow variable APIs. @@ -92,7 +79,7 @@ def _api_prerequisite(f): @console_ns.route("/rag/pipelines//workflows/draft/variables") class RagPipelineVariableCollectionApi(Resource): @_api_prerequisite - @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) + @marshal_with(workflow_draft_variable_list_without_value_model) def get(self, pipeline: Pipeline): """ Get draft workflow @@ -150,7 +137,7 @@ def validate_node_id(node_id: str) -> NoReturn | None: @console_ns.route("/rag/pipelines//workflows/draft/nodes//variables") class RagPipelineNodeVariableCollectionApi(Resource): @_api_prerequisite - @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + @marshal_with(workflow_draft_variable_list_model) def get(self, pipeline: Pipeline, node_id: str): validate_node_id(node_id) with Session(bind=db.engine, expire_on_commit=False) as session: @@ -176,7 +163,7 @@ class RagPipelineVariableApi(Resource): _PATCH_VALUE_FIELD = "value" @_api_prerequisite - @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + @marshal_with(workflow_draft_variable_model) def get(self, pipeline: Pipeline, variable_id: str): draft_var_srv = WorkflowDraftVariableService( session=db.session(), @@ -189,7 +176,7 @@ class RagPipelineVariableApi(Resource): return variable @_api_prerequisite - @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + @marshal_with(workflow_draft_variable_model) @console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__]) def patch(self, pipeline: Pipeline, variable_id: str): # Request payload for file types: @@ -307,7 +294,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList @console_ns.route("/rag/pipelines//workflows/draft/system-variables") class RagPipelineSystemVariableCollectionApi(Resource): @_api_prerequisite - @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) + @marshal_with(workflow_draft_variable_list_model) def get(self, pipeline: Pipeline): return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index d43ee9a6e0..af142b4646 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -1,9 +1,9 @@ from flask import request -from flask_restx import Resource, marshal_with # type: ignore +from flask_restx import Resource, fields, marshal_with # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session -from controllers.common.schema import register_schema_models +from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( @@ -12,7 +12,11 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields +from fields.rag_pipeline_fields import ( + leaked_dependency_fields, + pipeline_import_check_dependencies_fields, + pipeline_import_fields, +) from libs.login import current_account_with_tenant, login_required from models.dataset import Pipeline from services.app_dsl_service import ImportStatus @@ -38,13 +42,25 @@ class IncludeSecretQuery(BaseModel): register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery) +pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields) + +leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields) +pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy() +pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List( + fields.Nested(leaked_dependency_model) +) +pipeline_import_check_dependencies_model = get_or_create_model( + "RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy +) + + @console_ns.route("/rag/pipelines/imports") class RagPipelineImportApi(Resource): @setup_required @login_required @account_initialization_required @edit_permission_required - @marshal_with(pipeline_import_fields) + @marshal_with(pipeline_import_model) @console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__]) def post(self): # Check user role first @@ -81,7 +97,7 @@ class RagPipelineImportConfirmApi(Resource): @login_required @account_initialization_required @edit_permission_required - @marshal_with(pipeline_import_fields) + @marshal_with(pipeline_import_model) def post(self, import_id): current_user, _ = current_account_with_tenant() @@ -106,7 +122,7 @@ class RagPipelineImportCheckDependenciesApi(Resource): @get_rag_pipeline @account_initialization_required @edit_permission_required - @marshal_with(pipeline_import_check_dependencies_fields) + @marshal_with(pipeline_import_check_dependencies_model) def get(self, pipeline: Pipeline): with Session(db.engine) as session: import_service = RagPipelineDslService(session) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 02efc54eea..d34fd5088d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -17,6 +17,13 @@ from controllers.console.app.error import ( DraftWorkflowNotExist, DraftWorkflowNotSync, ) +from controllers.console.app.workflow import workflow_model, workflow_pagination_model +from controllers.console.app.workflow_run import ( + workflow_run_detail_model, + workflow_run_node_execution_list_model, + workflow_run_node_execution_model, + workflow_run_pagination_model, +) from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( account_initialization_required, @@ -30,13 +37,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory -from fields.workflow_fields import workflow_fields, workflow_pagination_fields -from fields.workflow_run_fields import ( - workflow_run_detail_fields, - workflow_run_node_execution_fields, - workflow_run_node_execution_list_fields, - workflow_run_pagination_fields, -) from libs import helper from libs.helper import TimestampField from libs.login import current_account_with_tenant, current_user, login_required @@ -145,7 +145,7 @@ class DraftRagPipelineApi(Resource): @account_initialization_required @get_rag_pipeline @edit_permission_required - @marshal_with(workflow_fields) + @marshal_with(workflow_model) def get(self, pipeline: Pipeline): """ Get draft rag pipeline's workflow @@ -521,7 +521,7 @@ class RagPipelineDraftNodeRunApi(Resource): @edit_permission_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_node_execution_fields) + @marshal_with(workflow_run_node_execution_model) def post(self, pipeline: Pipeline, node_id: str): """ Run draft workflow node @@ -569,7 +569,7 @@ class PublishedRagPipelineApi(Resource): @account_initialization_required @edit_permission_required @get_rag_pipeline - @marshal_with(workflow_fields) + @marshal_with(workflow_model) def get(self, pipeline: Pipeline): """ Get published pipeline @@ -664,7 +664,7 @@ class PublishedAllRagPipelineApi(Resource): @account_initialization_required @edit_permission_required @get_rag_pipeline - @marshal_with(workflow_pagination_fields) + @marshal_with(workflow_pagination_model) def get(self, pipeline: Pipeline): """ Get published workflows @@ -708,7 +708,7 @@ class RagPipelineByIdApi(Resource): @account_initialization_required @edit_permission_required @get_rag_pipeline - @marshal_with(workflow_fields) + @marshal_with(workflow_model) def patch(self, pipeline: Pipeline, workflow_id: str): """ Update workflow attributes @@ -830,7 +830,7 @@ class RagPipelineWorkflowRunListApi(Resource): @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_pagination_fields) + @marshal_with(workflow_run_pagination_model) def get(self, pipeline: Pipeline): """ Get workflow run list @@ -858,7 +858,7 @@ class RagPipelineWorkflowRunDetailApi(Resource): @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_detail_fields) + @marshal_with(workflow_run_detail_model) def get(self, pipeline: Pipeline, run_id): """ Get workflow run detail @@ -877,7 +877,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource): @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_node_execution_list_fields) + @marshal_with(workflow_run_node_execution_list_model) def get(self, pipeline: Pipeline, run_id: str): """ Get workflow run node execution list @@ -911,7 +911,7 @@ class RagPipelineWorkflowLastRunApi(Resource): @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_node_execution_fields) + @marshal_with(workflow_run_node_execution_model) def get(self, pipeline: Pipeline, node_id: str): rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) @@ -952,7 +952,7 @@ class RagPipelineDatasourceVariableApi(Resource): @account_initialization_required @get_rag_pipeline @edit_permission_required - @marshal_with(workflow_run_node_execution_fields) + @marshal_with(workflow_run_node_execution_model) def post(self, pipeline: Pipeline): """ Set datasource variables diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index b77eac605e..aca766567f 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -2,16 +2,17 @@ import logging from typing import Any from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound +from controllers.common.schema import get_or_create_model from controllers.console import console_ns from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from fields.installed_app_fields import installed_app_list_fields +from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import App, InstalledApp, RecommendedApp @@ -35,11 +36,22 @@ class InstalledAppsListQuery(BaseModel): logger = logging.getLogger(__name__) +app_model = get_or_create_model("InstalledAppInfo", app_fields) + +installed_app_fields_copy = installed_app_fields.copy() +installed_app_fields_copy["app"] = fields.Nested(app_model) +installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy) + +installed_app_list_fields_copy = installed_app_list_fields.copy() +installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model)) +installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy) + + @console_ns.route("/installed-apps") class InstalledAppsListApi(Resource): @login_required @account_initialization_required - @marshal_with(installed_app_list_fields) + @marshal_with(installed_app_list_model) def get(self): query = InstalledAppsListQuery.model_validate(request.args.to_dict()) current_user, current_tenant_id = current_account_with_tenant() diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 362513ec1c..c9920c97cf 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -3,6 +3,7 @@ from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field from constants.languages import languages +from controllers.common.schema import get_or_create_model from controllers.console import console_ns from controllers.console.wraps import account_initialization_required from libs.helper import AppIconUrlField @@ -19,8 +20,10 @@ app_fields = { "icon_background": fields.String, } +app_model = get_or_create_model("RecommendedAppInfo", app_fields) + recommended_app_fields = { - "app": fields.Nested(app_fields, attribute="app"), + "app": fields.Nested(app_model, attribute="app"), "app_id": fields.String, "description": fields.String(attribute="description"), "copyright": fields.String, @@ -32,11 +35,15 @@ recommended_app_fields = { "can_trial": fields.Boolean, } +recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields) + recommended_app_list_fields = { - "recommended_apps": fields.List(fields.Nested(recommended_app_fields)), + "recommended_apps": fields.List(fields.Nested(recommended_app_model)), "categories": fields.List(fields.String), } +recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields) + class RecommendedAppsQuery(BaseModel): language: str | None = Field(default=None) @@ -53,7 +60,7 @@ class RecommendedAppListApi(Resource): @console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__]) @login_required @account_initialization_required - @marshal_with(recommended_app_list_fields) + @marshal_with(recommended_app_list_model) def get(self): # language args args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 97d856bebe..1eb0cdb019 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -2,13 +2,14 @@ import logging from typing import Any, cast from flask import request -from flask_restx import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from controllers.common.fields import Parameters as ParametersResponse from controllers.common.fields import Site as SiteResponse -from controllers.console import api +from controllers.common.schema import get_or_create_model +from controllers.console import api, console_ns from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -42,9 +43,21 @@ from core.errors.error import ( from core.model_runtime.errors.invoke import InvokeError from core.workflow.graph_engine.manager import GraphEngineManager from extensions.ext_database import db -from fields.app_fields import app_detail_fields_with_site +from fields.app_fields import ( + app_detail_fields_with_site, + deleted_tool_fields, + model_config_fields, + site_fields, + tag_fields, +) from fields.dataset_fields import dataset_fields -from fields.workflow_fields import workflow_fields +from fields.member_fields import build_simple_account_model +from fields.workflow_fields import ( + conversation_variable_fields, + pipeline_variable_fields, + workflow_fields, + workflow_partial_fields, +) from libs import helper from libs.helper import uuid_value from libs.login import current_user @@ -74,6 +87,36 @@ from services.recommended_app_service import RecommendedAppService logger = logging.getLogger(__name__) +model_config_model = get_or_create_model("TrialAppModelConfig", model_config_fields) +workflow_partial_model = get_or_create_model("TrialWorkflowPartial", workflow_partial_fields) +deleted_tool_model = get_or_create_model("TrialDeletedTool", deleted_tool_fields) +tag_model = get_or_create_model("TrialTag", tag_fields) +site_model = get_or_create_model("TrialSite", site_fields) + +app_detail_fields_with_site_copy = app_detail_fields_with_site.copy() +app_detail_fields_with_site_copy["model_config"] = fields.Nested( + model_config_model, attribute="app_model_config", allow_null=True +) +app_detail_fields_with_site_copy["workflow"] = fields.Nested(workflow_partial_model, allow_null=True) +app_detail_fields_with_site_copy["deleted_tools"] = fields.List(fields.Nested(deleted_tool_model)) +app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model)) +app_detail_fields_with_site_copy["site"] = fields.Nested(site_model) +app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy) + +simple_account_model = build_simple_account_model(console_ns) +conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields) +pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields) + +workflow_fields_copy = workflow_fields.copy() +workflow_fields_copy["created_by"] = fields.Nested(simple_account_model, attribute="created_by_account") +workflow_fields_copy["updated_by"] = fields.Nested( + simple_account_model, attribute="updated_by_account", allow_null=True +) +workflow_fields_copy["conversation_variables"] = fields.List(fields.Nested(conversation_variable_model)) +workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipeline_variable_model)) +workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy) + + class TrialAppWorkflowRunApi(TrialAppResource): def post(self, trial_app): """ @@ -437,7 +480,7 @@ class TrialAppParameterApi(Resource): class AppApi(Resource): @trial_feature_enable @get_app_model_with_trial - @marshal_with(app_detail_fields_with_site) + @marshal_with(app_detail_with_site_model) def get(self, app_model): """Get app detail""" @@ -450,7 +493,7 @@ class AppApi(Resource): class AppWorkflowApi(Resource): @trial_feature_enable @get_app_model_with_trial - @marshal_with(workflow_fields) + @marshal_with(workflow_model) def get(self, app_model): """Get workflow detail""" if not app_model.workflow_id: diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 527aabbc3d..38c66525b3 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -171,6 +171,19 @@ reg(ChangeEmailValidityPayload) reg(ChangeEmailResetPayload) reg(CheckEmailUniquePayload) +integrate_fields = { + "provider": fields.String, + "created_at": TimestampField, + "is_bound": fields.Boolean, + "link": fields.String, +} + +integrate_model = console_ns.model("AccountIntegrate", integrate_fields) +integrate_list_model = console_ns.model( + "AccountIntegrateList", + {"data": fields.List(fields.Nested(integrate_model))}, +) + @console_ns.route("/account/init") class AccountInitApi(Resource): @@ -336,21 +349,10 @@ class AccountPasswordApi(Resource): @console_ns.route("/account/integrates") class AccountIntegrateApi(Resource): - integrate_fields = { - "provider": fields.String, - "created_at": TimestampField, - "is_bound": fields.Boolean, - "link": fields.String, - } - - integrate_list_fields = { - "data": fields.List(fields.Nested(integrate_fields)), - } - @setup_required @login_required @account_initialization_required - @marshal_with(integrate_list_fields) + @marshal_with(integrate_list_model) def get(self): account, _ = current_account_with_tenant() diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 01cca2a8a0..271cdce3c3 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,11 +1,12 @@ from urllib import parse from flask import abort, request -from flask_restx import Resource, marshal_with +from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field import services from configs import dify_config +from controllers.common.schema import get_or_create_model, register_enum_models from controllers.console import console_ns from controllers.console.auth.error import ( CannotTransferOwnerToSelfError, @@ -24,7 +25,7 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.member_fields import account_with_role_list_fields +from fields.member_fields import account_with_role_fields, account_with_role_list_fields from libs.helper import extract_remote_ip from libs.login import current_account_with_tenant, login_required from models.account import Account, TenantAccountRole @@ -67,6 +68,13 @@ reg(MemberRoleUpdatePayload) reg(OwnerTransferEmailPayload) reg(OwnerTransferCheckPayload) reg(OwnerTransferPayload) +register_enum_models(console_ns, TenantAccountRole) + +account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields) + +account_with_role_list_fields_copy = account_with_role_list_fields.copy() +account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model)) +account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy) @console_ns.route("/workspaces/current/members") @@ -76,7 +84,7 @@ class MemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_fields) + @marshal_with(account_with_role_list_model) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: @@ -227,7 +235,7 @@ class DatasetOperatorMemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_fields) + @marshal_with(account_with_role_list_model) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 2def57ed7b..583e3e3057 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -5,6 +5,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.entities.model_entities import ModelType @@ -23,12 +24,13 @@ class ParserGetDefault(BaseModel): model_type: ModelType -class ParserPostDefault(BaseModel): - class Inner(BaseModel): - model_type: ModelType - model: str | None = None - provider: str | None = None +class Inner(BaseModel): + model_type: ModelType + model: str | None = None + provider: str | None = None + +class ParserPostDefault(BaseModel): model_settings: list[Inner] @@ -105,19 +107,21 @@ class ParserParameter(BaseModel): model: str -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models( + console_ns, + ParserGetDefault, + ParserPostDefault, + ParserDeleteModels, + ParserPostModels, + ParserGetCredentials, + ParserCreateCredential, + ParserUpdateCredential, + ParserDeleteCredential, + ParserParameter, + Inner, +) - -reg(ParserGetDefault) -reg(ParserPostDefault) -reg(ParserDeleteModels) -reg(ParserPostModels) -reg(ParserGetCredentials) -reg(ParserCreateCredential) -reg(ParserUpdateCredential) -reg(ParserDeleteCredential) -reg(ParserParameter) +register_enum_models(console_ns, ModelType) @console_ns.route("/workspaces/current/default-model") diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index ea74fc0337..d1485bc1c0 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden from configs import dify_config +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required @@ -20,57 +21,12 @@ from services.plugin.plugin_parameter_service import PluginParameterService from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_service import PluginService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - - -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -@console_ns.route("/workspaces/current/plugin/debugging-key") -class PluginDebuggingKeyApi(Resource): - @setup_required - @login_required - @account_initialization_required - @plugin_permission_required(debug_required=True) - def get(self): - _, tenant_id = current_account_with_tenant() - - try: - return { - "key": PluginService.get_debugging_key(tenant_id), - "host": dify_config.PLUGIN_REMOTE_INSTALL_HOST, - "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT, - } - except PluginDaemonClientSideError as e: - raise ValueError(e) - class ParserList(BaseModel): page: int = Field(default=1, ge=1, description="Page number") page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)") -reg(ParserList) - - -@console_ns.route("/workspaces/current/plugin/list") -class PluginListApi(Resource): - @console_ns.expect(console_ns.models[ParserList.__name__]) - @setup_required - @login_required - @account_initialization_required - def get(self): - _, tenant_id = current_account_with_tenant() - args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore - try: - plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) - except PluginDaemonClientSideError as e: - raise ValueError(e) - - return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) - - class ParserLatest(BaseModel): plugin_ids: list[str] @@ -180,23 +136,73 @@ class ParserReadme(BaseModel): language: str = Field(default="en-US") -reg(ParserLatest) -reg(ParserIcon) -reg(ParserAsset) -reg(ParserGithubUpload) -reg(ParserPluginIdentifiers) -reg(ParserGithubInstall) -reg(ParserPluginIdentifierQuery) -reg(ParserTasks) -reg(ParserMarketplaceUpgrade) -reg(ParserGithubUpgrade) -reg(ParserUninstall) -reg(ParserPermissionChange) -reg(ParserDynamicOptions) -reg(ParserDynamicOptionsWithCredentials) -reg(ParserPreferencesChange) -reg(ParserExcludePlugin) -reg(ParserReadme) +register_schema_models( + console_ns, + ParserList, + PluginAutoUpgradeSettingsPayload, + PluginPermissionSettingsPayload, + ParserLatest, + ParserIcon, + ParserAsset, + ParserGithubUpload, + ParserPluginIdentifiers, + ParserGithubInstall, + ParserPluginIdentifierQuery, + ParserTasks, + ParserMarketplaceUpgrade, + ParserGithubUpgrade, + ParserUninstall, + ParserPermissionChange, + ParserDynamicOptions, + ParserDynamicOptionsWithCredentials, + ParserPreferencesChange, + ParserExcludePlugin, + ParserReadme, +) + +register_enum_models( + console_ns, + TenantPluginPermission.DebugPermission, + TenantPluginAutoUpgradeStrategy.UpgradeMode, + TenantPluginAutoUpgradeStrategy.StrategySetting, + TenantPluginPermission.InstallPermission, +) + + +@console_ns.route("/workspaces/current/plugin/debugging-key") +class PluginDebuggingKeyApi(Resource): + @setup_required + @login_required + @account_initialization_required + @plugin_permission_required(debug_required=True) + def get(self): + _, tenant_id = current_account_with_tenant() + + try: + return { + "key": PluginService.get_debugging_key(tenant_id), + "host": dify_config.PLUGIN_REMOTE_INSTALL_HOST, + "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT, + } + except PluginDaemonClientSideError as e: + raise ValueError(e) + + +@console_ns.route("/workspaces/current/plugin/list") +class PluginListApi(Resource): + @console_ns.expect(console_ns.models[ParserList.__name__]) + @setup_required + @login_required + @account_initialization_required + def get(self): + _, tenant_id = current_account_with_tenant() + args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore + try: + plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) + except PluginDaemonClientSideError as e: + raise ValueError(e) + + return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) @console_ns.route("/workspaces/current/plugin/list/latest-versions") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 2af269fd91..28864a140a 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,7 +2,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import Forbidden, NotFound import services @@ -26,6 +26,14 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.tag_service import TagService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +service_api_ns.schema_model( + DatasetPermissionEnum.__name__, + TypeAdapter(DatasetPermissionEnum).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + class DatasetCreatePayload(BaseModel): name: str = Field(..., min_length=1, max_length=40) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 1260645624..c85c1cf81e 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -16,6 +16,7 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_enum_models, register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.dataset.error import ( @@ -29,12 +30,20 @@ from controllers.service_api.wraps import ( cloud_edition_billing_resource_check, ) from core.errors.error import ProviderTokenNotInitError +from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment from services.dataset_service import DatasetService, DocumentService -from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel +from services.entities.knowledge_entities.knowledge_entities import ( + KnowledgeConfig, + PreProcessingRule, + ProcessRule, + RetrievalModel, + Rule, + Segmentation, +) from services.file_service import FileService @@ -76,8 +85,19 @@ class DocumentListQuery(BaseModel): status: str | None = Field(default=None, description="Document status filter") -for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate, DocumentListQuery]: - service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore +register_enum_models(service_api_ns, RetrievalMethod) + +register_schema_models( + service_api_ns, + ProcessRule, + RetrievalModel, + DocumentTextCreatePayload, + DocumentTextUpdate, + DocumentListQuery, + Rule, + PreProcessingRule, + Segmentation, +) @service_api_ns.route( diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index b242fd2c3e..95679e6fcb 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -60,6 +60,7 @@ register_schema_models( service_api_ns, SegmentCreatePayload, SegmentListQuery, + SegmentUpdateArgs, SegmentUpdatePayload, ChildChunkCreatePayload, ChildChunkListQuery, diff --git a/api/libs/login.py b/api/libs/login.py index 4b8ee2d1f8..73caa492fe 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from collections.abc import Callable from functools import wraps -from typing import Any +from typing import TYPE_CHECKING, Any from flask import current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS @@ -9,7 +11,9 @@ from werkzeug.local import LocalProxy from configs import dify_config from libs.token import check_csrf_token from models import Account -from models.model import EndUser + +if TYPE_CHECKING: + from models.model import EndUser def current_account_with_tenant(): diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index edcb2a7870..0f42c99246 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -428,10 +428,10 @@ class AppDslService: # Set icon type icon_type_value = icon_type or app_data.get("icon_type") - if icon_type_value in [IconType.EMOJI.value, IconType.IMAGE.value, IconType.LINK.value]: + if icon_type_value in [IconType.EMOJI, IconType.IMAGE, IconType.LINK]: icon_type = icon_type_value else: - icon_type = IconType.EMOJI.value + icon_type = IconType.EMOJI icon = icon or str(app_data.get("icon", "")) if app: