fix: doc not gen bug (#31547)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-01-27 20:19:39 +09:00 committed by GitHub
parent e482588ef8
commit 8ec4233611
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 473 additions and 265 deletions

View File

@ -1,7 +1,11 @@
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
from enum import StrEnum
from flask_restx import Namespace
from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter
from controllers.console import console_ns
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -19,8 +23,25 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No
register_schema_model(namespace, model)
def get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
def register_enum_models(namespace: Namespace, *models: type[StrEnum]) -> None:
"""Register multiple StrEnum with a namespace."""
for model in models:
namespace.schema_model(
model.__name__, TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
__all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
"get_or_create_model",
"register_enum_models",
"register_schema_model",
"register_schema_models",
]

View File

@ -22,10 +22,10 @@ api_key_fields = {
"created_at": TimestampField,
}
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
api_key_list_model = console_ns.model(
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
)

View File

@ -9,9 +9,11 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.common.helpers import FileInfo
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.workspace.models import LoadBalancingPayload
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
@ -22,18 +24,36 @@ from controllers.console.wraps import (
)
from core.file import helpers as file_helpers
from core.ops.ops_trace_manager import OpsTraceManager
from core.workflow.enums import NodeType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.enums import NodeType, WorkflowExecutionStatus
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App, Workflow
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.knowledge_entities.knowledge_entities import (
DataSource,
InfoList,
NotionIcon,
NotionInfo,
NotionPage,
PreProcessingRule,
RerankingModel,
Rule,
Segmentation,
WebsiteInfo,
WeightKeywordSetting,
WeightModel,
WeightVectorSetting,
)
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
register_enum_models(console_ns, IconType)
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
@ -151,7 +171,7 @@ def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str |
if icon is None or icon_type is None:
return None
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE.value:
if icon_type_value.lower() != IconType.IMAGE:
return None
return file_helpers.get_signed_file_url(icon)
@ -391,6 +411,8 @@ class AppExportResponse(ResponseModel):
data: str
register_enum_models(console_ns, RetrievalMethod, WorkflowExecutionStatus, DatasetPermissionEnum)
register_schema_models(
console_ns,
AppListQuery,
@ -414,6 +436,22 @@ register_schema_models(
AppDetailWithSite,
AppPagination,
AppExportResponse,
Segmentation,
PreProcessingRule,
Rule,
WeightVectorSetting,
WeightKeywordSetting,
WeightModel,
RerankingModel,
InfoList,
NotionInfo,
FileInfo,
WebsiteInfo,
NotionPage,
NotionIcon,
RerankingModel,
DataSource,
LoadBalancingPayload,
)

View File

@ -41,14 +41,14 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppImportPayload(BaseModel):
mode: str = Field(..., description="Import mode")
yaml_content: str | None = None
yaml_url: str | None = None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
app_id: str | None = None
yaml_content: str | None = Field(None)
yaml_url: str | None = Field(None)
name: str | None = Field(None)
description: str | None = Field(None)
icon_type: str | None = Field(None)
icon: str | None = Field(None)
icon_background: str | None = Field(None)
app_id: str | None = Field(None)
console_ns.schema_model(

View File

@ -12,6 +12,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow_run import workflow_run_node_execution_model
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
@ -35,7 +36,6 @@ from extensions.ext_database import db
from factories import file_factory, variable_factory
from fields.member_fields import simple_account_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value
@ -88,26 +88,6 @@ workflow_pagination_fields_copy = workflow_pagination_fields.copy()
workflow_pagination_fields_copy["items"] = fields.List(fields.Nested(workflow_model), attribute="items")
workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagination_fields_copy)
# Reuse workflow_run_node_execution_model from workflow_run.py if already registered
# Otherwise register it here
from fields.end_user_fields import simple_end_user_fields
simple_end_user_model = None
try:
simple_end_user_model = console_ns.models.get("SimpleEndUser")
except AttributeError:
pass
if simple_end_user_model is None:
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
workflow_run_node_execution_model = None
try:
workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution")
except AttributeError:
pass
if workflow_run_node_execution_model is None:
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
class SyncDraftWorkflowPayload(BaseModel):
graph: dict[str, Any]

View File

@ -1,13 +1,14 @@
import logging
from flask import request
from flask_restx import Resource, marshal_with
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.schema import get_or_create_model
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
@ -22,6 +23,14 @@ from ..wraps import account_initialization_required, edit_permission_required, s
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields)
triggers_list_fields_copy = triggers_list_fields.copy()
triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model))
triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy)
webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields)
class Parser(BaseModel):
node_id: str
@ -48,7 +57,7 @@ class WebhookTriggerApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(webhook_trigger_fields)
@marshal_with(webhook_trigger_model)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -80,7 +89,7 @@ class AppTriggersApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(triggers_list_fields)
@marshal_with(triggers_list_model)
def get(self, app_model: App):
"""Get app triggers list"""
assert isinstance(current_user, Account)
@ -120,7 +129,7 @@ class AppTriggerEnableApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_fields)
@marshal_with(trigger_model)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
args = ParserEnable.model_validate(console_ns.payload)

View File

@ -3,13 +3,13 @@ from collections.abc import Generator
from typing import Any, cast
from flask import request
from flask_restx import Resource, marshal_with
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model
from controllers.common.schema import get_or_create_model, register_schema_model
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
@ -17,7 +17,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from fields.data_source_fields import (
integrate_fields,
integrate_icon_fields,
integrate_list_fields,
integrate_notion_info_list_fields,
integrate_page_fields,
integrate_workspace_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
@ -49,6 +56,49 @@ class DataSourceNotionPreviewQuery(BaseModel):
register_schema_model(console_ns, NotionEstimatePayload)
integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
integrate_page_fields_copy = integrate_page_fields.copy()
integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True)
integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy)
integrate_workspace_fields_copy = integrate_workspace_fields.copy()
integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model))
integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy)
integrate_fields_copy = integrate_fields.copy()
integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model)
integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy)
integrate_list_fields_copy = integrate_list_fields.copy()
integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model))
integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy)
notion_page_fields = {
"page_name": fields.String,
"page_id": fields.String,
"page_icon": fields.Nested(integrate_icon_model, allow_null=True),
"is_bound": fields.Boolean,
"parent_id": fields.String,
"type": fields.String,
}
notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields)
notion_workspace_fields = {
"workspace_name": fields.String,
"workspace_id": fields.String,
"workspace_icon": fields.String,
"pages": fields.List(fields.Nested(notion_page_model)),
}
notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields)
integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy()
integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model))
integrate_notion_info_list_model = get_or_create_model(
"NotionIntegrateInfoList", integrate_notion_info_list_fields_copy
)
@console_ns.route(
"/data-source/integrates",
"/data-source/integrates/<uuid:binding_id>/<string:action>",
@ -57,7 +107,7 @@ class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_fields)
@marshal_with(integrate_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
@ -142,7 +192,7 @@ class DataSourceNotionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_notion_info_list_fields)
@marshal_with(integrate_notion_info_list_model)
def get(self):
current_user, current_tenant_id = current_account_with_tenant()

View File

@ -8,7 +8,7 @@ from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import (
api_key_item_model,
@ -34,6 +34,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
content_fields,
dataset_detail_fields,
dataset_fields,
dataset_query_detail_fields,
@ -41,6 +42,7 @@ from fields.dataset_fields import (
doc_metadata_fields,
external_knowledge_info_fields,
external_retrieval_model_fields,
file_info_fields,
icon_info_fields,
keyword_setting_fields,
reranking_model_fields,
@ -55,41 +57,33 @@ from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_base_model = _get_or_create_model("DatasetBase", dataset_fields)
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
tag_model = _get_or_create_model("Tag", tag_fields)
tag_model = get_or_create_model("Tag", tag_fields)
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
@ -98,14 +92,22 @@ dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_k
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
dataset_detail_model = _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
dataset_query_detail_model = _get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields)
file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
app_detail_kernel_model = _get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
content_fields_copy = content_fields.copy()
content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
content_model = get_or_create_model("DatasetContent", content_fields_copy)
dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
related_app_list_copy = related_app_list.copy()
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_indexing_technique(value: str | None) -> str | None:

View File

@ -14,7 +14,7 @@ from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from core.errors.error import (
LLMBadRequestError,
@ -72,34 +72,27 @@ logger = logging.getLogger(__name__)
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_model = _get_or_create_model("Dataset", dataset_fields)
dataset_model = get_or_create_model("Dataset", dataset_fields)
document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields)
document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
document_fields_copy = document_fields.copy()
document_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_model = _get_or_create_model("Document", document_fields_copy)
document_model = get_or_create_model("Document", document_fields_copy)
document_with_segments_fields_copy = document_with_segments_fields.copy()
document_with_segments_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DocumentRetryPayload(BaseModel):
@ -1178,7 +1171,7 @@ class DocumentRenameApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(document_fields)
@marshal_with(document_model)
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator

View File

@ -90,6 +90,7 @@ register_schema_models(
ChildChunkCreatePayload,
ChildChunkUpdatePayload,
ChildChunkBatchUpdatePayload,
ChildChunkUpdateArgs,
)

View File

@ -4,7 +4,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -28,34 +28,27 @@ from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
def _build_dataset_detail_model():
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
tag_model = _get_or_create_model("Tag", tag_fields)
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
tag_model = get_or_create_model("Tag", tag_fields)
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
@ -64,7 +57,7 @@ def _build_dataset_detail_model():
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
return get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
try:

View File

@ -4,14 +4,16 @@ from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
MetadataArgs,
MetadataDetail,
MetadataOperationData,
)
from services.metadata_service import MetadataService
@ -21,8 +23,9 @@ class MetadataUpdatePayload(BaseModel):
name: str
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
register_schema_model(console_ns, MetadataUpdatePayload)
register_schema_models(
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
)
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")

View File

@ -2,7 +2,7 @@ import logging
from typing import Any, NoReturn
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
from flask_restx import Resource, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@ -14,7 +14,9 @@ from controllers.console.app.error import (
)
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
workflow_draft_variable_list_model,
workflow_draft_variable_list_without_value_model,
workflow_draft_variable_model,
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
@ -27,7 +29,6 @@ from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
@ -52,20 +53,6 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
return var_list.variables
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
"total": fields.Raw(),
}
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
@ -92,7 +79,7 @@ def _api_prerequisite(f):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables")
class RagPipelineVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
@marshal_with(workflow_draft_variable_list_without_value_model)
def get(self, pipeline: Pipeline):
"""
Get draft workflow
@ -150,7 +137,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables")
class RagPipelineNodeVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
@ -176,7 +163,7 @@ class RagPipelineVariableApi(Resource):
_PATCH_VALUE_FIELD = "value"
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@marshal_with(workflow_draft_variable_model)
def get(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
@ -189,7 +176,7 @@ class RagPipelineVariableApi(Resource):
return variable
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@marshal_with(workflow_draft_variable_model)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
@ -307,7 +294,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables")
class RagPipelineSystemVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline):
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)

View File

@ -1,9 +1,9 @@
from flask import request
from flask_restx import Resource, marshal_with # type: ignore
from flask_restx import Resource, fields, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
@ -12,7 +12,11 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from fields.rag_pipeline_fields import (
leaked_dependency_fields,
pipeline_import_check_dependencies_fields,
pipeline_import_fields,
)
from libs.login import current_account_with_tenant, login_required
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
@ -38,13 +42,25 @@ class IncludeSecretQuery(BaseModel):
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields)
leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields)
pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy()
pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(
fields.Nested(leaked_dependency_model)
)
pipeline_import_check_dependencies_model = get_or_create_model(
"RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy
)
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields)
@marshal_with(pipeline_import_model)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self):
# Check user role first
@ -81,7 +97,7 @@ class RagPipelineImportConfirmApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields)
@marshal_with(pipeline_import_model)
def post(self, import_id):
current_user, _ = current_account_with_tenant()
@ -106,7 +122,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@get_rag_pipeline
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_fields)
@marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)

View File

@ -17,6 +17,13 @@ from controllers.console.app.error import (
DraftWorkflowNotExist,
DraftWorkflowNotSync,
)
from controllers.console.app.workflow import workflow_model, workflow_pagination_model
from controllers.console.app.workflow_run import (
workflow_run_detail_model,
workflow_run_node_execution_list_model,
workflow_run_node_execution_model,
workflow_run_pagination_model,
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
@ -30,13 +37,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import (
workflow_run_detail_fields,
workflow_run_node_execution_fields,
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
from libs import helper
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, current_user, login_required
@ -145,7 +145,7 @@ class DraftRagPipelineApi(Resource):
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_fields)
@marshal_with(workflow_model)
def get(self, pipeline: Pipeline):
"""
Get draft rag pipeline's workflow
@ -521,7 +521,7 @@ class RagPipelineDraftNodeRunApi(Resource):
@edit_permission_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_fields)
@marshal_with(workflow_run_node_execution_model)
def post(self, pipeline: Pipeline, node_id: str):
"""
Run draft workflow node
@ -569,7 +569,7 @@ class PublishedRagPipelineApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
@marshal_with(workflow_fields)
@marshal_with(workflow_model)
def get(self, pipeline: Pipeline):
"""
Get published pipeline
@ -664,7 +664,7 @@ class PublishedAllRagPipelineApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
@marshal_with(workflow_pagination_fields)
@marshal_with(workflow_pagination_model)
def get(self, pipeline: Pipeline):
"""
Get published workflows
@ -708,7 +708,7 @@ class RagPipelineByIdApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
@marshal_with(workflow_fields)
@marshal_with(workflow_model)
def patch(self, pipeline: Pipeline, workflow_id: str):
"""
Update workflow attributes
@ -830,7 +830,7 @@ class RagPipelineWorkflowRunListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_pagination_fields)
@marshal_with(workflow_run_pagination_model)
def get(self, pipeline: Pipeline):
"""
Get workflow run list
@ -858,7 +858,7 @@ class RagPipelineWorkflowRunDetailApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_detail_fields)
@marshal_with(workflow_run_detail_model)
def get(self, pipeline: Pipeline, run_id):
"""
Get workflow run detail
@ -877,7 +877,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_list_fields)
@marshal_with(workflow_run_node_execution_list_model)
def get(self, pipeline: Pipeline, run_id: str):
"""
Get workflow run node execution list
@ -911,7 +911,7 @@ class RagPipelineWorkflowLastRunApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_fields)
@marshal_with(workflow_run_node_execution_model)
def get(self, pipeline: Pipeline, node_id: str):
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@ -952,7 +952,7 @@ class RagPipelineDatasourceVariableApi(Resource):
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_run_node_execution_fields)
@marshal_with(workflow_run_node_execution_model)
def post(self, pipeline: Pipeline):
"""
Set datasource variables

View File

@ -2,16 +2,17 @@ import logging
from typing import Any
from flask import request
from flask_restx import Resource, marshal_with
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.common.schema import get_or_create_model
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
@ -35,11 +36,22 @@ class InstalledAppsListQuery(BaseModel):
logger = logging.getLogger(__name__)
app_model = get_or_create_model("InstalledAppInfo", app_fields)
installed_app_fields_copy = installed_app_fields.copy()
installed_app_fields_copy["app"] = fields.Nested(app_model)
installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy)
installed_app_list_fields_copy = installed_app_list_fields.copy()
installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model))
installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy)
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@marshal_with(installed_app_list_fields)
@marshal_with(installed_app_list_model)
def get(self):
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()

View File

@ -3,6 +3,7 @@ from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants.languages import languages
from controllers.common.schema import get_or_create_model
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
@ -19,8 +20,10 @@ app_fields = {
"icon_background": fields.String,
}
app_model = get_or_create_model("RecommendedAppInfo", app_fields)
recommended_app_fields = {
"app": fields.Nested(app_fields, attribute="app"),
"app": fields.Nested(app_model, attribute="app"),
"app_id": fields.String,
"description": fields.String(attribute="description"),
"copyright": fields.String,
@ -32,11 +35,15 @@ recommended_app_fields = {
"can_trial": fields.Boolean,
}
recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields)
recommended_app_list_fields = {
"recommended_apps": fields.List(fields.Nested(recommended_app_fields)),
"recommended_apps": fields.List(fields.Nested(recommended_app_model)),
"categories": fields.List(fields.String),
}
recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields)
class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
@ -53,7 +60,7 @@ class RecommendedAppListApi(Resource):
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_fields)
@marshal_with(recommended_app_list_model)
def get(self):
# language args
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore

View File

@ -2,13 +2,14 @@ import logging
from typing import Any, cast
from flask import request
from flask_restx import Resource, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.fields import Parameters as ParametersResponse
from controllers.common.fields import Site as SiteResponse
from controllers.console import api
from controllers.common.schema import get_or_create_model
from controllers.console import api, console_ns
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@ -42,9 +43,21 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from fields.app_fields import app_detail_fields_with_site
from fields.app_fields import (
app_detail_fields_with_site,
deleted_tool_fields,
model_config_fields,
site_fields,
tag_fields,
)
from fields.dataset_fields import dataset_fields
from fields.workflow_fields import workflow_fields
from fields.member_fields import build_simple_account_model
from fields.workflow_fields import (
conversation_variable_fields,
pipeline_variable_fields,
workflow_fields,
workflow_partial_fields,
)
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
@ -74,6 +87,36 @@ from services.recommended_app_service import RecommendedAppService
logger = logging.getLogger(__name__)
model_config_model = get_or_create_model("TrialAppModelConfig", model_config_fields)
workflow_partial_model = get_or_create_model("TrialWorkflowPartial", workflow_partial_fields)
deleted_tool_model = get_or_create_model("TrialDeletedTool", deleted_tool_fields)
tag_model = get_or_create_model("TrialTag", tag_fields)
site_model = get_or_create_model("TrialSite", site_fields)
app_detail_fields_with_site_copy = app_detail_fields_with_site.copy()
app_detail_fields_with_site_copy["model_config"] = fields.Nested(
model_config_model, attribute="app_model_config", allow_null=True
)
app_detail_fields_with_site_copy["workflow"] = fields.Nested(workflow_partial_model, allow_null=True)
app_detail_fields_with_site_copy["deleted_tools"] = fields.List(fields.Nested(deleted_tool_model))
app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
simple_account_model = build_simple_account_model(console_ns)
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)
workflow_fields_copy = workflow_fields.copy()
workflow_fields_copy["created_by"] = fields.Nested(simple_account_model, attribute="created_by_account")
workflow_fields_copy["updated_by"] = fields.Nested(
simple_account_model, attribute="updated_by_account", allow_null=True
)
workflow_fields_copy["conversation_variables"] = fields.List(fields.Nested(conversation_variable_model))
workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipeline_variable_model))
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
class TrialAppWorkflowRunApi(TrialAppResource):
def post(self, trial_app):
"""
@ -437,7 +480,7 @@ class TrialAppParameterApi(Resource):
class AppApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(app_detail_fields_with_site)
@marshal_with(app_detail_with_site_model)
def get(self, app_model):
"""Get app detail"""
@ -450,7 +493,7 @@ class AppApi(Resource):
class AppWorkflowApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(workflow_fields)
@marshal_with(workflow_model)
def get(self, app_model):
"""Get workflow detail"""
if not app_model.workflow_id:

View File

@ -171,6 +171,19 @@ reg(ChangeEmailValidityPayload)
reg(ChangeEmailResetPayload)
reg(CheckEmailUniquePayload)
integrate_fields = {
"provider": fields.String,
"created_at": TimestampField,
"is_bound": fields.Boolean,
"link": fields.String,
}
integrate_model = console_ns.model("AccountIntegrate", integrate_fields)
integrate_list_model = console_ns.model(
"AccountIntegrateList",
{"data": fields.List(fields.Nested(integrate_model))},
)
@console_ns.route("/account/init")
class AccountInitApi(Resource):
@ -336,21 +349,10 @@ class AccountPasswordApi(Resource):
@console_ns.route("/account/integrates")
class AccountIntegrateApi(Resource):
integrate_fields = {
"provider": fields.String,
"created_at": TimestampField,
"is_bound": fields.Boolean,
"link": fields.String,
}
integrate_list_fields = {
"data": fields.List(fields.Nested(integrate_fields)),
}
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_fields)
@marshal_with(integrate_list_model)
def get(self):
account, _ = current_account_with_tenant()

View File

@ -1,11 +1,12 @@
from urllib import parse
from flask import abort, request
from flask_restx import Resource, marshal_with
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
import services
from configs import dify_config
from controllers.common.schema import get_or_create_model, register_enum_models
from controllers.console import console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
@ -24,7 +25,7 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
from fields.member_fields import account_with_role_fields, account_with_role_list_fields
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
@ -67,6 +68,13 @@ reg(MemberRoleUpdatePayload)
reg(OwnerTransferEmailPayload)
reg(OwnerTransferCheckPayload)
reg(OwnerTransferPayload)
register_enum_models(console_ns, TenantAccountRole)
account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields)
account_with_role_list_fields_copy = account_with_role_list_fields.copy()
account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model))
account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy)
@console_ns.route("/workspaces/current/members")
@ -76,7 +84,7 @@ class MemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_with_role_list_fields)
@marshal_with(account_with_role_list_model)
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
@ -227,7 +235,7 @@ class DatasetOperatorMemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_with_role_list_fields)
@marshal_with(account_with_role_list_model)
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:

View File

@ -5,6 +5,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
@ -23,12 +24,13 @@ class ParserGetDefault(BaseModel):
model_type: ModelType
class ParserPostDefault(BaseModel):
class Inner(BaseModel):
model_type: ModelType
model: str | None = None
provider: str | None = None
class Inner(BaseModel):
model_type: ModelType
model: str | None = None
provider: str | None = None
class ParserPostDefault(BaseModel):
model_settings: list[Inner]
@ -105,19 +107,21 @@ class ParserParameter(BaseModel):
model: str
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
register_schema_models(
console_ns,
ParserGetDefault,
ParserPostDefault,
ParserDeleteModels,
ParserPostModels,
ParserGetCredentials,
ParserCreateCredential,
ParserUpdateCredential,
ParserDeleteCredential,
ParserParameter,
Inner,
)
reg(ParserGetDefault)
reg(ParserPostDefault)
reg(ParserDeleteModels)
reg(ParserPostModels)
reg(ParserGetCredentials)
reg(ParserCreateCredential)
reg(ParserUpdateCredential)
reg(ParserDeleteCredential)
reg(ParserParameter)
register_enum_models(console_ns, ModelType)
@console_ns.route("/workspaces/current/default-model")

View File

@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
@ -20,57 +21,12 @@ from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
try:
return {
"key": PluginService.get_debugging_key(tenant_id),
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class ParserList(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
reg(ParserList)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
@console_ns.expect(console_ns.models[ParserList.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
class ParserLatest(BaseModel):
plugin_ids: list[str]
@ -180,23 +136,73 @@ class ParserReadme(BaseModel):
language: str = Field(default="en-US")
reg(ParserLatest)
reg(ParserIcon)
reg(ParserAsset)
reg(ParserGithubUpload)
reg(ParserPluginIdentifiers)
reg(ParserGithubInstall)
reg(ParserPluginIdentifierQuery)
reg(ParserTasks)
reg(ParserMarketplaceUpgrade)
reg(ParserGithubUpgrade)
reg(ParserUninstall)
reg(ParserPermissionChange)
reg(ParserDynamicOptions)
reg(ParserDynamicOptionsWithCredentials)
reg(ParserPreferencesChange)
reg(ParserExcludePlugin)
reg(ParserReadme)
register_schema_models(
console_ns,
ParserList,
PluginAutoUpgradeSettingsPayload,
PluginPermissionSettingsPayload,
ParserLatest,
ParserIcon,
ParserAsset,
ParserGithubUpload,
ParserPluginIdentifiers,
ParserGithubInstall,
ParserPluginIdentifierQuery,
ParserTasks,
ParserMarketplaceUpgrade,
ParserGithubUpgrade,
ParserUninstall,
ParserPermissionChange,
ParserDynamicOptions,
ParserDynamicOptionsWithCredentials,
ParserPreferencesChange,
ParserExcludePlugin,
ParserReadme,
)
register_enum_models(
console_ns,
TenantPluginPermission.DebugPermission,
TenantPluginAutoUpgradeStrategy.UpgradeMode,
TenantPluginAutoUpgradeStrategy.StrategySetting,
TenantPluginPermission.InstallPermission,
)
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
try:
return {
"key": PluginService.get_debugging_key(tenant_id),
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
}
except PluginDaemonClientSideError as e:
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
@console_ns.expect(console_ns.models[ParserList.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
@console_ns.route("/workspaces/current/plugin/list/latest-versions")

View File

@ -2,7 +2,7 @@ from typing import Any, Literal, cast
from flask import request
from flask_restx import marshal
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -26,6 +26,14 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
service_api_ns.schema_model(
DatasetPermissionEnum.__name__,
TypeAdapter(DatasetPermissionEnum).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)

View File

@ -16,6 +16,7 @@ from controllers.common.errors import (
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import (
@ -29,12 +30,20 @@ from controllers.service_api.wraps import (
cloud_edition_billing_resource_check,
)
from core.errors.error import ProviderTokenNotInitError
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.entities.knowledge_entities.knowledge_entities import (
KnowledgeConfig,
PreProcessingRule,
ProcessRule,
RetrievalModel,
Rule,
Segmentation,
)
from services.file_service import FileService
@ -76,8 +85,19 @@ class DocumentListQuery(BaseModel):
status: str | None = Field(default=None, description="Document status filter")
for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate, DocumentListQuery]:
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
register_enum_models(service_api_ns, RetrievalMethod)
register_schema_models(
service_api_ns,
ProcessRule,
RetrievalModel,
DocumentTextCreatePayload,
DocumentTextUpdate,
DocumentListQuery,
Rule,
PreProcessingRule,
Segmentation,
)
@service_api_ns.route(

View File

@ -60,6 +60,7 @@ register_schema_models(
service_api_ns,
SegmentCreatePayload,
SegmentListQuery,
SegmentUpdateArgs,
SegmentUpdatePayload,
ChildChunkCreatePayload,
ChildChunkListQuery,

View File

@ -1,6 +1,8 @@
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import TYPE_CHECKING, Any
from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS
@ -9,7 +11,9 @@ from werkzeug.local import LocalProxy
from configs import dify_config
from libs.token import check_csrf_token
from models import Account
from models.model import EndUser
if TYPE_CHECKING:
from models.model import EndUser
def current_account_with_tenant():

View File

@ -428,10 +428,10 @@ class AppDslService:
# Set icon type
icon_type_value = icon_type or app_data.get("icon_type")
if icon_type_value in [IconType.EMOJI.value, IconType.IMAGE.value, IconType.LINK.value]:
if icon_type_value in [IconType.EMOJI, IconType.IMAGE, IconType.LINK]:
icon_type = icon_type_value
else:
icon_type = IconType.EMOJI.value
icon_type = IconType.EMOJI
icon = icon or str(app_data.get("icon", ""))
if app: