diff --git a/api/AGENTS.md b/api/AGENTS.md index 8e5d9f600d..eb4404509d 100644 --- a/api/AGENTS.md +++ b/api/AGENTS.md @@ -193,6 +193,10 @@ Before opening a PR / submitting: - Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic. - Services: coordinate repositories, providers, background tasks; keep side effects explicit. - Document non-obvious behaviour with concise docstrings and comments. +- For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`. + In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response + DTOs with `register_response_schema_models(...)`, serialize with `ResponseModel.model_validate(...).model_dump(...)`, + and avoid adding new legacy `ns.model(...)`, `@marshal_with(...)`, or GET `@ns.expect(...)` patterns. ### Miscellaneous diff --git a/api/controllers/API_SCHEMA_GUIDE.md b/api/controllers/API_SCHEMA_GUIDE.md new file mode 100644 index 0000000000..5b1b055b09 --- /dev/null +++ b/api/controllers/API_SCHEMA_GUIDE.md @@ -0,0 +1,193 @@ +# API Schema Guide + +This guide describes the expected Flask-RESTX + Pydantic pattern for controller request payloads, query +parameters, response schemas, and Swagger documentation. + +## Principles + +- Use Pydantic `BaseModel` for request bodies and query parameters. +- Use `fields.base.ResponseModel` for response DTOs. +- Keep runtime validation and Swagger documentation wired to the same Pydantic model. +- Prefer explicit validation and serialization in controller methods over Flask-RESTX marshalling. +- Do not add new Flask-RESTX `fields.*` dictionaries, `Namespace.model(...)` exports, or `@marshal_with(...)` for migrated or new endpoints. +- Do not use `@ns.expect(...)` for GET query parameters. Flask-RESTX documents that as a request body. + +## Naming + +- Request body models: use a `Payload` suffix. + - Example: `WorkflowRunPayload`, `DatasourceVariablesPayload`. +- Query parameter models: use a `Query` suffix. + - Example: `WorkflowRunListQuery`, `MessageListQuery`. +- Response models: use a `Response` suffix and inherit from `ResponseModel`. + - Example: `WorkflowRunDetailResponse`, `WorkflowRunNodeExecutionListResponse`. +- Use `ListResponse` or `PaginationResponse` for wrapper responses. + - Example: `WorkflowRunNodeExecutionListResponse`, `WorkflowRunPaginationResponse`. +- Keep these models near the controller when they are endpoint-specific. Move them to `fields/*_fields.py` only when shared by multiple controllers. + +## Registering Models For Swagger + +Use helpers from `controllers.common.schema`. + +```python +from controllers.common.schema import ( + query_params_from_model, + register_response_schema_models, + register_schema_models, +) +``` + +Register request payload and query models with `register_schema_models(...)`: + +```python +register_schema_models( + console_ns, + WorkflowRunPayload, + WorkflowRunListQuery, +) +``` + +Register response models with `register_response_schema_models(...)`: + +```python +register_response_schema_models( + console_ns, + WorkflowRunDetailResponse, + WorkflowRunPaginationResponse, +) +``` + +Response models are registered in Pydantic serialization mode. This matters when a response model uses +`validation_alias` to read internal object attributes but emits public API field names. For example, a response model +can validate from `inputs_dict` while documenting and serializing `inputs`. + +## Request Bodies + +For non-GET request bodies: + +1. Define a Pydantic `Payload` model. +2. Register it with `register_schema_models(...)`. +3. Use `@ns.expect(ns.models[Payload.__name__])` for Swagger documentation. +4. Validate from `ns.payload or {}` inside the controller. + +```python +class DraftWorkflowNodeRunPayload(BaseModel): + inputs: dict[str, Any] + query: str = "" + + +register_schema_models(console_ns, DraftWorkflowNodeRunPayload) + + +@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__]) +def post(self, app_model: App, node_id: str): + payload = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {}) + result = service.run(..., inputs=payload.inputs, query=payload.query) + return WorkflowRunNodeExecutionResponse.model_validate(result, from_attributes=True).model_dump(mode="json") +``` + +## Query Parameters + +For GET query parameters: + +1. Define a Pydantic `Query` model. +2. Register it with `register_schema_models(...)` if it is referenced elsewhere in docs, or only use + `query_params_from_model(...)` if a body schema is not needed. +3. Use `@ns.doc(params=query_params_from_model(QueryModel))`. +4. Validate from `request.args.to_dict(flat=True)` or an explicit dict when type coercion is needed. + +```python +class WorkflowRunListQuery(BaseModel): + last_id: str | None = Field(default=None, description="Last run ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)") + + +@console_ns.doc(params=query_params_from_model(WorkflowRunListQuery)) +def get(self, app_model: App): + query = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) + result = service.list(..., limit=query.limit, last_id=query.last_id) + return WorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(mode="json") +``` + +Do not do this for GET query parameters: + +```python +@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__]) +def get(...): + ... +``` + +That documents a GET request body and is not the expected contract. + +## Responses + +Response models should inherit from `ResponseModel`: + +```python +class WorkflowRunNodeExecutionResponse(ResponseModel): + id: str + inputs: Any = Field(default=None, validation_alias="inputs_dict") + process_data: Any = Field(default=None, validation_alias="process_data_dict") + outputs: Any = Field(default=None, validation_alias="outputs_dict") +``` + +Document response models with `@ns.response(...)`: + +```python +@console_ns.response( + 200, + "Node run started successfully", + console_ns.models[WorkflowRunNodeExecutionResponse.__name__], +) +def post(...): + ... +``` + +Serialize explicitly: + +```python +return WorkflowRunNodeExecutionResponse.model_validate( + workflow_node_execution, + from_attributes=True, +).model_dump(mode="json") +``` + +If the service can return `None`, translate that into the expected HTTP error before validation: + +```python +workflow_run = service.get_workflow_run(...) +if workflow_run is None: + raise NotFound("Workflow run not found") + +return WorkflowRunDetailResponse.model_validate(workflow_run, from_attributes=True).model_dump(mode="json") +``` + +## Legacy Flask-RESTX Patterns + +Avoid adding these patterns to new or migrated endpoints: + +- `ns.model(...)` for new request/response DTOs. +- Module-level exported RESTX model objects such as `workflow_run_detail_model`. +- `fields.Nested({...})` with raw inline dict field maps. +- `@marshal_with(...)` for response serialization. +- `@ns.expect(...)` for GET query params. + +Existing legacy field dictionaries may remain where an endpoint has not yet been migrated. Keep that compatibility local +to the legacy area and avoid importing RESTX model objects from controllers. + +## Verifying Swagger + +For schema and documentation changes, run focused tests and generate Swagger JSON: + +```bash +uv run --project . pytest tests/unit_tests/controllers/common/test_schema.py +uv run --project . pytest tests/unit_tests/commands/test_generate_swagger_specs.py tests/unit_tests/controllers/test_swagger.py +uv run --project . dev/generate_swagger_specs.py --output-dir /tmp/dify-openapi-check +``` + +Inspect affected endpoints with `jq`. Check that: + +- GET parameters are `in: query`. +- Request bodies appear only where the endpoint has a body. +- Responses reference the expected `*Response` schema. +- Response schemas use public serialized names, not internal validation aliases like `inputs_dict`. + diff --git a/api/controllers/common/schema.py b/api/controllers/common/schema.py index 57070f1c80..58140f3ac8 100644 --- a/api/controllers/common/schema.py +++ b/api/controllers/common/schema.py @@ -8,7 +8,7 @@ These helpers keep that translation centralized so models registered through from collections.abc import Mapping from enum import StrEnum -from typing import Any, NotRequired, TypedDict +from typing import Any, Literal, NotRequired, TypedDict from flask_restx import Namespace from pydantic import BaseModel, TypeAdapter @@ -54,16 +54,23 @@ def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None _register_json_schema(namespace, nested_name, nested_schema) -def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None: - """Register a BaseModel and its nested schema definitions for Swagger documentation.""" +JsonSchemaMode = Literal["validation", "serialization"] + +def _register_schema_model(namespace: Namespace, model: type[BaseModel], *, mode: JsonSchemaMode) -> None: _register_json_schema( namespace, model.__name__, - model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), + model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0, mode=mode), ) +def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None: + """Register a BaseModel and its nested schema definitions for Swagger documentation.""" + + _register_schema_model(namespace, model, mode="validation") + + def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None: """Register multiple BaseModels with a namespace.""" @@ -71,6 +78,19 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No register_schema_model(namespace, model) +def register_response_schema_model(namespace: Namespace, model: type[BaseModel]) -> None: + """Register a BaseModel using its serialized response shape.""" + + _register_schema_model(namespace, model, mode="serialization") + + +def register_response_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None: + """Register multiple response BaseModels using their serialized response shape.""" + + for model in models: + register_response_schema_model(namespace, model) + + def get_or_create_model(model_name: str, field_def): # Import lazily to avoid circular imports between console controllers and schema helpers. from controllers.console import console_ns @@ -190,6 +210,8 @@ __all__ = [ "get_or_create_model", "query_params_from_model", "register_enum_models", + "register_response_schema_model", + "register_response_schema_models", "register_schema_model", "register_schema_models", ] diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index a32c3420bb..ae2b1007dd 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -3,6 +3,7 @@ import io from collections.abc import Callable from functools import wraps from typing import cast +from uuid import UUID from flask import request from flask_restx import Resource @@ -21,8 +22,6 @@ from libs.token import extract_access_token from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp from services.billing_service import BillingService, LangContentDict -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class InsertExploreAppPayload(BaseModel): app_id: str = Field(...) @@ -59,15 +58,7 @@ class InsertExploreBannerPayload(BaseModel): model_config = {"populate_by_name": True} -console_ns.schema_model( - InsertExploreAppPayload.__name__, - InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - InsertExploreBannerPayload.__name__, - InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_schema_models(console_ns, InsertExploreAppPayload, InsertExploreBannerPayload) def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: @@ -191,7 +182,7 @@ class InsertExploreAppApi(Resource): @console_ns.response(204, "App removed successfully") @only_edition_cloud @admin_required - def delete(self, app_id): + def delete(self, app_id: UUID): with session_factory.create_session() as session: recommended_app = session.execute( select(RecommendedApp).where(RecommendedApp.app_id == str(app_id)) @@ -404,11 +395,11 @@ class BatchAddNotificationAccountsApi(Resource): raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.") try: - content = file.read().decode("utf-8") + content = file.stream.read().decode("utf-8") except UnicodeDecodeError: try: - file.seek(0) - content = file.read().decode("gbk") + file.stream.seek(0) + content = file.stream.read().decode("gbk") except UnicodeDecodeError: raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.") diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index ed66da1be5..ad21671176 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -34,7 +34,7 @@ class AdvancedPromptTemplateList(Resource): @login_required @account_initialization_required def get(self): - args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) prompt_args: AdvancedPromptTemplateArgs = { "app_mode": args.app_mode, "model_mode": args.model_mode, diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index cfdb9cf417..c05600ced5 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -2,6 +2,7 @@ from flask import request from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required @@ -10,8 +11,6 @@ from libs.login import login_required from models.model import AppMode from services.agent_service import AgentService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AgentLogQuery(BaseModel): message_id: str = Field(..., description="Message UUID") @@ -23,9 +22,7 @@ class AgentLogQuery(BaseModel): return uuid_value(value) -console_ns.schema_model( - AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, AgentLogQuery) @console_ns.route("/apps//agent/logs") @@ -44,6 +41,6 @@ class AgentLogApi(Resource): @get_app_model(mode=[AppMode.AGENT_CHAT]) def get(self, app_model): """Get agent logs""" - args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 528785931e..cfeaec4af9 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,4 +1,5 @@ from typing import Any, Literal +from uuid import UUID from flask import abort, make_response, request from flask_restx import Resource @@ -33,8 +34,6 @@ from services.annotation_service import ( UpsertAnnotationArgs, ) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AnnotationReplyPayload(BaseModel): score_threshold: float = Field(..., description="Score threshold for annotation matching") @@ -87,17 +86,6 @@ class AnnotationFilePayload(BaseModel): return uuid_value(value) -def reg(model: type[BaseModel]) -> None: - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(AnnotationReplyPayload) -reg(AnnotationSettingUpdatePayload) -reg(AnnotationListQuery) -reg(CreateAnnotationPayload) -reg(UpdateAnnotationPayload) -reg(AnnotationReplyStatusQuery) -reg(AnnotationFilePayload) register_schema_models( console_ns, Annotation, @@ -105,6 +93,13 @@ register_schema_models( AnnotationExportList, AnnotationHitHistory, AnnotationHitHistoryList, + AnnotationReplyPayload, + AnnotationSettingUpdatePayload, + AnnotationListQuery, + CreateAnnotationPayload, + UpdateAnnotationPayload, + AnnotationReplyStatusQuery, + AnnotationFilePayload, ) @@ -121,8 +116,7 @@ class AnnotationReplyActionApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def post(self, app_id, action: Literal["enable", "disable"]): - app_id = str(app_id) + def post(self, app_id: UUID, action: Literal["enable", "disable"]): args = AnnotationReplyPayload.model_validate(console_ns.payload) match action: case "enable": @@ -131,9 +125,9 @@ class AnnotationReplyActionApi(Resource): "embedding_provider_name": args.embedding_provider_name, "embedding_model_name": args.embedding_model_name, } - result = AppAnnotationService.enable_app_annotation(enable_args, app_id) + result = AppAnnotationService.enable_app_annotation(enable_args, str(app_id)) case "disable": - result = AppAnnotationService.disable_app_annotation(app_id) + result = AppAnnotationService.disable_app_annotation(str(app_id)) return result, 200 @@ -148,9 +142,8 @@ class AppAnnotationSettingDetailApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_id): - app_id = str(app_id) - result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id) + def get(self, app_id: UUID): + result = AppAnnotationService.get_app_annotation_setting_by_app_id(str(app_id)) return result, 200 @@ -166,14 +159,13 @@ class AppAnnotationSettingUpdateApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, app_id, annotation_setting_id): - app_id = str(app_id) + def post(self, app_id: UUID, annotation_setting_id): annotation_setting_id = str(annotation_setting_id) args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload) setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold} - result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args) + result = AppAnnotationService.update_app_annotation_setting(str(app_id), annotation_setting_id, setting_args) return result, 200 @@ -189,7 +181,7 @@ class AnnotationReplyActionStatusApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def get(self, app_id, job_id, action): + def get(self, app_id: UUID, job_id, action): job_id = str(job_id) app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" cache_result = redis_client.get(app_annotation_job_key) @@ -217,14 +209,13 @@ class AnnotationApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_id): - args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + def get(self, app_id: UUID): + args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) page = args.page limit = args.limit keyword = args.keyword - app_id = str(app_id) - annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(str(app_id), page, limit, keyword) annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) response = AnnotationList( data=annotation_models, @@ -246,8 +237,7 @@ class AnnotationApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def post(self, app_id): - app_id = str(app_id) + def post(self, app_id: UUID): args = CreateAnnotationPayload.model_validate(console_ns.payload) upsert_args: UpsertAnnotationArgs = {} if args.answer is not None: @@ -258,15 +248,14 @@ class AnnotationApi(Resource): upsert_args["message_id"] = args.message_id if args.question is not None: upsert_args["question"] = args.question - annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id) + annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, str(app_id)) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @account_initialization_required @edit_permission_required - def delete(self, app_id): - app_id = str(app_id) + def delete(self, app_id: UUID): # Use request.args.getlist to get annotation_ids array directly annotation_ids = request.args.getlist("annotation_id") @@ -280,11 +269,11 @@ class AnnotationApi(Resource): "message": "annotation_ids are required if the parameter is provided.", }, 400 - result = AppAnnotationService.delete_app_annotations_in_batch(app_id, annotation_ids) + result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids) return result, 204 # If no annotation_ids are provided, handle clearing all annotations else: - AppAnnotationService.clear_all_annotations(app_id) + AppAnnotationService.clear_all_annotations(str(app_id)) return {"result": "success"}, 204 @@ -303,9 +292,8 @@ class AnnotationExportApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_id): - app_id = str(app_id) - annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) + def get(self, app_id: UUID): + annotation_list = AppAnnotationService.export_annotation_list_by_app_id(str(app_id)) annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json") @@ -331,26 +319,22 @@ class AnnotationUpdateDeleteApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def post(self, app_id, annotation_id): - app_id = str(app_id) - annotation_id = str(annotation_id) + def post(self, app_id: UUID, annotation_id: UUID): args = UpdateAnnotationPayload.model_validate(console_ns.payload) update_args: UpdateAnnotationArgs = {} if args.answer is not None: update_args["answer"] = args.answer if args.question is not None: update_args["question"] = args.question - annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id) + annotation = AppAnnotationService.update_app_annotation_directly(update_args, str(app_id), str(annotation_id)) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @account_initialization_required @edit_permission_required - def delete(self, app_id, annotation_id): - app_id = str(app_id) - annotation_id = str(annotation_id) - AppAnnotationService.delete_app_annotation(app_id, annotation_id) + def delete(self, app_id: UUID, annotation_id: UUID): + AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id)) return {"result": "success"}, 204 @@ -371,11 +355,9 @@ class AnnotationBatchImportApi(Resource): @annotation_import_rate_limit @annotation_import_concurrency_limit @edit_permission_required - def post(self, app_id): + def post(self, app_id: UUID): from configs import dify_config - app_id = str(app_id) - # check file if "file" not in request.files: raise NoFileUploadedError() @@ -391,9 +373,9 @@ class AnnotationBatchImportApi(Resource): raise ValueError("Invalid file type. Only CSV files are allowed") # Check file size before processing - file.seek(0, 2) # Seek to end of file - file_size = file.tell() - file.seek(0) # Reset to beginning + file.stream.seek(0, 2) # Seek to end of file + file_size = file.stream.tell() + file.stream.seek(0) # Reset to beginning max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024 if file_size > max_size_bytes: @@ -406,7 +388,7 @@ class AnnotationBatchImportApi(Resource): if file_size == 0: raise ValueError("The uploaded file is empty") - return AppAnnotationService.batch_import_app_annotations(app_id, file) + return AppAnnotationService.batch_import_app_annotations(str(app_id), file) @console_ns.route("/apps//annotations/batch-import-status/") @@ -421,8 +403,7 @@ class AnnotationBatchImportStatusApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def get(self, app_id, job_id): - job_id = str(job_id) + def get(self, app_id: UUID, job_id: UUID): indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" cache_result = redis_client.get(indexing_cache_key) if cache_result is None: @@ -456,13 +437,11 @@ class AnnotationHitHistoryListApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_id, annotation_id): + def get(self, app_id: UUID, annotation_id: UUID): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - app_id = str(app_id) - annotation_id = str(annotation_id) annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( - app_id, annotation_id, page, limit + str(app_id), str(annotation_id), page, limit ) history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python( annotation_hit_history_list, from_attributes=True diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 58ed6efc14..a8ab5bec48 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -3,6 +3,7 @@ import re import uuid from datetime import datetime from typing import Any, Literal +from uuid import UUID from flask import request from flask_restx import Resource @@ -701,7 +702,7 @@ class AppExportApi(Resource): @edit_permission_required def get(self, app_model): """Export app""" - args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) payload = AppExportResponse( data=AppDslService.export_dsl( @@ -840,10 +841,10 @@ class AppTraceApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + def get(self, app_id: UUID): """Get app trace""" with session_factory.create_session() as session: - app_trace_config = OpsTraceManager.get_app_tracing_config(app_id, session) + app_trace_config = OpsTraceManager.get_app_tracing_config(str(app_id), session) return app_trace_config @@ -857,12 +858,12 @@ class AppTraceApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, app_id): + def post(self, app_id: UUID): # add app trace args = AppTracePayload.model_validate(console_ns.payload) OpsTraceManager.update_app_tracing_config( - app_id=app_id, + app_id=str(app_id), enabled=args.enabled, tracing_provider=args.tracing_provider, ) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 91fbe4a85a..5b673f3394 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -173,7 +173,7 @@ class TextModesApi(Resource): @account_initialization_required def get(self, app_model): try: - args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) response = AudioService.transcript_tts_voices( tenant_id=app_model.tenant_id, diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index fe274e4c9a..6a20296cff 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( AppUnavailableError, @@ -37,7 +38,6 @@ from services.app_task_service import AppTaskService from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class BaseMessagePayload(BaseModel): @@ -65,13 +65,7 @@ class ChatMessagePayload(BaseMessagePayload): return uuid_value(value) -console_ns.schema_model( - CompletionMessagePayload.__name__, - CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) # define completion message api for user diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index b2b1049f0c..c7347933cb 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -39,8 +39,6 @@ from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class BaseConversationQuery(BaseModel): keyword: str | None = Field(default=None, description="Search keyword") @@ -70,15 +68,6 @@ class ChatConversationQuery(BaseConversationQuery): ) -console_ns.schema_model( - CompletionConversationQuery.__name__, - CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ChatConversationQuery.__name__, - ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - register_schema_models( console_ns, CompletionConversationQuery, @@ -89,6 +78,8 @@ register_schema_models( ConversationWithSummaryPaginationResponse, ConversationDetailResponse, ResultResponse, + CompletionConversationQuery, + ChatConversationQuery, ) @@ -107,7 +98,7 @@ class CompletionConversationApi(Resource): @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() - args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) query = sa.select(Conversation).where( Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) @@ -221,7 +212,7 @@ class ChatConversationApi(Resource): @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() - args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) subquery = ( sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 9c8b095b9f..60a2bfc799 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -100,7 +100,7 @@ class ConversationVariablesApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.ADVANCED_CHAT) def get(self, app_model): - args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) stmt = ( select(ConversationVariable) diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index cbcf513162..9227d00a21 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,18 +1,18 @@ from typing import Any +from uuid import UUID from flask import request from flask_restx import Resource, fields from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.ops_service import OpsService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class TraceProviderQuery(BaseModel): tracing_provider: str = Field(..., description="Tracing provider name") @@ -23,13 +23,7 @@ class TraceConfigPayload(BaseModel): tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data") -console_ns.schema_model( - TraceProviderQuery.__name__, - TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, TraceProviderQuery, TraceConfigPayload) @console_ns.route("/apps//trace-config") @@ -49,11 +43,11 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): - args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + def get(self, app_id: UUID): + args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) try: - trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) + trace_config = OpsService.get_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider) if not trace_config: return {"has_not_configured": True} return trace_config @@ -71,13 +65,13 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): + def post(self, app_id: UUID): """Create a new trace app configuration""" args = TraceConfigPayload.model_validate(console_ns.payload) try: result = OpsService.create_tracing_app_config( - app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config + app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config ) if not result: raise TracingConfigIsExist() @@ -96,13 +90,13 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, app_id): + def patch(self, app_id: UUID): """Update an existing trace app configuration""" args = TraceConfigPayload.model_validate(console_ns.payload) try: result = OpsService.update_tracing_app_config( - app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config + app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config ) if not result: raise TracingConfigNotExist() @@ -119,12 +113,12 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self, app_id): + def delete(self, app_id: UUID): """Delete an existing trace app configuration""" - args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) try: - result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) + result = OpsService.delete_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider) if not result: raise TracingConfigNotExist() return {"result": "success"}, 204 diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index ffa28b1c95..d23b2837c9 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -5,6 +5,7 @@ from flask import abort, jsonify, request from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required @@ -15,8 +16,6 @@ from libs.helper import convert_datetime_to_date from libs.login import current_account_with_tenant, login_required from models import AppMode -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class StatisticTimeRangeQuery(BaseModel): start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)") @@ -30,10 +29,7 @@ class StatisticTimeRangeQuery(BaseModel): return value -console_ns.schema_model( - StatisticTimeRangeQuery.__name__, - StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_schema_models(console_ns, StatisticTimeRangeQuery) @console_ns.route("/apps//statistics/daily-messages") @@ -54,7 +50,7 @@ class DailyMessageStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -111,7 +107,7 @@ class DailyConversationStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -167,7 +163,7 @@ class DailyTerminalsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -224,7 +220,7 @@ class DailyTokenCostStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -284,7 +280,7 @@ class AverageSessionInteractionStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("c.created_at") sql_query = f"""SELECT @@ -360,7 +356,7 @@ class UserSatisfactionRateStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("m.created_at") sql_query = f"""SELECT @@ -426,7 +422,7 @@ class AverageResponseTimeStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -482,7 +478,7 @@ class TokensPerSecondStatistic(Resource): @account_initialization_required def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 68dd8b7a8d..4f532b437c 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -11,9 +11,9 @@ from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotF import services from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload +from controllers.common.schema import register_response_schema_model, register_schema_models from controllers.console import console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync -from controllers.console.app.workflow_run import workflow_run_node_execution_model from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError @@ -37,6 +37,7 @@ from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields from fields.online_user_fields import online_user_list_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields +from fields.workflow_run_fields import WorkflowRunNodeExecutionResponse from graphon.enums import NodeType from graphon.file import File from graphon.file import helpers as file_helpers @@ -56,9 +57,10 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) + _file_access_controller = DatabaseFileAccessController() LISTENING_RETRY_IN = 2000 -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000 WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50 @@ -176,25 +178,25 @@ class DraftWorkflowTriggerRunAllPayload(BaseModel): node_ids: list[str] -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(SyncDraftWorkflowPayload) -reg(AdvancedChatWorkflowRunPayload) -reg(IterationNodeRunPayload) -reg(LoopNodeRunPayload) -reg(DraftWorkflowRunPayload) -reg(DraftWorkflowNodeRunPayload) -reg(PublishWorkflowPayload) -reg(DefaultBlockConfigQuery) -reg(ConvertToWorkflowPayload) -reg(WorkflowListQuery) -reg(WorkflowUpdatePayload) -reg(WorkflowFeaturesPayload) -reg(WorkflowOnlineUsersPayload) -reg(DraftWorkflowTriggerRunPayload) -reg(DraftWorkflowTriggerRunAllPayload) +register_schema_models( + console_ns, + SyncDraftWorkflowPayload, + AdvancedChatWorkflowRunPayload, + IterationNodeRunPayload, + LoopNodeRunPayload, + DraftWorkflowRunPayload, + DraftWorkflowNodeRunPayload, + PublishWorkflowPayload, + DefaultBlockConfigQuery, + ConvertToWorkflowPayload, + WorkflowListQuery, + WorkflowUpdatePayload, + WorkflowFeaturesPayload, + WorkflowOnlineUsersPayload, + DraftWorkflowTriggerRunPayload, + DraftWorkflowTriggerRunAllPayload, +) +register_response_schema_model(console_ns, WorkflowRunNodeExecutionResponse) # TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing @@ -540,9 +542,12 @@ class HumanInputDeliveryTestPayload(BaseModel): ) -reg(HumanInputFormPreviewPayload) -reg(HumanInputFormSubmitPayload) -reg(HumanInputDeliveryTestPayload) +register_schema_models( + console_ns, + HumanInputFormPreviewPayload, + HumanInputFormSubmitPayload, + HumanInputDeliveryTestPayload, +) @console_ns.route("/apps//advanced-chat/workflows/draft/human-input/nodes//form/preview") @@ -760,14 +765,17 @@ class DraftWorkflowNodeRunApi(Resource): @console_ns.doc(description="Run draft workflow node") @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__]) - @console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model) + @console_ns.response( + 200, + "Node run started successfully", + console_ns.models[WorkflowRunNodeExecutionResponse.__name__], + ) @console_ns.response(403, "Permission denied") @console_ns.response(404, "Node not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_node_execution_model) @edit_permission_required def post(self, app_model: App, node_id: str): """ @@ -799,7 +807,9 @@ class DraftWorkflowNodeRunApi(Resource): files=files, ) - return workflow_node_execution + return WorkflowRunNodeExecutionResponse.model_validate( + workflow_node_execution, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/apps//workflows/publish") @@ -902,7 +912,7 @@ class DefaultBlockConfigApi(Resource): """ Get default block config """ - args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) filters = None if args.q: @@ -995,7 +1005,7 @@ class PublishedAllWorkflowApi(Resource): """ current_user, _ = current_account_with_tenant() - args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) page = args.page limit = args.limit user_id = args.user_id @@ -1143,14 +1153,17 @@ class DraftWorkflowNodeLastRunApi(Resource): @console_ns.doc("get_draft_workflow_node_last_run") @console_ns.doc(description="Get last run result for draft workflow node") @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model) + @console_ns.response( + 200, + "Node last run retrieved successfully", + console_ns.models[WorkflowRunNodeExecutionResponse.__name__], + ) @console_ns.response(404, "Node last run not found") @console_ns.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_node_execution_model) def get(self, app_model: App, node_id: str): srv = WorkflowService() workflow = srv.get_draft_workflow(app_model) @@ -1163,7 +1176,7 @@ class DraftWorkflowNodeLastRunApi(Resource): ) if node_exec is None: raise NotFound("last run not found") - return node_exec + return WorkflowRunNodeExecutionResponse.model_validate(node_exec, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//workflows/draft/trigger/run") diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 4b39590235..ddc900eb2d 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -185,7 +185,7 @@ class WorkflowAppLogApi(Resource): """ Get workflow app logs """ - args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # get paginate workflow app logs workflow_app_service = WorkflowAppService() @@ -228,7 +228,7 @@ class WorkflowArchivedLogApi(Resource): """ Get workflow archived logs """ - args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) workflow_app_service = WorkflowAppService() with sessionmaker(db.engine, expire_on_commit=False).begin() as session: diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py index e7c3e982a6..c003be1303 100644 --- a/api/controllers/console/app/workflow_comment.py +++ b/api/controllers/console/app/workflow_comment.py @@ -23,7 +23,6 @@ from services.account_service import TenantService from services.workflow_comment_service import WorkflowCommentService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class WorkflowCommentCreatePayload(BaseModel): @@ -52,13 +51,14 @@ class WorkflowCommentMentionUsersPayload(BaseModel): users: list[AccountWithRole] -for model in ( +register_schema_models( + console_ns, + AccountWithRole, + WorkflowCommentMentionUsersPayload, WorkflowCommentCreatePayload, WorkflowCommentUpdatePayload, WorkflowCommentReplyPayload, -): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) -register_schema_models(console_ns, AccountWithRole, WorkflowCommentMentionUsersPayload) +) workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields) workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index c688a69074..3c887c33dc 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -8,6 +8,7 @@ from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( DraftWorkflowNotExist, @@ -33,7 +34,6 @@ from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) _file_access_controller = DatabaseFileAccessController() -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class WorkflowDraftVariableListQuery(BaseModel): @@ -56,21 +56,12 @@ class EnvironmentVariableUpdatePayload(BaseModel): environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow") -console_ns.schema_model( - WorkflowDraftVariableListQuery.__name__, - WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - WorkflowDraftVariableUpdatePayload.__name__, - WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ConversationVariableUpdatePayload.__name__, - ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - EnvironmentVariableUpdatePayload.__name__, - EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +register_schema_models( + console_ns, + WorkflowDraftVariableListQuery, + WorkflowDraftVariableUpdatePayload, + ConversationVariableUpdatePayload, + EnvironmentVariableUpdatePayload, ) @@ -260,7 +251,7 @@ class WorkflowVariableCollectionApi(Resource): """ Get draft workflow """ - args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # fetch draft workflow by app_model workflow_service = WorkflowService() diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 6748d95d6b..97d2003209 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,30 +1,28 @@ from datetime import UTC, datetime, timedelta -from typing import Literal, TypedDict, cast +from typing import Literal, cast from flask import request -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker from configs import dify_config +from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import NotFoundError from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id from extensions.ext_database import db -from fields.end_user_fields import simple_end_user_fields -from fields.member_fields import simple_account_fields +from fields.base import ResponseModel from fields.workflow_run_fields import ( - advanced_chat_workflow_run_for_list_fields, - advanced_chat_workflow_run_pagination_fields, - workflow_run_count_fields, - workflow_run_detail_fields, - workflow_run_for_list_fields, - workflow_run_node_execution_fields, - workflow_run_node_execution_list_fields, - workflow_run_pagination_fields, + AdvancedChatWorkflowRunPaginationResponse, + WorkflowRunCountResponse, + WorkflowRunDetailResponse, + WorkflowRunNodeExecutionListResponse, + WorkflowRunNodeExecutionResponse, + WorkflowRunPaginationResponse, ) from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import WorkflowExecutionStatus @@ -52,82 +50,6 @@ def _build_backstage_input_url(form_token: str | None) -> str | None: WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600 -# Register models for flask_restx to avoid dict type issues in Swagger -# Register in dependency order: base models first, then dependent models - -# Base models -simple_account_model = console_ns.model("SimpleAccount", simple_account_fields) - -simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields) - -# Models that depend on simple_account_fields -workflow_run_for_list_fields_copy = workflow_run_for_list_fields.copy() -workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True -) -workflow_run_for_list_model = console_ns.model("WorkflowRunForList", workflow_run_for_list_fields_copy) - -advanced_chat_workflow_run_for_list_fields_copy = advanced_chat_workflow_run_for_list_fields.copy() -advanced_chat_workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True -) -advanced_chat_workflow_run_for_list_model = console_ns.model( - "AdvancedChatWorkflowRunForList", advanced_chat_workflow_run_for_list_fields_copy -) - -workflow_run_detail_fields_copy = workflow_run_detail_fields.copy() -workflow_run_detail_fields_copy["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True -) -workflow_run_detail_fields_copy["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True -) -workflow_run_detail_model = console_ns.model("WorkflowRunDetail", workflow_run_detail_fields_copy) - -workflow_run_node_execution_fields_copy = workflow_run_node_execution_fields.copy() -workflow_run_node_execution_fields_copy["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True -) -workflow_run_node_execution_fields_copy["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True -) -workflow_run_node_execution_model = console_ns.model( - "WorkflowRunNodeExecution", workflow_run_node_execution_fields_copy -) - -# Simple models without nested dependencies -workflow_run_count_model = console_ns.model("WorkflowRunCount", workflow_run_count_fields) - -# Pagination models that depend on list models -advanced_chat_workflow_run_pagination_fields_copy = advanced_chat_workflow_run_pagination_fields.copy() -advanced_chat_workflow_run_pagination_fields_copy["data"] = fields.List( - fields.Nested(advanced_chat_workflow_run_for_list_model), attribute="data" -) -advanced_chat_workflow_run_pagination_model = console_ns.model( - "AdvancedChatWorkflowRunPagination", advanced_chat_workflow_run_pagination_fields_copy -) - -workflow_run_pagination_fields_copy = workflow_run_pagination_fields.copy() -workflow_run_pagination_fields_copy["data"] = fields.List(fields.Nested(workflow_run_for_list_model), attribute="data") -workflow_run_pagination_model = console_ns.model("WorkflowRunPagination", workflow_run_pagination_fields_copy) - -workflow_run_node_execution_list_fields_copy = workflow_run_node_execution_list_fields.copy() -workflow_run_node_execution_list_fields_copy["data"] = fields.List(fields.Nested(workflow_run_node_execution_model)) -workflow_run_node_execution_list_model = console_ns.model( - "WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy -) - -workflow_run_export_fields = console_ns.model( - "WorkflowRunExport", - { - "status": fields.String(description="Export status: success/failed"), - "presigned_url": fields.String(description="Pre-signed URL for download", required=False), - "presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False), - }, -) - -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class WorkflowRunListQuery(BaseModel): last_id: str | None = Field(default=None, description="Last run ID for pagination") @@ -136,7 +58,7 @@ class WorkflowRunListQuery(BaseModel): default=None, description="Workflow run status filter" ) triggered_from: Literal["debugging", "app-run"] | None = Field( - default=None, description="Filter by trigger source: debugging or app-run" + default=None, description="Filter by trigger source: debugging or app-run. Default: debugging" ) @field_validator("last_id") @@ -151,9 +73,15 @@ class WorkflowRunCountQuery(BaseModel): status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field( default=None, description="Workflow run status filter" ) - time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)") + time_range: str | None = Field( + default=None, + description=( + "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " + "30m (30 minutes), 30s (30 seconds). Filters by created_at field." + ), + ) triggered_from: Literal["debugging", "app-run"] | None = Field( - default=None, description="Filter by trigger source: debugging or app-run" + default=None, description="Filter by trigger source: debugging or app-run. Default: debugging" ) @field_validator("time_range") @@ -164,56 +92,69 @@ class WorkflowRunCountQuery(BaseModel): return time_duration(value) -console_ns.schema_model( - WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - WorkflowRunCountQuery.__name__, - WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +class WorkflowRunExportResponse(ResponseModel): + status: str = Field(description="Export status: success/failed") + presigned_url: str | None = Field(default=None, description="Pre-signed URL for download") + presigned_url_expires_at: str | None = Field(default=None, description="Pre-signed URL expiration time") -class HumanInputPauseTypeResponse(TypedDict): +class HumanInputPauseTypeResponse(ResponseModel): type: Literal["human_input"] form_id: str - backstage_input_url: str | None + backstage_input_url: str | None = None -class PausedNodeResponse(TypedDict): +class PausedNodeResponse(ResponseModel): node_id: str node_title: str pause_type: HumanInputPauseTypeResponse -class WorkflowPauseDetailsResponse(TypedDict): - paused_at: str | None +class WorkflowPauseDetailsResponse(ResponseModel): + paused_at: str | None = None paused_nodes: list[PausedNodeResponse] +register_schema_models( + console_ns, + WorkflowRunListQuery, + WorkflowRunCountQuery, +) +register_response_schema_models( + console_ns, + AdvancedChatWorkflowRunPaginationResponse, + WorkflowRunPaginationResponse, + WorkflowRunCountResponse, + WorkflowRunDetailResponse, + WorkflowRunNodeExecutionResponse, + WorkflowRunNodeExecutionListResponse, + WorkflowRunExportResponse, + HumanInputPauseTypeResponse, + PausedNodeResponse, + WorkflowPauseDetailsResponse, +) + + @console_ns.route("/apps//advanced-chat/workflow-runs") class AdvancedChatAppWorkflowRunListApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs") @console_ns.doc(description="Get advanced chat workflow run list") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) - @console_ns.doc( - params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"} + @console_ns.doc(params=query_params_from_model(WorkflowRunListQuery)) + @console_ns.response( + 200, + "Workflow runs retrieved successfully", + console_ns.models[AdvancedChatWorkflowRunPaginationResponse.__name__], ) - @console_ns.doc( - params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} - ) - @console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__]) - @console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT]) - @marshal_with(advanced_chat_workflow_run_pagination_model) def get(self, app_model: App): """ Get advanced chat app workflow run list """ - args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) args: WorkflowRunListArgs = {"limit": args_model.limit} if args_model.last_id is not None: args["last_id"] = args_model.last_id @@ -232,7 +173,9 @@ class AdvancedChatAppWorkflowRunListApi(Resource): app_model=app_model, args=args, triggered_from=triggered_from ) - return result + return AdvancedChatWorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump( + mode="json" + ) @console_ns.route("/apps//workflow-runs//export") @@ -240,7 +183,7 @@ class WorkflowRunExportApi(Resource): @console_ns.doc("get_workflow_run_export_url") @console_ns.doc(description="Generate a download URL for an archived workflow run.") @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) - @console_ns.response(200, "Export URL generated", workflow_run_export_fields) + @console_ns.response(200, "Export URL generated", console_ns.models[WorkflowRunExportResponse.__name__]) @setup_required @login_required @account_initialization_required @@ -278,11 +221,14 @@ class WorkflowRunExportApi(Resource): expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS, ) expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS) - return { - "status": "success", - "presigned_url": presigned_url, - "presigned_url_expires_at": expires_at.isoformat(), - }, 200 + response = WorkflowRunExportResponse.model_validate( + { + "status": "success", + "presigned_url": presigned_url, + "presigned_url_expires_at": expires_at.isoformat(), + } + ) + return response.model_dump(mode="json"), 200 @console_ns.route("/apps//advanced-chat/workflow-runs/count") @@ -290,32 +236,21 @@ class AdvancedChatAppWorkflowRunCountApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs_count") @console_ns.doc(description="Get advanced chat workflow runs count statistics") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc( - params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"} + @console_ns.doc(params=query_params_from_model(WorkflowRunCountQuery)) + @console_ns.response( + 200, + "Workflow runs count retrieved successfully", + console_ns.models[WorkflowRunCountResponse.__name__], ) - @console_ns.doc( - params={ - "time_range": ( - "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " - "30m (30 minutes), 30s (30 seconds). Filters by created_at field." - ) - } - ) - @console_ns.doc( - params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} - ) - @console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model) - @console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__]) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT]) - @marshal_with(workflow_run_count_model) def get(self, app_model: App): """ Get advanced chat workflow runs count statistics """ - args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) args = args_model.model_dump(exclude_none=True) # Default to DEBUGGING if not specified @@ -333,7 +268,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource): triggered_from=triggered_from, ) - return result + return WorkflowRunCountResponse.model_validate(result).model_dump(mode="json") @console_ns.route("/apps//workflow-runs") @@ -341,25 +276,21 @@ class WorkflowRunListApi(Resource): @console_ns.doc("get_workflow_runs") @console_ns.doc(description="Get workflow run list") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) - @console_ns.doc( - params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"} + @console_ns.doc(params=query_params_from_model(WorkflowRunListQuery)) + @console_ns.response( + 200, + "Workflow runs retrieved successfully", + console_ns.models[WorkflowRunPaginationResponse.__name__], ) - @console_ns.doc( - params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} - ) - @console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model) - @console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__]) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_pagination_model) def get(self, app_model: App): """ Get workflow run list """ - args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) args: WorkflowRunListArgs = {"limit": args_model.limit} if args_model.last_id is not None: args["last_id"] = args_model.last_id @@ -378,7 +309,7 @@ class WorkflowRunListApi(Resource): app_model=app_model, args=args, triggered_from=triggered_from ) - return result + return WorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//workflow-runs/count") @@ -386,32 +317,21 @@ class WorkflowRunCountApi(Resource): @console_ns.doc("get_workflow_runs_count") @console_ns.doc(description="Get workflow runs count statistics") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc( - params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"} + @console_ns.doc(params=query_params_from_model(WorkflowRunCountQuery)) + @console_ns.response( + 200, + "Workflow runs count retrieved successfully", + console_ns.models[WorkflowRunCountResponse.__name__], ) - @console_ns.doc( - params={ - "time_range": ( - "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " - "30m (30 minutes), 30s (30 seconds). Filters by created_at field." - ) - } - ) - @console_ns.doc( - params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} - ) - @console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model) - @console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__]) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_count_model) def get(self, app_model: App): """ Get workflow runs count statistics """ - args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) args = args_model.model_dump(exclude_none=True) # Default to DEBUGGING for workflow if not specified (backward compatibility) @@ -429,7 +349,7 @@ class WorkflowRunCountApi(Resource): triggered_from=triggered_from, ) - return result + return WorkflowRunCountResponse.model_validate(result).model_dump(mode="json") @console_ns.route("/apps//workflow-runs/") @@ -437,13 +357,16 @@ class WorkflowRunDetailApi(Resource): @console_ns.doc("get_workflow_run_detail") @console_ns.doc(description="Get workflow run detail") @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) - @console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model) + @console_ns.response( + 200, + "Workflow run detail retrieved successfully", + console_ns.models[WorkflowRunDetailResponse.__name__], + ) @console_ns.response(404, "Workflow run not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_detail_model) def get(self, app_model: App, run_id): """ Get workflow run detail @@ -452,8 +375,10 @@ class WorkflowRunDetailApi(Resource): workflow_run_service = WorkflowRunService() workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id) + if workflow_run is None: + raise NotFoundError("Workflow run not found") - return workflow_run + return WorkflowRunDetailResponse.model_validate(workflow_run, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//workflow-runs//node-executions") @@ -461,13 +386,16 @@ class WorkflowRunNodeExecutionListApi(Resource): @console_ns.doc("get_workflow_run_node_executions") @console_ns.doc(description="Get workflow run node execution list") @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) - @console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model) + @console_ns.response( + 200, + "Node executions retrieved successfully", + console_ns.models[WorkflowRunNodeExecutionListResponse.__name__], + ) @console_ns.response(404, "Workflow run not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_node_execution_list_model) def get(self, app_model: App, run_id): """ Get workflow run node execution list @@ -482,13 +410,24 @@ class WorkflowRunNodeExecutionListApi(Resource): user=user, ) - return {"data": node_executions} + return WorkflowRunNodeExecutionListResponse.model_validate( + {"data": node_executions}, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/workflow//pause-details") class ConsoleWorkflowPauseDetailsApi(Resource): """Console API for getting workflow pause details.""" + @console_ns.doc("get_workflow_pause_details") + @console_ns.doc(description="Get workflow pause details") + @console_ns.doc(params={"workflow_run_id": "Workflow run ID"}) + @console_ns.response( + 200, + "Workflow pause details retrieved successfully", + console_ns.models[WorkflowPauseDetailsResponse.__name__], + ) + @console_ns.response(404, "Workflow run not found") @setup_required @login_required @account_initialization_required @@ -515,11 +454,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource): # Check if workflow is suspended is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED if not is_paused: - empty_response: WorkflowPauseDetailsResponse = { - "paused_at": None, - "paused_nodes": [], - } - return empty_response, 200 + empty_response = WorkflowPauseDetailsResponse(paused_at=None, paused_nodes=[]) + return empty_response.model_dump(mode="json"), 200 pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] @@ -530,27 +466,25 @@ class ConsoleWorkflowPauseDetailsApi(Resource): # Build response paused_at = pause_entity.paused_at if pause_entity else None paused_nodes: list[PausedNodeResponse] = [] - response: WorkflowPauseDetailsResponse = { - "paused_at": paused_at.isoformat() + "Z" if paused_at else None, - "paused_nodes": paused_nodes, - } for reason in pause_reasons: if isinstance(reason, HumanInputRequired): paused_nodes.append( - { - "node_id": reason.node_id, - "node_title": reason.node_title, - "pause_type": { - "type": "human_input", - "form_id": reason.form_id, - "backstage_input_url": _build_backstage_input_url( - form_tokens_by_form_id.get(reason.form_id) - ), - }, - } + PausedNodeResponse( + node_id=reason.node_id, + node_title=reason.node_title, + pause_type=HumanInputPauseTypeResponse( + type="human_input", + form_id=reason.form_id, + backstage_input_url=_build_backstage_input_url(form_tokens_by_form_id.get(reason.form_id)), + ), + ) ) else: raise AssertionError("unimplemented.") - return response, 200 + response = WorkflowPauseDetailsResponse( + paused_at=paused_at.isoformat() + "Z" if paused_at else None, + paused_nodes=paused_nodes, + ) + return response.model_dump(mode="json"), 200 diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index e48cf42762..ca899d8784 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -3,6 +3,7 @@ from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required @@ -13,8 +14,6 @@ from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class WorkflowStatisticQuery(BaseModel): start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)") @@ -28,10 +27,7 @@ class WorkflowStatisticQuery(BaseModel): return value -console_ns.schema_model( - WorkflowStatisticQuery.__name__, - WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_schema_models(console_ns, WorkflowStatisticQuery) @console_ns.route("/apps//workflow/statistics/daily-conversations") @@ -53,7 +49,7 @@ class WorkflowDailyRunsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None @@ -93,7 +89,7 @@ class WorkflowDailyTerminalsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None @@ -133,7 +129,7 @@ class WorkflowDailyTokenCostStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None @@ -173,7 +169,7 @@ class WorkflowAverageAppInteractionStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index a6715fa200..a80b4f5d0c 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -94,7 +94,7 @@ class WebhookTriggerApi(Resource): @console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__]) def get(self, app_model: App): """Get webhook trigger for a node""" - args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = Parser.model_validate(request.args.to_dict(flat=True)) node_id = args.node_id diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index f7061f820f..0c05cf2fe3 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -63,7 +63,7 @@ class ActivateCheckApi(Resource): console_ns.models[ActivationCheckResponse.__name__], ) def get(self): - args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) workspaceId = args.workspace_id token = args.token diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 905d0daef0..db0d36af6e 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,6 +1,7 @@ from flask_restx import Resource from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from libs.login import current_account_with_tenant, login_required from services.auth.api_key_auth_service import ApiKeyAuthService @@ -8,8 +9,6 @@ from .. import console_ns from ..auth.error import ApiKeyAuthFailedError from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ApiKeyAuthBindingPayload(BaseModel): category: str = Field(...) @@ -17,10 +16,7 @@ class ApiKeyAuthBindingPayload(BaseModel): credentials: dict = Field(...) -console_ns.schema_model( - ApiKeyAuthBindingPayload.__name__, - ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_schema_models(console_ns, ApiKeyAuthBindingPayload) @console_ns.route("/api-key-auth/data-source") diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 1fd781b4fc..f6b8aedf22 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field, field_validator from configs import dify_config from constants.languages import languages +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailAlreadyInUseError, @@ -23,8 +24,6 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError from ..error import AccountInFreezeError, EmailSendIpLimitError from ..wraps import email_password_login_enabled, email_register_enabled, setup_required -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class EmailRegisterSendPayload(BaseModel): email: EmailStr = Field(..., description="Email address") @@ -48,8 +47,7 @@ class EmailRegisterResetPayload(BaseModel): return valid_password(value) -for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload) @console_ns.route("/email-register/send-email") diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index ed390a5f89..c34dd1ac85 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -28,8 +28,6 @@ from services.entities.auth_entities import ( ) from services.feature_service import FeatureService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ForgotPasswordEmailResponse(BaseModel): result: str = Field(description="Operation result") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 8216b3d0da..19c98f3a1a 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -9,6 +9,7 @@ from werkzeug.exceptions import Unauthorized import services from configs import dify_config from constants.languages import get_valid_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( AuthenticationFailedError, @@ -50,7 +51,6 @@ from services.errors.account import AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" logger = logging.getLogger(__name__) @@ -71,13 +71,7 @@ class EmailCodeLoginPayload(BaseModel): language: str | None = Field(default=None) -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(LoginPayload) -reg(EmailPayload) -reg(EmailCodeLoginPayload) +register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPayload) @console_ns.route("/login") diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index 7caf5b52ed..a43caa8f56 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -4,6 +4,7 @@ from flask_restx import ( # type: ignore from pydantic import BaseModel from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required @@ -12,8 +13,6 @@ from models import Account from models.dataset import Pipeline from services.rag_pipeline.rag_pipeline import RagPipelineService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class Parser(BaseModel): inputs: dict @@ -21,7 +20,7 @@ class Parser(BaseModel): credential_id: str | None = None -console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, Parser) @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//preview") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index ee146e8287..8eff32c555 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -10,7 +10,7 @@ from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotF import services from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload -from controllers.common.schema import register_schema_models +from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( ConversationCompletedError, @@ -22,12 +22,6 @@ from controllers.console.app.workflow import ( workflow_model, workflow_pagination_model, ) -from controllers.console.app.workflow_run import ( - workflow_run_detail_model, - workflow_run_node_execution_list_model, - workflow_run_node_execution_model, - workflow_run_pagination_model, -) from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( account_initialization_required, @@ -40,6 +34,12 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from factories import variable_factory +from fields.workflow_run_fields import ( + WorkflowRunDetailResponse, + WorkflowRunNodeExecutionListResponse, + WorkflowRunNodeExecutionResponse, + WorkflowRunPaginationResponse, +) from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty @@ -131,6 +131,13 @@ register_schema_models( DatasourceVariablesPayload, RagPipelineRecommendedPluginQuery, ) +register_response_schema_models( + console_ns, + WorkflowRunDetailResponse, + WorkflowRunNodeExecutionListResponse, + WorkflowRunNodeExecutionResponse, + WorkflowRunPaginationResponse, +) @console_ns.route("/rag/pipelines//workflows/draft") @@ -415,12 +422,16 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/nodes//run") class RagPipelineDraftNodeRunApi(Resource): @console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__]) + @console_ns.response( + 200, + "Node run started successfully", + console_ns.models[WorkflowRunNodeExecutionResponse.__name__], + ) @setup_required @login_required @edit_permission_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_node_execution_model) def post(self, pipeline: Pipeline, node_id: str): """ Run draft workflow node @@ -439,7 +450,9 @@ class RagPipelineDraftNodeRunApi(Resource): if workflow_node_execution is None: raise ValueError("Workflow node execution not found") - return workflow_node_execution + return WorkflowRunNodeExecutionResponse.model_validate( + workflow_node_execution, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/rag/pipelines//workflow-runs/tasks//stop") @@ -778,11 +791,15 @@ class DraftRagPipelineSecondStepApi(Resource): @console_ns.route("/rag/pipelines//workflow-runs") class RagPipelineWorkflowRunListApi(Resource): + @console_ns.response( + 200, + "Workflow runs retrieved successfully", + console_ns.models[WorkflowRunPaginationResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_pagination_model) def get(self, pipeline: Pipeline): """ Get workflow run list @@ -801,16 +818,20 @@ class RagPipelineWorkflowRunListApi(Resource): rag_pipeline_service = RagPipelineService() result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args) - return result + return WorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(mode="json") @console_ns.route("/rag/pipelines//workflow-runs/") class RagPipelineWorkflowRunDetailApi(Resource): + @console_ns.response( + 200, + "Workflow run detail retrieved successfully", + console_ns.models[WorkflowRunDetailResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_detail_model) def get(self, pipeline: Pipeline, run_id): """ Get workflow run detail @@ -819,17 +840,23 @@ class RagPipelineWorkflowRunDetailApi(Resource): rag_pipeline_service = RagPipelineService() workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id) + if workflow_run is None: + raise NotFound("Workflow run not found") - return workflow_run + return WorkflowRunDetailResponse.model_validate(workflow_run, from_attributes=True).model_dump(mode="json") @console_ns.route("/rag/pipelines//workflow-runs//node-executions") class RagPipelineWorkflowRunNodeExecutionListApi(Resource): + @console_ns.response( + 200, + "Node executions retrieved successfully", + console_ns.models[WorkflowRunNodeExecutionListResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_node_execution_list_model) def get(self, pipeline: Pipeline, run_id: str): """ Get workflow run node execution list @@ -844,7 +871,9 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource): user=user, ) - return {"data": node_executions} + return WorkflowRunNodeExecutionListResponse.model_validate( + {"data": node_executions}, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/rag/pipelines/datasource-plugins") @@ -859,11 +888,15 @@ class DatasourceListApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/nodes//last-run") class RagPipelineWorkflowLastRunApi(Resource): + @console_ns.response( + 200, + "Node last run retrieved successfully", + console_ns.models[WorkflowRunNodeExecutionResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_node_execution_model) def get(self, pipeline: Pipeline, node_id: str): rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) @@ -876,7 +909,7 @@ class RagPipelineWorkflowLastRunApi(Resource): ) if node_exec is None: raise NotFound("last run not found") - return node_exec + return WorkflowRunNodeExecutionResponse.model_validate(node_exec, from_attributes=True).model_dump(mode="json") @console_ns.route("/rag/pipelines/transform/datasets/") @@ -899,12 +932,16 @@ class RagPipelineTransformApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/datasource/variables-inspect") class RagPipelineDatasourceVariableApi(Resource): @console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__]) + @console_ns.response( + 200, + "Datasource variables set successfully", + console_ns.models[WorkflowRunNodeExecutionResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_rag_pipeline @edit_permission_required - @marshal_with(workflow_run_node_execution_model) def post(self, pipeline: Pipeline): """ Set datasource variables @@ -918,7 +955,9 @@ class RagPipelineDatasourceVariableApi(Resource): args=args, current_user=current_user, ) - return workflow_node_execution + return WorkflowRunNodeExecutionResponse.model_validate( + workflow_node_execution, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/rag/pipelines/recommended-plugins") diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 572f9773a1..5821b91489 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,4 +1,5 @@ from typing import Any +from uuid import UUID from flask import request from flask_restx import Resource @@ -80,7 +81,7 @@ class RecommendedAppListApi(Resource): @account_initialization_required def get(self): # language args - args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) language = args.language if language and language in languages: language_prefix = language @@ -99,6 +100,5 @@ class RecommendedAppListApi(Resource): class RecommendedAppApi(Resource): @login_required @account_initialization_required - def get(self, app_id): - app_id = str(app_id) - return RecommendedAppService.get_recommend_app_detail(app_id) + def get(self, app_id: UUID): + return RecommendedAppService.get_recommend_app_detail(str(app_id)) diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 1456301a24..025c517d20 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -10,7 +10,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from controllers.common.fields import Parameters as ParametersResponse from controllers.common.fields import Site as SiteResponse -from controllers.common.schema import get_or_create_model +from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( AppUnavailableError, @@ -120,10 +120,6 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy) -# Pydantic models for request validation -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - - class WorkflowRunRequest(BaseModel): inputs: dict files: list | None = None @@ -153,19 +149,7 @@ class CompletionRequest(BaseModel): retriever_from: str = "explore_app" -# Register schemas for Swagger documentation -console_ns.schema_model( - WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, WorkflowRunRequest, ChatRequest, TextToSpeechRequest, CompletionRequest) class TrialAppWorkflowRunApi(TrialAppResource): diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 7a6356d052..9ffc18e4c2 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -89,7 +89,7 @@ class CodeBasedExtensionAPI(Resource): @login_required @account_initialization_required def get(self): - query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) return CodeBasedExtensionResponse( module=query.module, diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 109a3cd0d3..9fa5b0f5c1 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -82,7 +82,7 @@ class FileApi(Resource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, source=source, diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index d69a59ecb7..68520e540b 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -52,8 +52,6 @@ from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AccountInitPayload(BaseModel): interface_language: str @@ -161,27 +159,26 @@ class CheckEmailUniquePayload(BaseModel): email: EmailStr -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(AccountInitPayload) -reg(AccountNamePayload) -reg(AccountAvatarPayload) -reg(AccountAvatarQuery) -reg(AccountInterfaceLanguagePayload) -reg(AccountInterfaceThemePayload) -reg(AccountTimezonePayload) -reg(AccountPasswordPayload) -reg(AccountDeletePayload) -reg(AccountDeletionFeedbackPayload) -reg(EducationActivatePayload) -reg(EducationAutocompleteQuery) -reg(ChangeEmailSendPayload) -reg(ChangeEmailValidityPayload) -reg(ChangeEmailResetPayload) -reg(CheckEmailUniquePayload) -register_schema_models(console_ns, AccountResponse) +register_schema_models( + console_ns, + AccountResponse, + AccountInitPayload, + AccountNamePayload, + AccountAvatarPayload, + AccountAvatarQuery, + AccountInterfaceLanguagePayload, + AccountInterfaceThemePayload, + AccountTimezonePayload, + AccountPasswordPayload, + AccountDeletePayload, + AccountDeletionFeedbackPayload, + EducationActivatePayload, + EducationAutocompleteQuery, + ChangeEmailSendPayload, + ChangeEmailValidityPayload, + ChangeEmailResetPayload, + CheckEmailUniquePayload, +) def _serialize_account(account) -> dict[str, Any]: @@ -326,7 +323,7 @@ class AccountAvatarApi(Resource): @account_initialization_required def get(self): current_user, current_tenant_id = current_account_with_tenant() - args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) avatar = args.avatar if avatar.startswith(("http://", "https://")): diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index d4be07382a..925f3e1197 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -20,8 +20,6 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class EndpointCreatePayload(BaseModel): plugin_unique_identifier: str @@ -80,10 +78,6 @@ class EndpointDisableResponse(BaseModel): success: bool = Field(description="Operation success") -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - register_schema_models( console_ns, EndpointCreatePayload, @@ -215,7 +209,7 @@ class EndpointListApi(Resource): def get(self): user, tenant_id = current_account_with_tenant() - args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) page = args.page page_size = args.page_size @@ -248,7 +242,7 @@ class EndpointListForSinglePluginApi(Resource): def get(self): user, tenant_id = current_account_with_tenant() - args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) page = args.page page_size = args.page_size diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index e3bf4c95b8..c2533c9872 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -33,8 +33,6 @@ from services.account_service import AccountService, RegisterService, TenantServ from services.errors.account import AccountAlreadyInTenantError from services.feature_service import FeatureService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class MemberInvitePayload(BaseModel): emails: list[str] = Field(default_factory=list) @@ -59,17 +57,17 @@ class OwnerTransferPayload(BaseModel): token: str -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(MemberInvitePayload) -reg(MemberRoleUpdatePayload) -reg(OwnerTransferEmailPayload) -reg(OwnerTransferCheckPayload) -reg(OwnerTransferPayload) register_enum_models(console_ns, TenantAccountRole) -register_schema_models(console_ns, AccountWithRole, AccountWithRoleList) +register_schema_models( + console_ns, + AccountWithRole, + AccountWithRoleList, + MemberInvitePayload, + MemberRoleUpdatePayload, + OwnerTransferEmailPayload, + OwnerTransferCheckPayload, + OwnerTransferPayload, +) @console_ns.route("/workspaces/current/members") diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 4b10561fdb..2f75218c0f 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -5,6 +5,7 @@ from flask import request, send_file from flask_restx import Resource from pydantic import BaseModel, Field, field_validator +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from graphon.model_runtime.entities.model_entities import ModelType @@ -15,8 +16,6 @@ from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService from services.model_provider_service import ModelProviderService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ParserModelList(BaseModel): model_type: ModelType | None = None @@ -75,18 +74,17 @@ class ParserPreferredProviderType(BaseModel): preferred_provider_type: Literal["system", "custom"] -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(ParserModelList) -reg(ParserCredentialId) -reg(ParserCredentialCreate) -reg(ParserCredentialUpdate) -reg(ParserCredentialDelete) -reg(ParserCredentialSwitch) -reg(ParserCredentialValidate) -reg(ParserPreferredProviderType) +register_schema_models( + console_ns, + ParserModelList, + ParserCredentialId, + ParserCredentialCreate, + ParserCredentialUpdate, + ParserCredentialDelete, + ParserCredentialSwitch, + ParserCredentialValidate, + ParserPreferredProviderType, +) @console_ns.route("/workspaces/current/model-providers") diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index b2d07ff8f9..7f7d6379c3 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -17,7 +17,6 @@ from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class ParserGetDefault(BaseModel): @@ -107,6 +106,12 @@ class ParserParameter(BaseModel): model: str +class ParserSwitch(BaseModel): + model: str + model_type: ModelType + credential_id: str + + register_schema_models( console_ns, ParserGetDefault, @@ -119,6 +124,7 @@ register_schema_models( ParserDeleteCredential, ParserParameter, Inner, + ParserSwitch, ) register_enum_models(console_ns, ModelType) @@ -133,7 +139,7 @@ class DefaultModelApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) model_provider_service = ModelProviderService() default_model_entity = model_provider_service.get_default_model_of_model_type( @@ -261,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource): def get(self, provider: str): _, tenant_id = current_account_with_tenant() - args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) model_provider_service = ModelProviderService() current_credential = model_provider_service.get_model_credential( @@ -387,17 +393,6 @@ class ModelProviderModelCredentialApi(Resource): return {"result": "success"}, 204 -class ParserSwitch(BaseModel): - model: str - model_type: ModelType - credential_id: str - - -console_ns.schema_model( - ParserSwitch.__name__, ParserSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - - @console_ns.route("/workspaces/current/model-providers//models/credentials/switch") class ModelProviderModelCredentialSwitchApi(Resource): @console_ns.expect(console_ns.models[ParserSwitch.__name__]) @@ -468,9 +463,7 @@ class ParserValidate(BaseModel): credentials: dict[str, Any] -console_ns.schema_model( - ParserValidate.__name__, ParserValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, ParserSwitch, ParserValidate) @console_ns.route("/workspaces/current/model-providers//models/credentials/validate") @@ -515,7 +508,7 @@ class ModelProviderModelParameterRuleApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - args = ParserParameter.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserParameter.model_validate(request.args.to_dict(flat=True)) _, tenant_id = current_account_with_tenant() model_provider_service = ModelProviderService() diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index b3e344ccea..a6d4a60beb 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -177,7 +177,7 @@ def _read_upload_content(file: FileStorage, max_size: int) -> bytes: FileStorage.content_length is not reliable for multipart test uploads and may be zero even when content exists, so the controllers validate against the loaded bytes instead. """ - content = file.read() + content = file.stream.read() if len(content) > max_size: raise ValueError("File size exceeds the maximum allowed size") @@ -211,7 +211,7 @@ class PluginListApi(Resource): @account_initialization_required def get(self): _, tenant_id = current_account_with_tenant() - args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserList.model_validate(request.args.to_dict(flat=True)) try: plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) except PluginDaemonClientSideError as e: @@ -261,7 +261,7 @@ class PluginIconApi(Resource): @console_ns.expect(console_ns.models[ParserIcon.__name__]) @setup_required def get(self): - args = ParserIcon.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserIcon.model_validate(request.args.to_dict(flat=True)) try: icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename) @@ -279,7 +279,7 @@ class PluginAssetApi(Resource): @login_required @account_initialization_required def get(self): - args = ParserAsset.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserAsset.model_validate(request.args.to_dict(flat=True)) _, tenant_id = current_account_with_tenant() try: @@ -421,7 +421,7 @@ class PluginFetchMarketplacePkgApi(Resource): @plugin_permission_required(install_required=True) def get(self): _, tenant_id = current_account_with_tenant() - args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) try: return jsonable_encoder( @@ -446,7 +446,7 @@ class PluginFetchManifestApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) try: return jsonable_encoder( @@ -466,7 +466,7 @@ class PluginFetchInstallTasksApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - args = ParserTasks.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserTasks.model_validate(request.args.to_dict(flat=True)) try: return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)}) @@ -660,7 +660,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource): current_user, tenant_id = current_account_with_tenant() user_id = current_user.id - args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) try: options = PluginParameterService.get_dynamic_select_options( @@ -822,7 +822,7 @@ class PluginReadmeApi(Resource): @account_initialization_required def get(self): _, tenant_id = current_account_with_tenant() - args = ParserReadme.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserReadme.model_validate(request.args.to_dict(flat=True)) return jsonable_encoder( {"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)} ) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 565099db61..84890f0443 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -16,6 +16,7 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.admin import admin_required from controllers.console.error import AccountNotLinkTenantError @@ -39,7 +40,6 @@ from services.file_service import FileService from services.workspace_service import WorkspaceService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class WorkspaceListQuery(BaseModel): @@ -91,15 +91,14 @@ class TenantInfoResponse(ResponseModel): return value -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(WorkspaceListQuery) -reg(SwitchWorkspacePayload) -reg(WorkspaceCustomConfigPayload) -reg(WorkspaceInfoPayload) -reg(TenantInfoResponse) +register_schema_models( + console_ns, + WorkspaceListQuery, + SwitchWorkspacePayload, + WorkspaceCustomConfigPayload, + WorkspaceInfoPayload, + TenantInfoResponse, +) provider_fields = { "provider_name": fields.String, @@ -322,7 +321,7 @@ class WebappLogoWorkspaceApi(Resource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, ) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index a91e745f80..be7886e831 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -8,13 +8,12 @@ from werkzeug.exceptions import NotFound import services from controllers.common.errors import UnsupportedFileTypeError from controllers.common.file_response import enforce_download_for_html +from controllers.common.schema import register_schema_models from controllers.files import files_ns from extensions.ext_database import db from services.account_service import TenantService from services.file_service import FileService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class FileSignatureQuery(BaseModel): timestamp: str = Field(..., description="Unix timestamp used in the signature") @@ -26,12 +25,7 @@ class FilePreviewQuery(FileSignatureQuery): as_attachment: bool = Field(default=False, description="Whether to download as attachment") -files_ns.schema_model( - FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -files_ns.schema_model( - FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(files_ns, FileSignatureQuery, FilePreviewQuery) @files_ns.route("//image-preview") @@ -58,7 +52,7 @@ class ImagePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) timestamp = args.timestamp nonce = args.nonce sign = args.sign @@ -100,7 +94,7 @@ class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) try: generator, upload_file = FileService(db.engine).get_file_generator_by_file_id( diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 2f1e2f28bd..8ae16ce7f4 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -7,12 +7,11 @@ from werkzeug.exceptions import Forbidden, NotFound from controllers.common.errors import UnsupportedFileTypeError from controllers.common.file_response import enforce_download_for_html +from controllers.common.schema import register_schema_models from controllers.files import files_ns from core.tools.signature import verify_tool_file_signature from core.tools.tool_file_manager import ToolFileManager -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ToolFileQuery(BaseModel): timestamp: str = Field(..., description="Unix timestamp") @@ -21,9 +20,7 @@ class ToolFileQuery(BaseModel): as_attachment: bool = Field(default=False, description="Download as attachment") -files_ns.schema_model( - ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(files_ns, ToolFileQuery) @files_ns.route("/tools/.") diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index ed3278a28b..7d588b95dd 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -20,8 +20,6 @@ from ..console.wraps import setup_required from ..files import files_ns from ..inner_api.plugin.wraps import get_user -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class PluginUploadQuery(BaseModel): timestamp: str = Field(..., description="Unix timestamp for signature verification") @@ -31,9 +29,8 @@ class PluginUploadQuery(BaseModel): user_id: str | None = Field(default=None, description="User identifier") -files_ns.schema_model( - PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(files_ns, PluginUploadQuery) + register_schema_models(files_ns, FileResponse) @@ -69,7 +66,7 @@ class PluginUploadFileApi(Resource): FileTooLargeError: File exceeds size limit UnsupportedFileTypeError: File type not supported """ - args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) file = request.files.get("file") if file is None: @@ -103,7 +100,7 @@ class PluginUploadFileApi(Resource): tool_file = ToolFileManager().create_file_by_raw( user_id=user.id, tenant_id=tenant_id, - file_binary=file.read(), + file_binary=file.stream.read(), mimetype=mimetype, filename=filename, conversation_id=None, diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 6f6dadf768..687d34076d 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -58,7 +58,7 @@ class FileApi(Resource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=end_user, ) diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 3eb773fa7c..9af66f1960 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_valid from werkzeug.exceptions import Forbidden, NotFound import services -from controllers.common.schema import register_schema_models +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError @@ -34,13 +34,7 @@ from services.tag_service import ( UpdateTagPayload, ) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - - -service_api_ns.schema_model( - DatasetPermissionEnum.__name__, - TypeAdapter(DatasetPermissionEnum).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_enum_models(service_api_ns, DatasetPermissionEnum) class DatasetCreatePayload(BaseModel): diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 0b09facf58..cb48fe6715 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -77,9 +77,6 @@ class DocumentTextCreatePayload(BaseModel): return value -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - - class DocumentTextUpdate(BaseModel): name: str | None = None text: str | None = None @@ -435,7 +432,7 @@ class DocumentAddByFileApi(DatasetApiResource): raise ValueError("current_user is required") upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, source="datasets", @@ -509,7 +506,7 @@ def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, source="datasets", diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 2dc98bfbf7..8bc43bccd5 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -241,7 +241,7 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, ) diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index 0036c90800..6128490104 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -73,7 +73,7 @@ class FileApi(WebApiResource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=end_user, source="datasets" if source == "datasets" else None, diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index c22102c2ba..cba4659483 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -532,7 +532,6 @@ class BaseAgentRunner(AppRunner): file_objs = file_factory.build_from_message_files( message_files=files, tenant_id=self.tenant_id, - config=file_extra_config, access_controller=_file_access_controller, ) if not file_objs: diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 4c07445df3..f4bbbe5d8b 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -75,7 +75,7 @@ class PromptTemplateConfigManager: if not config.get("prompt_type"): config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE - prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] + prompt_type_vals = list(PromptTemplateEntity.PromptType) if config["prompt_type"] not in prompt_type_vals: raise ValueError(f"prompt_type must be in {prompt_type_vals}") diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index d840ee213c..c41c175cca 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -86,12 +86,10 @@ class TokenBufferMemory: detail = ImagePromptMessageContent.DETAIL.HIGH if file_extra_config and app_record: - # Build files directly without filtering by belongs_to file_objs = [ file_factory.build_from_message_file( message_file=message_file, tenant_id=app_record.tenant_id, - config=file_extra_config, access_controller=_file_access_controller, ) for message_file in message_files diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index b985ebbe1d..7769878e70 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -21,7 +21,7 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.tools.signature import sign_upload_file +from core.tools.signature import sign_upload_file_preview_url from extensions.ext_database import db from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( @@ -893,7 +893,7 @@ class RetrievalService: "name": upload_file.name, "extension": "." + upload_file.extension, "mime_type": upload_file.mime_type, - "source_url": sign_upload_file(upload_file.id, upload_file.extension), + "source_url": sign_upload_file_preview_url(upload_file.id, upload_file.extension), "size": upload_file.size, } return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id} @@ -920,7 +920,7 @@ class RetrievalService: "name": upload_file.name, "extension": "." + upload_file.extension, "mime_type": upload_file.mime_type, - "source_url": sign_upload_file(upload_file.id, upload_file.extension), + "source_url": sign_upload_file_preview_url(upload_file.id, upload_file.extension), "size": upload_file.size, } if attachment_binding: diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 02f0efc908..25f6fe3e2a 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -115,7 +115,7 @@ class PdfExtractor(BaseExtractor): """ image_content = [] upload_files = [] - base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + base_url = dify_config.FILES_URL try: image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,)) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 0330a43b28..60f8906181 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -110,7 +110,7 @@ class WordExtractor(BaseExtractor): def _extract_images_from_docx(self, doc): image_count = 0 image_map = {} - base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + base_url = dify_config.FILES_URL for r_id, rel in doc.part.rels.items(): if "image" in rel.target_ref: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 5631b3a921..010566d203 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -52,7 +52,7 @@ from core.rag.retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) -from core.tools.signature import sign_upload_file +from core.tools.signature import sign_upload_file_preview_url from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.workflow.file_reference import build_file_reference from core.workflow.nodes.knowledge_retrieval import exc @@ -529,7 +529,7 @@ class DatasetRetrieval: ), size=upload_file.size, storage_key=upload_file.key, - url=sign_upload_file(upload_file.id, upload_file.extension), + url=sign_upload_file_preview_url(upload_file.id, upload_file.extension), ) context_files.append(attachment_info) if show_retrieve_source: diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 1807226924..3c7b523ff1 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -26,12 +26,14 @@ def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" -def sign_upload_file(upload_file_id: str, extension: str) -> str: +def sign_upload_file_preview_url(upload_file_id: str, extension: str) -> str: """ - sign file to get a temporary url for plugin access + Sign an upload file to get a temporary image preview URL. + + The URL generated by this function is only for external preview and download, + not for internal communication. """ - # Use internal URL for plugin/tool file access in Docker environments - base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + base_url = dify_config.FILES_URL file_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" timestamp = str(int(time.time())) diff --git a/api/dev/generate_swagger_markdown_docs.py b/api/dev/generate_swagger_markdown_docs.py index 0900d08331..e0028c63f6 100644 --- a/api/dev/generate_swagger_markdown_docs.py +++ b/api/dev/generate_swagger_markdown_docs.py @@ -29,18 +29,39 @@ STALE_COMBINED_MARKDOWN_FILENAME = "api-reference.md" def _convert_spec_to_markdown(spec_path: Path, markdown_path: Path) -> None: - subprocess.run( - [ - "npx", - "--yes", - SWAGGER_MARKDOWN_PACKAGE, - "-i", - str(spec_path), - "-o", - str(markdown_path), - ], - check=True, - ) + markdown_path.parent.mkdir(parents=True, exist_ok=True) + with tempfile.TemporaryDirectory(prefix=f"{markdown_path.stem}-", dir=markdown_path.parent) as temp_dir: + temp_markdown_path = Path(temp_dir) / markdown_path.name + result = subprocess.run( + [ + "npx", + "--yes", + SWAGGER_MARKDOWN_PACKAGE, + "-i", + str(spec_path), + "-o", + str(temp_markdown_path), + ], + check=False, + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise subprocess.CalledProcessError( + result.returncode, + result.args, + output=result.stdout, + stderr=result.stderr, + ) + if not temp_markdown_path.exists(): + converter_output = "\n".join(item for item in (result.stdout, result.stderr) if item).strip() + raise RuntimeError(f"swagger-markdown did not write {markdown_path}: {converter_output}") + + converted_markdown = temp_markdown_path.read_text(encoding="utf-8") + if not converted_markdown.strip(): + raise RuntimeError(f"swagger-markdown wrote an empty document for {markdown_path}") + + markdown_path.write_text(converted_markdown, encoding="utf-8") def _demote_markdown_headings(markdown: str, *, levels: int = 1) -> str: diff --git a/api/dev/generate_swagger_specs.py b/api/dev/generate_swagger_specs.py index 9122f3ab24..254310cd2a 100644 --- a/api/dev/generate_swagger_specs.py +++ b/api/dev/generate_swagger_specs.py @@ -20,7 +20,6 @@ from pathlib import Path from typing import Protocol, TypeGuard from flask import Flask -from flask_restx.swagger import Swagger logger = logging.getLogger(__name__) @@ -48,9 +47,6 @@ SPEC_TARGETS: tuple[SpecTarget, ...] = ( SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"), ) -_ORIGINAL_REGISTER_MODEL = Swagger.register_model -_ORIGINAL_REGISTER_FIELD = Swagger.register_field - def _is_inline_field_map(value: object) -> TypeGuard[dict[object, object]]: """Return whether a nested field map is an anonymous inline mapping.""" @@ -152,56 +148,14 @@ def apply_runtime_defaults() -> None: dify_config.SWAGGER_UI_ENABLED = os.environ["SWAGGER_UI_ENABLED"].lower() == "true" -def _patch_swagger_for_inline_nested_dicts() -> None: - """Teach Flask-RESTX Swagger generation to tolerate inline nested field maps. - - Some existing controllers use `fields.Nested({...})` with a raw field mapping - instead of a named `api.model(...)`. Flask-RESTX crashes on those anonymous - dicts during schema registration, so this helper upgrades them into temporary - named models at export time. - """ - - if getattr(Swagger, "_dify_inline_nested_dict_patch", False): - return - - def get_or_create_inline_model(self: Swagger, nested_fields: dict[object, object]) -> object: - anonymous_models = getattr(self, "_anonymous_inline_models", None) - if anonymous_models is None: - anonymous_models = {} - self.__dict__["_anonymous_inline_models"] = anonymous_models - - anonymous_name = anonymous_models.get(id(nested_fields)) - if anonymous_name is None: - anonymous_name = _inline_model_name(nested_fields) - anonymous_models[id(nested_fields)] = anonymous_name - if anonymous_name not in self.api.models: - self.api.model(anonymous_name, nested_fields) - - return self.api.models[anonymous_name] - - def register_model_with_inline_dict_support(self: Swagger, model: object) -> dict[str, str]: - if _is_inline_field_map(model): - model = get_or_create_inline_model(self, model) - - return _ORIGINAL_REGISTER_MODEL(self, model) - - def register_field_with_inline_dict_support(self: Swagger, field: object) -> None: - nested = getattr(field, "nested", None) - if _is_inline_field_map(nested): - field.model = get_or_create_inline_model(self, nested) # type: ignore - - _ORIGINAL_REGISTER_FIELD(self, field) - - Swagger.register_model = register_model_with_inline_dict_support - Swagger.register_field = register_field_with_inline_dict_support - Swagger._dify_inline_nested_dict_patch = True - - def create_spec_app() -> Flask: """Build a minimal Flask app that only mounts the Swagger-producing blueprints.""" apply_runtime_defaults() - _patch_swagger_for_inline_nested_dicts() + + from libs.flask_restx_compat import patch_swagger_for_inline_nested_dicts + + patch_swagger_for_inline_nested_dicts() app = Flask(__name__) diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py index 4b3d514238..27441bdcc1 100644 --- a/api/factories/file_factory/message_files.py +++ b/api/factories/file_factory/message_files.py @@ -1,11 +1,18 @@ -"""Adapters from persisted message files to graph-layer file values.""" +"""Adapters from persisted message files to graph-layer file values. + +Replay paths only: files in conversation history were validated at upload time, +so these helpers deliberately do not accept (or forward) a ``FileUploadConfig`` — +re-validation here would break replays whenever workflow ``file_upload`` config +drifts between rounds. Mirrors ``build_file_from_stored_mapping`` in +``models/utils/file_input_compat.py``. +""" from __future__ import annotations from collections.abc import Sequence from core.app.file_access import FileAccessControllerProtocol -from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig +from graphon.file import File, FileBelongsTo, FileTransferMethod from models import MessageFile from .builders import build_from_mapping @@ -15,14 +22,12 @@ def build_from_message_files( *, message_files: Sequence[MessageFile], tenant_id: str, - config: FileUploadConfig | None = None, access_controller: FileAccessControllerProtocol, ) -> Sequence[File]: return [ build_from_message_file( message_file=message_file, tenant_id=tenant_id, - config=config, access_controller=access_controller, ) for message_file in message_files @@ -34,7 +39,6 @@ def build_from_message_file( *, message_file: MessageFile, tenant_id: str, - config: FileUploadConfig | None, access_controller: FileAccessControllerProtocol, ) -> File: mapping = { @@ -54,6 +58,5 @@ def build_from_message_file( return build_from_mapping( mapping=mapping, tenant_id=tenant_id, - config=config, access_controller=access_controller, ) diff --git a/api/factories/file_factory/validation.py b/api/factories/file_factory/validation.py index 4c4f6150e4..8c4e7ef1d4 100644 --- a/api/factories/file_factory/validation.py +++ b/api/factories/file_factory/validation.py @@ -2,9 +2,25 @@ from __future__ import annotations +from collections.abc import Iterable + from graphon.file import FileTransferMethod, FileType, FileUploadConfig +def _normalize_extension(extension: str) -> str: + s = extension.strip().lower() + if not s: + return "" + return s if s.startswith(".") else "." + s + + +def _extension_matches(extension: str, whitelist: Iterable[str]) -> bool: + normalized = _normalize_extension(extension) + if not normalized: + return False + return normalized in {_normalize_extension(e) for e in whitelist} + + def is_file_valid_with_config( *, input_file_type: str, @@ -12,22 +28,31 @@ def is_file_valid_with_config( file_transfer_method: FileTransferMethod, config: FileUploadConfig, ) -> bool: - # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model) - # These are internally generated and should bypass user upload restrictions + """Return whether the file is allowed by the upload config. + + ``allowed_file_types`` lists the buckets a file may fall into; ``CUSTOM`` is + a fallback bucket gated by ``allowed_file_extensions`` (case- and + dot-insensitive). Tool-generated files bypass user-facing config. + """ if file_transfer_method == FileTransferMethod.TOOL_FILE: return True - if ( - config.allowed_file_types - and input_file_type not in config.allowed_file_types - and input_file_type != FileType.CUSTOM - ): + allowed_types = config.allowed_file_types or [] + custom_allowed = FileType.CUSTOM in allowed_types + type_allowed = not allowed_types or input_file_type in allowed_types + + if not type_allowed and not custom_allowed: return False + # When the file is in the CUSTOM bucket, the extension whitelist is authoritative. + # An explicitly set whitelist (including the empty list) is enforced; empty == deny — + # the UI never submits an empty list, so this guards against DSL/API paths that + # bypass the UI from accidentally widening the allowlist. + in_custom_bucket = input_file_type == FileType.CUSTOM or not type_allowed if ( - input_file_type == FileType.CUSTOM + in_custom_bucket and config.allowed_file_extensions is not None - and file_extension not in config.allowed_file_extensions + and not _extension_matches(file_extension, config.allowed_file_extensions) ): return False diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 8c659086ed..a852f21bb2 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,14 +1,21 @@ +"""Workflow run response schemas for console APIs. + +Most workflow-run endpoints should document and serialize responses with the +Pydantic models in this module. The remaining Flask-RESTX field dictionaries are +kept only for workflow app-log endpoints that still build legacy log models. +""" + from __future__ import annotations from datetime import datetime from typing import Any from flask_restx import Namespace, fields -from pydantic import Field, field_validator +from pydantic import AliasChoices, Field, field_validator from fields.base import ResponseModel -from fields.end_user_fields import SimpleEndUser, simple_end_user_fields -from fields.member_fields import SimpleAccount, simple_account_fields +from fields.end_user_fields import SimpleEndUser +from fields.member_fields import SimpleAccount from libs.helper import TimestampField workflow_run_for_log_fields = { @@ -43,119 +50,6 @@ def build_workflow_run_for_archived_log_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowRunForArchivedLog", workflow_run_for_archived_log_fields) -workflow_run_for_list_fields = { - "id": fields.String, - "version": fields.String, - "status": fields.String, - "elapsed_time": fields.Float, - "total_tokens": fields.Integer, - "total_steps": fields.Integer, - "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), - "created_at": TimestampField, - "finished_at": TimestampField, - "exceptions_count": fields.Integer, - "retry_index": fields.Integer, -} - -advanced_chat_workflow_run_for_list_fields = { - "id": fields.String, - "conversation_id": fields.String, - "message_id": fields.String, - "version": fields.String, - "status": fields.String, - "elapsed_time": fields.Float, - "total_tokens": fields.Integer, - "total_steps": fields.Integer, - "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), - "created_at": TimestampField, - "finished_at": TimestampField, - "exceptions_count": fields.Integer, - "retry_index": fields.Integer, -} - -advanced_chat_workflow_run_pagination_fields = { - "limit": fields.Integer(attribute="limit"), - "has_more": fields.Boolean(attribute="has_more"), - "data": fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute="data"), -} - -workflow_run_pagination_fields = { - "limit": fields.Integer(attribute="limit"), - "has_more": fields.Boolean(attribute="has_more"), - "data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"), -} - -workflow_run_count_fields = { - "total": fields.Integer, - "running": fields.Integer, - "succeeded": fields.Integer, - "failed": fields.Integer, - "stopped": fields.Integer, - "partial_succeeded": fields.Integer(attribute="partial-succeeded"), -} - -workflow_run_detail_fields = { - "id": fields.String, - "version": fields.String, - "graph": fields.Raw(attribute="graph_dict"), - "inputs": fields.Raw(attribute="inputs_dict"), - "status": fields.String, - "outputs": fields.Raw(attribute="outputs_dict"), - "error": fields.String, - "elapsed_time": fields.Float, - "total_tokens": fields.Integer, - "total_steps": fields.Integer, - "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), - "created_at": TimestampField, - "finished_at": TimestampField, - "exceptions_count": fields.Integer, -} - -retry_event_field = { - "elapsed_time": fields.Float, - "status": fields.String, - "inputs": fields.Raw(attribute="inputs"), - "process_data": fields.Raw(attribute="process_data"), - "outputs": fields.Raw(attribute="outputs"), - "metadata": fields.Raw(attribute="metadata"), - "llm_usage": fields.Raw(attribute="llm_usage"), - "error": fields.String, - "retry_index": fields.Integer, -} - - -workflow_run_node_execution_fields = { - "id": fields.String, - "index": fields.Integer, - "predecessor_node_id": fields.String, - "node_id": fields.String, - "node_type": fields.String, - "title": fields.String, - "inputs": fields.Raw(attribute="inputs_dict"), - "process_data": fields.Raw(attribute="process_data_dict"), - "outputs": fields.Raw(attribute="outputs_dict"), - "status": fields.String, - "error": fields.String, - "elapsed_time": fields.Float, - "execution_metadata": fields.Raw(attribute="execution_metadata_dict"), - "extras": fields.Raw, - "created_at": TimestampField, - "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), - "finished_at": TimestampField, - "inputs_truncated": fields.Boolean, - "outputs_truncated": fields.Boolean, - "process_data_truncated": fields.Boolean, -} - -workflow_run_node_execution_list_fields = { - "data": fields.List(fields.Nested(workflow_run_node_execution_fields)), -} - - def _to_timestamp(value: datetime | int | None) -> int | None: if isinstance(value, datetime): return int(value.timestamp()) @@ -252,7 +146,10 @@ class WorkflowRunCountResponse(ResponseModel): succeeded: int failed: int stopped: int - partial_succeeded: int = Field(validation_alias="partial-succeeded") + partial_succeeded: int = Field( + alias="partial_succeeded", + validation_alias=AliasChoices("partial_succeeded", "partial-succeeded"), + ) class WorkflowRunDetailResponse(ResponseModel): diff --git a/api/libs/external_api.py b/api/libs/external_api.py index f907d17750..64eb99a42b 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -9,6 +9,7 @@ from werkzeug.http import HTTP_STATUS_CODES from configs import dify_config from core.errors.error import AppInvokeQuotaExceededError +from libs.flask_restx_compat import patch_swagger_for_inline_nested_dicts from libs.token import build_force_logout_cookie_headers @@ -120,6 +121,7 @@ class ExternalApi(Api): } def __init__(self, app: Blueprint | Flask, *args, **kwargs): + patch_swagger_for_inline_nested_dicts() kwargs.setdefault("authorizations", self._authorizations) kwargs.setdefault("security", "Bearer") kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED diff --git a/api/libs/flask_restx_compat.py b/api/libs/flask_restx_compat.py new file mode 100644 index 0000000000..34e0d586a0 --- /dev/null +++ b/api/libs/flask_restx_compat.py @@ -0,0 +1,149 @@ +"""Compatibility helpers for Dify's Flask-RESTX Swagger integration. + +These helpers are temporary bridges for legacy Flask-RESTX field contracts +while controllers migrate their request and response documentation to Pydantic +models. Keep the behavior centralized so live Swagger endpoints and offline +spec export fail or succeed in the same way. +""" + +import hashlib +import json +from typing import TypeGuard + +from flask import current_app +from flask_restx import fields +from flask_restx.model import Model, OrderedModel, instance +from flask_restx.swagger import Swagger + + +def _is_inline_field_map(value: object) -> TypeGuard[dict[object, object]]: + """Return whether a nested field map is an anonymous inline mapping.""" + + return isinstance(value, dict) and not isinstance(value, (Model, OrderedModel)) + + +def _jsonable_schema_value(value: object) -> object: + """Return a deterministic JSON-serializable representation for schema fingerprints.""" + + if value is None or isinstance(value, str | int | float | bool): + return value + if isinstance(value, list | tuple): + return [_jsonable_schema_value(item) for item in value] + if isinstance(value, dict): + return {str(key): _jsonable_schema_value(item) for key, item in value.items()} + value_type = type(value) + return f"<{value_type.__module__}.{value_type.__qualname__}>" + + +def _field_signature(field: object) -> object: + """Build a stable signature for a Flask-RESTX field object.""" + + field_instance = instance(field) + signature: dict[str, object] = { + "class": f"{field_instance.__class__.__module__}.{field_instance.__class__.__qualname__}" + } + + if isinstance(field_instance, fields.Nested): + nested = getattr(field_instance, "nested", None) + if _is_inline_field_map(nested): + signature["nested"] = _inline_model_signature(nested) + else: + signature["nested"] = getattr( + nested, + "name", + f"<{type(nested).__module__}.{type(nested).__qualname__}>", + ) + elif hasattr(field_instance, "container"): + signature["container"] = _field_signature(field_instance.container) + else: + schema = getattr(field_instance, "__schema__", None) + if isinstance(schema, dict): + signature["schema"] = _jsonable_schema_value(schema) + + for attr_name in ( + "attribute", + "default", + "description", + "example", + "max", + "max_items", + "min", + "min_items", + "nullable", + "readonly", + "required", + "title", + "unique", + ): + if hasattr(field_instance, attr_name): + signature[attr_name] = _jsonable_schema_value(getattr(field_instance, attr_name)) + + return signature + + +def _inline_model_signature(nested_fields: dict[object, object]) -> object: + """Build a stable signature for an anonymous inline model.""" + + return [ + (str(field_name), _field_signature(field)) + for field_name, field in sorted(nested_fields.items(), key=lambda item: str(item[0])) + ] + + +def _inline_model_name(nested_fields: dict[object, object]) -> str: + """Return a stable Swagger model name for an anonymous inline field map.""" + + signature = json.dumps(_inline_model_signature(nested_fields), sort_keys=True, separators=(",", ":")) + digest = hashlib.sha1(signature.encode("utf-8")).hexdigest()[:12] + return f"_AnonymousInlineModel_{digest}" + + +def patch_swagger_for_inline_nested_dicts() -> None: + """Allow Swagger generation to handle legacy inline Flask-RESTX field dicts. + + Some existing controllers use raw field mappings in `fields.Nested({...})` + or directly in `@namespace.response(...)`. Runtime marshalling accepts that, + but Flask-RESTX Swagger registration expects a named model. Convert those + anonymous mappings into temporary named models during docs generation. + """ + + if getattr(Swagger, "_dify_inline_nested_dict_patch", False): + return + + original_register_model = Swagger.register_model + original_register_field = Swagger.register_field + original_as_dict = Swagger.as_dict + + def get_or_create_inline_model(self: Swagger, nested_fields: dict[object, object]) -> object: + anonymous_name = _inline_model_name(nested_fields) + if anonymous_name not in self.api.models: + self.api.model(anonymous_name, nested_fields) + + return self.api.models[anonymous_name] + + def register_model_with_inline_dict_support(self: Swagger, model: object) -> dict[str, str]: + if _is_inline_field_map(model): + model = get_or_create_inline_model(self, model) + + return original_register_model(self, model) + + def register_field_with_inline_dict_support(self: Swagger, field: object) -> None: + nested = getattr(field, "nested", None) + if _is_inline_field_map(nested): + field.model = get_or_create_inline_model(self, nested) # type: ignore[attr-defined] + + original_register_field(self, field) + + def as_dict_with_inline_dict_support(self: Swagger): + # Temporary set RESTX_INCLUDE_ALL_MODELS = false to prevent "length changed while iterating" error + include_all_models = current_app.config.get("RESTX_INCLUDE_ALL_MODELS", False) + current_app.config["RESTX_INCLUDE_ALL_MODELS"] = False + try: + return original_as_dict(self) + finally: + current_app.config["RESTX_INCLUDE_ALL_MODELS"] = include_all_models + + Swagger.register_model = register_model_with_inline_dict_support + Swagger.register_field = register_field_with_inline_dict_support + Swagger.as_dict = as_dict_with_inline_dict_support + Swagger._dify_inline_nested_dict_patch = True diff --git a/api/models/dataset.py b/api/models/dataset.py index a00e9f7640..ed7727e0f1 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -24,7 +24,7 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField, Metad from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.tools.signature import sign_upload_file +from core.tools.signature import sign_upload_file_preview_url from extensions.ext_storage import storage from libs.uuid_utils import uuidv7 @@ -1020,7 +1020,7 @@ class DocumentSegment(Base): encoded_sign = base64.urlsafe_b64encode(sign).decode() params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" - reference_url = dify_config.CONSOLE_API_URL or "" + reference_url = dify_config.FILES_URL or dify_config.CONSOLE_API_URL or "" base_url = f"{reference_url}/files/{upload_file_id}/image-preview" source_url = f"{base_url}?{params}" attachment_list.append( @@ -1162,7 +1162,7 @@ class DatasetQuery(TypeBase): "size": file_info.size, "extension": file_info.extension, "mime_type": file_info.mime_type, - "source_url": sign_upload_file(file_info.id, file_info.extension), + "source_url": sign_upload_file_preview_url(file_info.id, file_info.extension), } else: query["file_info"] = None diff --git a/api/openapi/markdown/console-swagger.md b/api/openapi/markdown/console-swagger.md index f4897e93c5..f3c188fc06 100644 --- a/api/openapi/markdown/console-swagger.md +++ b/api/openapi/markdown/console-swagger.md @@ -805,18 +805,17 @@ Get advanced chat workflow run list | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | -| payload | body | | Yes | [WorkflowRunListQuery](#workflowrunlistquery) | | app_id | path | Application ID | Yes | string | | last_id | query | Last run ID for pagination | No | string | -| limit | query | Number of items per page (1-100) | No | string | -| status | query | Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded | No | string | -| triggered_from | query | Filter by trigger source (optional): debugging or app-run. Default: debugging | No | string | +| limit | query | Number of items per page (1-100) | No | integer | +| status | query | Workflow run status filter | No | string | +| triggered_from | query | Filter by trigger source: debugging or app-run. Default: debugging | No | string | ##### Responses | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Workflow runs retrieved successfully | [AdvancedChatWorkflowRunPagination](#advancedchatworkflowrunpagination) | +| 200 | Workflow runs retrieved successfully | [AdvancedChatWorkflowRunPaginationResponse](#advancedchatworkflowrunpaginationresponse) | ### /apps/{app_id}/advanced-chat/workflow-runs/count @@ -833,17 +832,16 @@ Get advanced chat workflow runs count statistics | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | -| payload | body | | Yes | [WorkflowRunCountQuery](#workflowruncountquery) | | app_id | path | Application ID | Yes | string | -| status | query | Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded | No | string | +| status | query | Workflow run status filter | No | string | | time_range | query | Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), 30m (30 minutes), 30s (30 seconds). Filters by created_at field. | No | string | -| triggered_from | query | Filter by trigger source (optional): debugging or app-run. Default: debugging | No | string | +| triggered_from | query | Filter by trigger source: debugging or app-run. Default: debugging | No | string | ##### Responses | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Workflow runs count retrieved successfully | [WorkflowRunCount](#workflowruncount) | +| 200 | Workflow runs count retrieved successfully | [WorkflowRunCountResponse](#workflowruncountresponse) | ### /apps/{app_id}/advanced-chat/workflows/draft/human-input/nodes/{node_id}/form/preview @@ -2361,18 +2359,17 @@ Get workflow run list | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | -| payload | body | | Yes | [WorkflowRunListQuery](#workflowrunlistquery) | | app_id | path | Application ID | Yes | string | | last_id | query | Last run ID for pagination | No | string | -| limit | query | Number of items per page (1-100) | No | string | -| status | query | Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded | No | string | -| triggered_from | query | Filter by trigger source (optional): debugging or app-run. Default: debugging | No | string | +| limit | query | Number of items per page (1-100) | No | integer | +| status | query | Workflow run status filter | No | string | +| triggered_from | query | Filter by trigger source: debugging or app-run. Default: debugging | No | string | ##### Responses | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Workflow runs retrieved successfully | [WorkflowRunPagination](#workflowrunpagination) | +| 200 | Workflow runs retrieved successfully | [WorkflowRunPaginationResponse](#workflowrunpaginationresponse) | ### /apps/{app_id}/workflow-runs/count @@ -2389,17 +2386,16 @@ Get workflow runs count statistics | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | -| payload | body | | Yes | [WorkflowRunCountQuery](#workflowruncountquery) | | app_id | path | Application ID | Yes | string | -| status | query | Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded | No | string | +| status | query | Workflow run status filter | No | string | | time_range | query | Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), 30m (30 minutes), 30s (30 seconds). Filters by created_at field. | No | string | -| triggered_from | query | Filter by trigger source (optional): debugging or app-run. Default: debugging | No | string | +| triggered_from | query | Filter by trigger source: debugging or app-run. Default: debugging | No | string | ##### Responses | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Workflow runs count retrieved successfully | [WorkflowRunCount](#workflowruncount) | +| 200 | Workflow runs count retrieved successfully | [WorkflowRunCountResponse](#workflowruncountresponse) | ### /apps/{app_id}/workflow-runs/tasks/{task_id}/stop @@ -2449,7 +2445,7 @@ Get workflow run detail | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Workflow run detail retrieved successfully | [WorkflowRunDetail](#workflowrundetail) | +| 200 | Workflow run detail retrieved successfully | [WorkflowRunDetailResponse](#workflowrundetailresponse) | | 404 | Workflow run not found | | ### /apps/{app_id}/workflow-runs/{run_id}/export @@ -2470,7 +2466,7 @@ Generate a download URL for an archived workflow run. | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Export URL generated | [WorkflowRunExport](#workflowrunexport) | +| 200 | Export URL generated | [WorkflowRunExportResponse](#workflowrunexportresponse) | ### /apps/{app_id}/workflow-runs/{run_id}/node-executions @@ -2494,7 +2490,7 @@ Get workflow run node execution list | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Node executions retrieved successfully | [WorkflowRunNodeExecutionList](#workflowrunnodeexecutionlist) | +| 200 | Node executions retrieved successfully | [WorkflowRunNodeExecutionListResponse](#workflowrunnodeexecutionlistresponse) | | 404 | Workflow run not found | | ### /apps/{app_id}/workflow/comments @@ -3180,7 +3176,7 @@ Get last run result for draft workflow node | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Node last run retrieved successfully | [WorkflowRunNodeExecution](#workflowrunnodeexecution) | +| 200 | Node last run retrieved successfully | [WorkflowRunNodeExecutionResponse](#workflowrunnodeexecutionresponse) | | 403 | Permission denied | | | 404 | Node last run not found | | @@ -3207,7 +3203,7 @@ Run draft workflow node | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Node run started successfully | [WorkflowRunNodeExecution](#workflowrunnodeexecution) | +| 200 | Node run started successfully | [WorkflowRunNodeExecutionResponse](#workflowrunnodeexecutionresponse) | | 403 | Permission denied | | | 404 | Node not found | | @@ -6720,9 +6716,9 @@ Get workflow run list ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Workflow runs retrieved successfully | [WorkflowRunPaginationResponse](#workflowrunpaginationresponse) | ### /rag/pipelines/{pipeline_id}/workflow-runs/tasks/{task_id}/stop @@ -6760,9 +6756,9 @@ Get workflow run detail ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Workflow run detail retrieved successfully | [WorkflowRunDetailResponse](#workflowrundetailresponse) | ### /rag/pipelines/{pipeline_id}/workflow-runs/{run_id}/node-executions @@ -6780,9 +6776,9 @@ Get workflow run node execution list ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Node executions retrieved successfully | [WorkflowRunNodeExecutionListResponse](#workflowrunnodeexecutionlistresponse) | ### /rag/pipelines/{pipeline_id}/workflows @@ -6915,9 +6911,9 @@ Set datasource variables ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Datasource variables set successfully | [WorkflowRunNodeExecutionResponse](#workflowrunnodeexecutionresponse) | ### /rag/pipelines/{pipeline_id}/workflows/draft/environment-variables @@ -6988,9 +6984,9 @@ Run draft workflow loop node ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Node last run retrieved successfully | [WorkflowRunNodeExecutionResponse](#workflowrunnodeexecutionresponse) | ### /rag/pipelines/{pipeline_id}/workflows/draft/nodes/{node_id}/run @@ -7009,9 +7005,9 @@ Run draft workflow node ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Node run started successfully | [WorkflowRunNodeExecutionResponse](#workflowrunnodeexecutionresponse) | ### /rag/pipelines/{pipeline_id}/workflows/draft/nodes/{node_id}/variables @@ -7947,6 +7943,7 @@ Get workflow pause details ##### Description +Get workflow pause details GET /console/api/workflow//pause-details Returns information about why and where the workflow is paused. @@ -7955,13 +7952,14 @@ Returns information about why and where the workflow is paused. | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | -| workflow_run_id | path | | Yes | string | +| workflow_run_id | path | Workflow run ID | Yes | string | ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Workflow pause details retrieved successfully | [WorkflowPauseDetailsResponse](#workflowpausedetailsresponse) | +| 404 | Workflow run not found | | ### /workspaces @@ -10256,31 +10254,31 @@ Get banner list | ---- | ---- | ----------- | -------- | | result | string | Operation result | Yes | -#### AdvancedChatWorkflowRunForList +#### AdvancedChatWorkflowRunForListResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| conversation_id | string | | No | -| created_at | object | | No | -| created_by_account | [SimpleAccount](#simpleaccount) | | No | -| elapsed_time | number | | No | -| exceptions_count | integer | | No | -| finished_at | object | | No | -| id | string | | No | -| message_id | string | | No | -| retry_index | integer | | No | -| status | string | | No | -| total_steps | integer | | No | -| total_tokens | integer | | No | -| version | string | | No | +| conversation_id | | | No | +| created_at | | | No | +| created_by_account | | | No | +| elapsed_time | | | No | +| exceptions_count | | | No | +| finished_at | | | No | +| id | string | | Yes | +| message_id | | | No | +| retry_index | | | No | +| status | | | No | +| total_steps | | | No | +| total_tokens | | | No | +| version | | | No | -#### AdvancedChatWorkflowRunPagination +#### AdvancedChatWorkflowRunPaginationResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| data | [ [AdvancedChatWorkflowRunForList](#advancedchatworkflowrunforlist) ] | | No | -| has_more | boolean | | No | -| limit | integer | | No | +| data | [ [AdvancedChatWorkflowRunForListResponse](#advancedchatworkflowrunforlistresponse) ] | | Yes | +| has_more | boolean | | Yes | +| limit | integer | | Yes | #### AdvancedChatWorkflowRunPayload @@ -12169,6 +12167,14 @@ Form input types. | form_inputs | object | Values the user provides for the form's own fields | Yes | | inputs | object | Values used to fill missing upstream variables referenced in form_content | Yes | +#### HumanInputPauseTypeResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| backstage_input_url | | | No | +| form_id | string | | Yes | +| type | string | | Yes | + #### IconType | Name | Type | Description | Required | @@ -13101,6 +13107,14 @@ Enum class for model type. | ---- | ---- | ----------- | -------- | | click_id | string | Click Id from partner referral link | Yes | +#### PausedNodeResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| node_id | string | | Yes | +| node_title | string | | Yes | +| pause_type | [HumanInputPauseTypeResponse](#humaninputpausetyperesponse) | | Yes | + #### Payload | Name | Type | Description | Required | @@ -14306,53 +14320,60 @@ User action configuration. | updated_at | | | No | | updated_by | | | No | -#### WorkflowRunCount +#### WorkflowPauseDetailsResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| failed | integer | | No | -| partial_succeeded | integer | | No | -| running | integer | | No | -| stopped | integer | | No | -| succeeded | integer | | No | -| total | integer | | No | +| paused_at | | | No | +| paused_nodes | [ [PausedNodeResponse](#pausednoderesponse) ] | | Yes | #### WorkflowRunCountQuery | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | | status | | Workflow run status filter | No | -| time_range | | Time range filter (e.g., 7d, 4h, 30m, 30s) | No | -| triggered_from | | Filter by trigger source: debugging or app-run | No | +| time_range | | Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), 30m (30 minutes), 30s (30 seconds). Filters by created_at field. | No | +| triggered_from | | Filter by trigger source: debugging or app-run. Default: debugging | No | -#### WorkflowRunDetail +#### WorkflowRunCountResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| created_at | object | | No | -| created_by_account | [SimpleAccount](#simpleaccount) | | No | -| created_by_end_user | [SimpleEndUser](#simpleenduser) | | No | -| created_by_role | string | | No | -| elapsed_time | number | | No | -| error | string | | No | -| exceptions_count | integer | | No | -| finished_at | object | | No | -| graph | object | | No | -| id | string | | No | -| inputs | object | | No | -| outputs | object | | No | -| status | string | | No | -| total_steps | integer | | No | -| total_tokens | integer | | No | -| version | string | | No | +| failed | integer | | Yes | +| partial_succeeded | integer | | Yes | +| running | integer | | Yes | +| stopped | integer | | Yes | +| succeeded | integer | | Yes | +| total | integer | | Yes | -#### WorkflowRunExport +#### WorkflowRunDetailResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| presigned_url | string | Pre-signed URL for download | No | -| presigned_url_expires_at | string | Pre-signed URL expiration time | No | -| status | string | Export status: success/failed | No | +| created_at | | | No | +| created_by_account | | | No | +| created_by_end_user | | | No | +| created_by_role | | | No | +| elapsed_time | | | No | +| error | | | No | +| exceptions_count | | | No | +| finished_at | | | No | +| graph | | | Yes | +| id | string | | Yes | +| inputs | | | Yes | +| outputs | | | Yes | +| status | | | No | +| total_steps | | | No | +| total_tokens | | | No | +| version | | | No | + +#### WorkflowRunExportResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| presigned_url | | Pre-signed URL for download | No | +| presigned_url_expires_at | | Pre-signed URL expiration time | No | +| status | string | Export status: success/failed | Yes | #### WorkflowRunForArchivedLogResponse @@ -14364,21 +14385,21 @@ User action configuration. | total_tokens | | | No | | triggered_from | | | No | -#### WorkflowRunForList +#### WorkflowRunForListResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| created_at | object | | No | -| created_by_account | [SimpleAccount](#simpleaccount) | | No | -| elapsed_time | number | | No | -| exceptions_count | integer | | No | -| finished_at | object | | No | -| id | string | | No | -| retry_index | integer | | No | -| status | string | | No | -| total_steps | integer | | No | -| total_tokens | integer | | No | -| version | string | | No | +| created_at | | | No | +| created_by_account | | | No | +| elapsed_time | | | No | +| exceptions_count | | | No | +| finished_at | | | No | +| id | string | | Yes | +| retry_index | | | No | +| status | | | No | +| total_steps | | | No | +| total_tokens | | | No | +| version | | | No | #### WorkflowRunForLogResponse @@ -14403,48 +14424,48 @@ User action configuration. | last_id | | Last run ID for pagination | No | | limit | integer | Number of items per page (1-100) | No | | status | | Workflow run status filter | No | -| triggered_from | | Filter by trigger source: debugging or app-run | No | +| triggered_from | | Filter by trigger source: debugging or app-run. Default: debugging | No | -#### WorkflowRunNodeExecution +#### WorkflowRunNodeExecutionListResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| created_at | object | | No | -| created_by_account | [SimpleAccount](#simpleaccount) | | No | -| created_by_end_user | [SimpleEndUser](#simpleenduser) | | No | -| created_by_role | string | | No | -| elapsed_time | number | | No | -| error | string | | No | -| execution_metadata | object | | No | -| extras | object | | No | -| finished_at | object | | No | -| id | string | | No | -| index | integer | | No | -| inputs | object | | No | -| inputs_truncated | boolean | | No | -| node_id | string | | No | -| node_type | string | | No | -| outputs | object | | No | -| outputs_truncated | boolean | | No | -| predecessor_node_id | string | | No | -| process_data | object | | No | -| process_data_truncated | boolean | | No | -| status | string | | No | -| title | string | | No | +| data | [ [WorkflowRunNodeExecutionResponse](#workflowrunnodeexecutionresponse) ] | | Yes | -#### WorkflowRunNodeExecutionList +#### WorkflowRunNodeExecutionResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| data | [ [WorkflowRunNodeExecution](#workflowrunnodeexecution) ] | | No | +| created_at | | | No | +| created_by_account | | | No | +| created_by_end_user | | | No | +| created_by_role | | | No | +| elapsed_time | | | No | +| error | | | No | +| execution_metadata | | | No | +| extras | | | No | +| finished_at | | | No | +| id | string | | Yes | +| index | | | No | +| inputs | | | No | +| inputs_truncated | | | No | +| node_id | | | No | +| node_type | | | No | +| outputs | | | No | +| outputs_truncated | | | No | +| predecessor_node_id | | | No | +| process_data | | | No | +| process_data_truncated | | | No | +| status | | | No | +| title | | | No | -#### WorkflowRunPagination +#### WorkflowRunPaginationResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| data | [ [WorkflowRunForList](#workflowrunforlist) ] | | No | -| has_more | boolean | | No | -| limit | integer | | No | +| data | [ [WorkflowRunForListResponse](#workflowrunforlistresponse) ] | | Yes | +| has_more | boolean | | Yes | +| limit | integer | | Yes | #### WorkflowRunPayload diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 0229a1f43a..aa6b8ffc6e 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -425,7 +425,7 @@ class AppAnnotationService: return {"deleted_count": deleted_count} @classmethod - def batch_import_app_annotations(cls, app_id, file: FileStorage): + def batch_import_app_annotations(cls, app_id: str, file: FileStorage): """ Batch import annotations from CSV file with enhanced security checks. diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 60948e652b..c80b2f43fd 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -54,7 +54,7 @@ class AudioService: if extension not in [f"audio/{ext}" for ext in AUDIO_EXTENSIONS]: raise UnsupportedAudioTypeServiceError() - file_content = file.read() + file_content = file.stream.read() file_size = len(file_content) if file_size > FILE_SIZE_LIMIT: diff --git a/api/services/trigger/trigger_subscription_builder_service.py b/api/services/trigger/trigger_subscription_builder_service.py index 889717df72..cff735b39d 100644 --- a/api/services/trigger/trigger_subscription_builder_service.py +++ b/api/services/trigger/trigger_subscription_builder_service.py @@ -121,9 +121,7 @@ class TriggerSubscriptionBuilderService: if not subscription_builder.name: raise ValueError("Subscription builder name is required") - credential_type = CredentialType.of( - subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value - ) + credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED) if credential_type == CredentialType.UNAUTHORIZED: # manually create TriggerProviderService.add_trigger_subscription( @@ -321,9 +319,7 @@ class TriggerSubscriptionBuilderService: raise ValueError("Subscription builder name is required") # Build - credential_type = CredentialType.of( - subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value - ) + credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED) if credential_type == CredentialType.UNAUTHORIZED: # manually create TriggerProviderService.add_trigger_subscription( diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 5d99900a04..592f678421 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -402,7 +402,7 @@ class WebhookService: for name, file in files.items(): if file and file.filename: try: - file_content = file.read() + file_content = file.stream.read() mimetype = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream" file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger) processed_files[name] = file_obj.to_dict() diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index c17a83cad3..ba59780d59 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import datetime +from types import SimpleNamespace from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -44,6 +45,35 @@ def unwrap(func): return func +def make_node_execution(**overrides): + payload = { + "id": "node-exec-1", + "index": 1, + "predecessor_node_id": None, + "node_id": "node1", + "node_type": "start", + "title": "Start", + "inputs_dict": {"query": "hello"}, + "process_data_dict": {}, + "outputs_dict": {"answer": "world"}, + "status": "succeeded", + "error": None, + "elapsed_time": 1.0, + "execution_metadata_dict": {}, + "extras": {}, + "created_at": datetime(2026, 1, 1, 0, 0, 0), + "created_by_role": "account", + "created_by_account": None, + "created_by_end_user": None, + "finished_at": datetime(2026, 1, 1, 0, 0, 1), + "inputs_truncated": False, + "outputs_truncated": False, + "process_data_truncated": False, + } + payload.update(overrides) + return SimpleNamespace(**payload) + + class TestDraftWorkflowApi: @pytest.fixture def app(self, flask_app_with_containers: Flask): @@ -743,7 +773,7 @@ class TestRagPipelineWorkflowLastRunApi: pipeline = MagicMock() workflow = MagicMock() - node_exec = MagicMock() + node_exec = make_node_execution() service = MagicMock() service.get_draft_workflow.return_value = workflow @@ -757,7 +787,9 @@ class TestRagPipelineWorkflowLastRunApi: ), ): result = method(api, pipeline, "node1") - assert result == node_exec + assert result["id"] == "node-exec-1" + assert result["inputs"] == {"query": "hello"} + assert result["outputs"] == {"answer": "world"} def test_last_run_not_found(self, app: Flask): api = RagPipelineWorkflowLastRunApi() @@ -799,7 +831,7 @@ class TestRagPipelineDatasourceVariableApi: } service = MagicMock() - service.set_datasource_variables.return_value = MagicMock() + service.set_datasource_variables.return_value = make_node_execution(node_id="n1") with ( app.test_request_context("/", json=payload), @@ -814,4 +846,5 @@ class TestRagPipelineDatasourceVariableApi: ), ): result = method(api, pipeline) - assert result is not None + assert result["node_id"] == "n1" + assert result["process_data"] == {} diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 6d5c7380b7..52b1229302 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -543,8 +543,8 @@ class TestWebhookService: "bad_file": MagicMock(filename="test.bad", content_type="text/plain"), } - files["good_file"].read.return_value = b"content" - files["bad_file"].read.side_effect = Exception("Read error") + files["good_file"].stream.read.return_value = b"content" + files["bad_file"].stream.read.side_effect = Exception("Read error") webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" diff --git a/api/tests/unit_tests/controllers/common/test_schema.py b/api/tests/unit_tests/controllers/common/test_schema.py index 575f8c839c..7cabafba0e 100644 --- a/api/tests/unit_tests/controllers/common/test_schema.py +++ b/api/tests/unit_tests/controllers/common/test_schema.py @@ -47,6 +47,10 @@ class QueryModel(BaseModel): ambiguous: int | str | None = Field(default=None, description="Ambiguous query parameter") +class ResponseAliasModel(BaseModel): + public_name: str = Field(validation_alias="internal_name") + + @pytest.fixture(autouse=True) def mock_console_ns(): """Mock the console_ns to avoid circular imports during test collection.""" @@ -146,6 +150,20 @@ def test_register_schema_models_calls_register_schema_model(monkeypatch: pytest. ] +def test_register_response_schema_model_uses_serialized_field_names(): + from controllers.common.schema import register_response_schema_model + + namespace = MagicMock(spec=Namespace) + + register_response_schema_model(namespace, ResponseAliasModel) + + model_name, schema = namespace.schema_model.call_args.args + + assert model_name == "ResponseAliasModel" + assert "public_name" in schema["properties"] + assert "internal_name" not in schema["properties"] + + def test_get_or_create_model_returns_existing_model(mock_console_ns): from controllers.common.schema import get_or_create_model diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index c4a8148446..05c17b4e34 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -112,3 +112,24 @@ def test_pause_details_tenant_isolation(app: Flask, monkeypatch: pytest.MonkeyPa with pytest.raises(NotFoundError): with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") + + +def test_pause_details_returns_empty_response_for_non_paused_run(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + account = _make_account() + _patch_console_guards(monkeypatch, account) + + workflow_run = Mock(spec=WorkflowRun) + workflow_run.tenant_id = "tenant-123" + workflow_run.status = WorkflowExecutionStatus.RUNNING + fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run)) + monkeypatch.setattr(workflow_run_module, "db", fake_db) + + with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): + response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") + + assert status == 200 + assert response == {"paused_at": None, "paused_nodes": []} + + +def test_pause_details_response_schema_is_registered() -> None: + assert workflow_run_module.WorkflowPauseDetailsResponse.__name__ in workflow_run_module.console_ns.models diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py new file mode 100644 index 0000000000..e225e31563 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any + +import pytest +from flask import Flask +from flask_restx import marshal + +from controllers.console.app import workflow_run as workflow_run_module + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def _serialize_200_response(handler, payload: Any) -> Any: + response_doc = getattr(handler, "__apidoc__", {}).get("responses", {}).get("200") + if response_doc is None: + return payload + + response_model = response_doc[1] + if isinstance(response_model, dict): + return marshal(payload, response_model) + return payload + + +def _account() -> SimpleNamespace: + return SimpleNamespace(id="account-1", name="Alice", email="alice@example.com") + + +def _workflow_run_summary(**overrides) -> SimpleNamespace: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + payload = { + "id": "run-1", + "version": "v1", + "status": "succeeded", + "elapsed_time": 1.5, + "total_tokens": 10, + "total_steps": 2, + "created_by_account": _account(), + "created_at": created_at, + "finished_at": created_at, + "exceptions_count": 0, + "retry_index": 0, + } + payload.update(overrides) + return SimpleNamespace(**payload) + + +def _workflow_run_node_execution(**overrides) -> SimpleNamespace: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + payload = { + "id": "node-exec-1", + "index": 1, + "predecessor_node_id": None, + "node_id": "node-1", + "node_type": "start", + "title": "Start", + "inputs_dict": {"query": "hello"}, + "process_data_dict": {"step": "prepared"}, + "outputs_dict": {"answer": "world"}, + "status": "succeeded", + "error": None, + "elapsed_time": 1.0, + "execution_metadata_dict": {"total_tokens": 3}, + "extras": {}, + "created_at": created_at, + "created_by_role": "account", + "created_by_account": _account(), + "created_by_end_user": None, + "finished_at": created_at, + "inputs_truncated": False, + "outputs_truncated": False, + "process_data_truncated": False, + } + payload.update(overrides) + return SimpleNamespace(**payload) + + +def test_workflow_run_list_returns_frontend_history_contract(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + class WorkflowRunService: + def get_paginate_workflow_runs(self, **_kwargs): + return { + "limit": 10, + "has_more": False, + "data": [_workflow_run_summary()], + } + + monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) + + api = workflow_run_module.WorkflowRunListApi() + handler = _unwrap(api.get) + + with app.test_request_context("/apps/app-1/workflow-runs?limit=10", method="GET"): + payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1")) + + response = _serialize_200_response(api.get, payload) + + assert response["limit"] == 10 + assert response["has_more"] is False + assert response["data"][0] == { + "id": "run-1", + "version": "v1", + "status": "succeeded", + "elapsed_time": 1.5, + "total_tokens": 10, + "total_steps": 2, + "created_by_account": {"id": "account-1", "name": "Alice", "email": "alice@example.com"}, + "created_at": 1767323045, + "finished_at": 1767323045, + "exceptions_count": 0, + "retry_index": 0, + } + + +def test_advanced_chat_workflow_run_list_keeps_message_fields(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + class WorkflowRunService: + def get_paginate_advanced_chat_workflow_runs(self, **_kwargs): + return { + "limit": 1, + "has_more": True, + "data": [ + _workflow_run_summary( + conversation_id="conversation-1", + message_id="message-1", + ) + ], + } + + monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) + + api = workflow_run_module.AdvancedChatAppWorkflowRunListApi() + handler = _unwrap(api.get) + + with app.test_request_context("/apps/app-1/advanced-chat/workflow-runs?limit=1", method="GET"): + payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1")) + + response = _serialize_200_response(api.get, payload) + + assert response["data"][0]["conversation_id"] == "conversation-1" + assert response["data"][0]["message_id"] == "message-1" + + +def test_workflow_run_detail_returns_frontend_detail_contract(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + workflow_run = SimpleNamespace( + id="run-1", + version="v1", + graph_dict={"nodes": []}, + inputs_dict={"query": "hello"}, + status="succeeded", + outputs_dict={"answer": "world"}, + error=None, + elapsed_time=1.5, + total_tokens=10, + total_steps=2, + created_by_role="account", + created_by_account=_account(), + created_by_end_user=None, + created_at=created_at, + finished_at=created_at, + exceptions_count=0, + ) + + class WorkflowRunService: + def get_workflow_run(self, **_kwargs): + return workflow_run + + monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) + + api = workflow_run_module.WorkflowRunDetailApi() + handler = _unwrap(api.get) + + with app.test_request_context("/apps/app-1/workflow-runs/run-1", method="GET"): + payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"), run_id="run-1") + + response = _serialize_200_response(api.get, payload) + + assert response == { + "id": "run-1", + "version": "v1", + "graph": {"nodes": []}, + "inputs": {"query": "hello"}, + "status": "succeeded", + "outputs": {"answer": "world"}, + "error": None, + "elapsed_time": 1.5, + "total_tokens": 10, + "total_steps": 2, + "created_by_role": "account", + "created_by_account": {"id": "account-1", "name": "Alice", "email": "alice@example.com"}, + "created_by_end_user": None, + "created_at": 1767323045, + "finished_at": 1767323045, + "exceptions_count": 0, + } + + +def test_workflow_run_node_executions_return_frontend_trace_contract( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: + class WorkflowRunService: + def get_workflow_run_node_executions(self, **_kwargs): + return [_workflow_run_node_execution()] + + monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) + monkeypatch.setattr(workflow_run_module, "current_user", SimpleNamespace(id="account-1")) + + api = workflow_run_module.WorkflowRunNodeExecutionListApi() + handler = _unwrap(api.get) + + with app.test_request_context("/apps/app-1/workflow-runs/run-1/node-executions", method="GET"): + payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"), run_id="run-1") + + response = _serialize_200_response(api.get, payload) + + assert response == { + "data": [ + { + "id": "node-exec-1", + "index": 1, + "predecessor_node_id": None, + "node_id": "node-1", + "node_type": "start", + "title": "Start", + "inputs": {"query": "hello"}, + "process_data": {"step": "prepared"}, + "outputs": {"answer": "world"}, + "status": "succeeded", + "error": None, + "elapsed_time": 1.0, + "execution_metadata": {"total_tokens": 3}, + "extras": {}, + "created_at": 1767323045, + "created_by_role": "account", + "created_by_account": {"id": "account-1", "name": "Alice", "email": "alice@example.com"}, + "created_by_end_user": None, + "finished_at": 1767323045, + "inputs_truncated": False, + "outputs_truncated": False, + "process_data_truncated": False, + } + ] + } diff --git a/api/tests/unit_tests/controllers/files/test_upload.py b/api/tests/unit_tests/controllers/files/test_upload.py index e8f3cd4b66..ff6ba0e9a1 100644 --- a/api/tests/unit_tests/controllers/files/test_upload.py +++ b/api/tests/unit_tests/controllers/files/test_upload.py @@ -1,3 +1,4 @@ +import io import types from unittest.mock import patch @@ -30,9 +31,10 @@ class DummyFile: self.filename = filename self.mimetype = mimetype self._content = content + self.stream = io.BytesIO(content) def read(self): - return self._content + return self.stream.read() class DummyToolFile: diff --git a/api/tests/unit_tests/controllers/test_swagger.py b/api/tests/unit_tests/controllers/test_swagger.py new file mode 100644 index 0000000000..999f1ae78d --- /dev/null +++ b/api/tests/unit_tests/controllers/test_swagger.py @@ -0,0 +1,72 @@ +"""Swagger JSON rendering tests for Flask-RESTX API blueprints.""" + +import pytest +from flask import Flask + + +def _definition_refs(value: object) -> set[str]: + refs: set[str] = set() + if isinstance(value, dict): + ref = value.get("$ref") + if isinstance(ref, str) and ref.startswith("#/definitions/"): + refs.add(ref.removeprefix("#/definitions/")) + for item in value.values(): + refs.update(_definition_refs(item)) + elif isinstance(value, list): + for item in value: + refs.update(_definition_refs(item)) + return refs + + +@pytest.mark.parametrize( + ("first_kwargs", "second_kwargs"), + [ + ({"min_items": 1}, {"min_items": 2}), + ({"max_items": 1}, {"max_items": 2}), + ({"unique": True}, {"unique": False}), + ], +) +def test_inline_model_name_includes_list_constraints( + first_kwargs: dict[str, object], + second_kwargs: dict[str, object], +): + from flask_restx import fields + + from libs.flask_restx_compat import _inline_model_name + + first_inline_model: dict[object, object] = {"items": fields.List(fields.String, **first_kwargs)} + second_inline_model: dict[object, object] = {"items": fields.List(fields.String, **second_kwargs)} + + assert _inline_model_name(first_inline_model) != _inline_model_name(second_inline_model) + + +def test_swagger_json_endpoints_render(monkeypatch: pytest.MonkeyPatch): + from configs import dify_config + from controllers.console import bp as console_bp + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp + + monkeypatch.setattr(dify_config, "SWAGGER_UI_ENABLED", True) + + app = Flask(__name__) + app.config["TESTING"] = True + app.config["RESTX_INCLUDE_ALL_MODELS"] = True + app.register_blueprint(console_bp) + app.register_blueprint(web_bp) + app.register_blueprint(service_api_bp) + + client = app.test_client() + + for route in ("/console/api/swagger.json", "/api/swagger.json", "/v1/swagger.json"): + response = client.get(route) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["swagger"] == "2.0" + assert "paths" in payload + assert "definitions" in payload + assert isinstance(payload["definitions"], dict) + missing_refs = _definition_refs(payload) - set(payload["definitions"]) + assert not sorted(ref for ref in missing_refs if ref.startswith("_AnonymousInlineModel")) + + assert app.config["RESTX_INCLUDE_ALL_MODELS"] is True diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py index 3fd21ab22b..62e1d22129 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py @@ -1,3 +1,4 @@ +from collections import UserString from unittest.mock import MagicMock import pytest @@ -12,21 +13,25 @@ from core.app.app_config.easy_ui_based_app.prompt_template.manager import ( # ----------------------------- -class DummyEnumValue: +class DummyEnumValue(UserString): def __init__(self, value): + super().__init__(value) self.value = value class DummyPromptType: def __init__(self): - self.SIMPLE = "simple" - self.ADVANCED = "advanced" + self.SIMPLE = DummyEnumValue("simple") + self.ADVANCED = DummyEnumValue("advanced") def value_of(self, value): - return value + for enum_value in self: + if enum_value.value == value: + return enum_value + raise ValueError(f"invalid prompt type value {value}") def __iter__(self): - return iter([DummyEnumValue("simple"), DummyEnumValue("advanced")]) + return iter([self.SIMPLE, self.ADVANCED]) # ----------------------------- diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py index f459250b8e..72c24bda96 100644 --- a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -198,6 +198,48 @@ class TestBuildPromptMessageWithFiles: assert isinstance(result.content[-1], TextPromptMessageContent) assert result.content[-1].data == "user text" + def test_replay_does_not_pass_config_to_file_factory(self): + """Replay contract: history files were validated on upload, so this + path must not forward a FileUploadConfig. The factory's signature + no longer accepts ``config``; this test guards against a future + regression that re-introduces it.""" + conv = _make_conversation(AppMode.CHAT) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_file_extra_config = MagicMock() + mock_file_extra_config.image_config = None + + real_image_content = ImagePromptMessageContent( + url="http://example.com/img.png", format="png", mime_type="image/png" + ) + mock_app_record = MagicMock() + mock_app_record.tenant_id = "tenant-1" + + with ( + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ), + patch( + "core.memory.token_buffer_memory.file_factory.build_from_message_file", + return_value=MagicMock(), + ) as mock_build, + patch( + "core.memory.token_buffer_memory.file_manager.to_prompt_message_content", + return_value=real_image_content, + ), + ): + mem._build_prompt_message_with_files( + message_files=[MagicMock()], + text_content="user text", + message=_make_message(), + app_record=mock_app_record, + is_user_message=True, + ) + + mock_build.assert_called_once() + assert "config" not in mock_build.call_args.kwargs + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) def test_chat_mode_with_files_assistant_message(self, mode): """When files are present, returns AssistantPromptMessage with list content.""" diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index d38213dd89..f72351ffa2 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -1038,7 +1038,7 @@ class TestRetrievalServiceInternals: assert any(doc.metadata["doc_id"] == "processed-doc" for doc in all_documents) processor_instance.invoke.assert_called_once() - @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + @patch("core.rag.datasource.retrieval_service.sign_upload_file_preview_url", return_value="signed://file") def test_get_segment_attachment_info_success(self, mock_sign): upload_file = SimpleNamespace( id="upload-1", @@ -1118,7 +1118,7 @@ class TestRetrievalServiceInternals: assert result == [] - @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + @patch("core.rag.datasource.retrieval_service.sign_upload_file_preview_url", return_value="signed://file") def test_get_segment_attachment_infos_success(self, mock_sign): upload_file_1 = SimpleNamespace( id="upload-1", diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index b556ddf528..9334ad9b2f 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -4562,7 +4562,7 @@ class TestRetrieveCoverage: "core.rag.retrieval.dataset_retrieval.RetrievalService.format_retrieval_documents", return_value=[record], ), - patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"), + patch("core.rag.retrieval.dataset_retrieval.sign_upload_file_preview_url", return_value="https://signed"), patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute, ): bound_model_instance = Mock() diff --git a/api/tests/unit_tests/core/tools/test_signature.py b/api/tests/unit_tests/core/tools/test_signature.py index 353988d7a6..a75fdee908 100644 --- a/api/tests/unit_tests/core/tools/test_signature.py +++ b/api/tests/unit_tests/core/tools/test_signature.py @@ -9,7 +9,7 @@ import pytest from core.tools.signature import ( get_signed_file_url_for_plugin, sign_tool_file, - sign_upload_file, + sign_upload_file_preview_url, verify_plugin_file_signature, verify_tool_file_signature, ) @@ -89,32 +89,32 @@ def test_verify_tool_file_signature_rejects_expired_signature(monkeypatch: pytes assert verify_tool_file_signature("tool-file-id", timestamp, nonce, sign) is False -def test_sign_upload_file_prefers_internal_url(monkeypatch: pytest.MonkeyPatch) -> None: +def test_sign_upload_file_preview_url_uses_files_url(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x03" * 16) monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") - url = sign_upload_file("upload-id", ".png") + url = sign_upload_file_preview_url("upload-id", ".png") parsed = urlparse(url) query = parse_qs(parsed.query) - assert parsed.netloc == "internal.example.com" + assert parsed.netloc == "files.example.com" assert parsed.path == "/files/upload-id/image-preview" assert query["timestamp"][0] assert query["nonce"][0] assert query["sign"][0] -def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatch) -> None: +def test_sign_upload_file_preview_url_ignores_internal_files_url(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x05" * 16) monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") - monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") - url = sign_upload_file("upload-id", ".png") + url = sign_upload_file_preview_url("upload-id", ".png") parsed = urlparse(url) query = parse_qs(parsed.query) diff --git a/api/tests/unit_tests/factories/test_file_validation.py b/api/tests/unit_tests/factories/test_file_validation.py new file mode 100644 index 0000000000..61337fcf10 --- /dev/null +++ b/api/tests/unit_tests/factories/test_file_validation.py @@ -0,0 +1,159 @@ +"""Unit tests for is_file_valid_with_config.""" + +from __future__ import annotations + +import pytest + +from factories.file_factory.validation import is_file_valid_with_config +from graphon.file import FileTransferMethod, FileType, FileUploadConfig + + +def _validate( + *, + input_file_type: str, + file_extension: str = ".png", + file_transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE, + config: FileUploadConfig, +) -> bool: + return is_file_valid_with_config( + input_file_type=input_file_type, + file_extension=file_extension, + file_transfer_method=file_transfer_method, + config=config, + ) + + +@pytest.mark.parametrize( + ("input_file_type", "file_extension", "allowed_file_types", "allowed_file_extensions", "expected"), + [ + # round-1 happy path: literal "custom" mapping, ext whitelisted + ("custom", ".png", [FileType.CUSTOM], [".png"], True), + # round-2 replay: MessageFile.type is the resolved type, but config still allows CUSTOM + ("image", ".png", [FileType.CUSTOM], [".png"], True), + ("document", ".pdf", [FileType.CUSTOM], [".pdf"], True), + # mixed bucket [IMAGE, CUSTOM]: document falls into CUSTOM bucket via extension + ("document", ".pdf", [FileType.IMAGE, FileType.CUSTOM], [".pdf"], True), + ("document", ".exe", [FileType.IMAGE, FileType.CUSTOM], [".pdf"], False), + ("image", ".jpg", [FileType.IMAGE], [], True), + ("video", ".mp4", [FileType.IMAGE, FileType.DOCUMENT], [], False), + ("custom", ".exe", [FileType.CUSTOM], [".png"], False), + # empty allowed_file_types == no type restriction + ("video", ".mp4", [], [], True), + ], +) +def test_bucket_semantics(input_file_type, file_extension, allowed_file_types, allowed_file_extensions, expected): + config = FileUploadConfig( + allowed_file_types=allowed_file_types, + allowed_file_extensions=allowed_file_extensions, + ) + assert _validate(input_file_type=input_file_type, file_extension=file_extension, config=config) is expected + + +@pytest.mark.parametrize("whitelist_entry", [".png", ".PNG", "png", "PNG", " .Png ", "PnG"]) +def test_extension_match_is_case_and_dot_insensitive(whitelist_entry): + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=[whitelist_entry], + ) + assert _validate(input_file_type="custom", file_extension=".png", config=config) is True + + +def test_extension_mismatch_still_rejected_after_normalization(): + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=[".png", ".jpg"], + ) + assert _validate(input_file_type="custom", file_extension=".pdf", config=config) is False + + +def test_mixed_case_whitelist_replicating_real_user_config(): + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=[".PNG", "png", "JPG", ".WEBP", "SVG", "GIF"], + ) + for ext in (".png", ".jpg", ".webp", ".svg", ".gif"): + assert _validate(input_file_type="custom", file_extension=ext, config=config) is True + + +def test_tool_file_always_passes(): + config = FileUploadConfig(allowed_file_types=[FileType.CUSTOM], allowed_file_extensions=[".pdf"]) + assert ( + _validate( + input_file_type="image", + file_extension=".png", + file_transfer_method=FileTransferMethod.TOOL_FILE, + config=config, + ) + is True + ) + + +def test_transfer_method_gate_for_non_image(): + config = FileUploadConfig( + allowed_file_types=[FileType.DOCUMENT], + allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE], + ) + assert ( + _validate( + input_file_type="document", + file_extension=".pdf", + file_transfer_method=FileTransferMethod.LOCAL_FILE, + config=config, + ) + is True + ) + assert ( + _validate( + input_file_type="document", + file_extension=".pdf", + file_transfer_method=FileTransferMethod.REMOTE_URL, + config=config, + ) + is False + ) + + +def test_history_replay_matches_round_1_outcome_under_unchanged_config(): + """A file that passes round 1 must pass history replay when config is unchanged.""" + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=[".png"], + ) + assert _validate(input_file_type="custom", file_extension=".png", config=config) is True + assert _validate(input_file_type="image", file_extension=".png", config=config) is True + + +def test_empty_whitelist_in_custom_bucket_denies_by_default(): + """Defensive: when a file lands in the CUSTOM bucket, an empty + allowed_file_extensions list rejects. The UI never submits empty; + this guards DSL / API paths that bypass the UI from accidentally + widening what's accepted.""" + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=[], + ) + assert _validate(input_file_type="custom", file_extension=".png", config=config) is False + assert _validate(input_file_type="image", file_extension=".png", config=config) is False + + +def test_normalize_handles_whitespace_and_empty_consistently(): + """Whitespace-only or empty entries in the whitelist must not match real + extensions (regression guard for _normalize_extension edge cases).""" + for noisy_entry in ("", " ", "\t"): + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=[noisy_entry], + ) + assert _validate(input_file_type="custom", file_extension=".png", config=config) is False + + +def test_empty_extension_does_not_spuriously_match_empty_whitelist_entry(): + """Defensive: even if the whitelist contains an empty / whitespace entry + (e.g., a stray comma in DSL), an extensionless file must not pass via + a both-sides-empty match. Real entries in the same whitelist still match.""" + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=["", ".png"], + ) + assert _validate(input_file_type="custom", file_extension=".png", config=config) is True + assert _validate(input_file_type="custom", file_extension="", config=config) is False diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py index 51d95c4239..3f14ebe8bf 100644 --- a/api/tests/unit_tests/models/test_dataset_models.py +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -12,7 +12,9 @@ This test suite covers: import json import pickle from datetime import UTC, datetime +from types import SimpleNamespace from unittest.mock import Mock, patch +from urllib.parse import parse_qs, urlparse from uuid import uuid4 from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -676,6 +678,51 @@ class TestDocumentSegmentIndexing: # Assert assert segment.hit_count == 5 + def test_document_segment_attachments_prefers_files_url_for_source_url(self, monkeypatch): + """Test attachment source URLs use FILES_URL before falling back to CONSOLE_API_URL.""" + # Arrange + segment = DocumentSegment( + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="document-1", + position=1, + content="Test", + word_count=1, + tokens=2, + created_by="user-1", + ) + segment.id = "segment-1" + attachment = SimpleNamespace( + id="upload-1", + name="image.png", + size=128, + extension="png", + mime_type="image/png", + ) + + monkeypatch.setattr("models.dataset.time.time", lambda: 1700000000) + monkeypatch.setattr("models.dataset.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("models.dataset.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("models.dataset.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("models.dataset.dify_config.CONSOLE_API_URL", "https://console.example.com") + + with patch("models.dataset.db") as mock_db: + mock_db.session.execute.return_value.all.return_value = [(Mock(), attachment)] + + # Act + attachments = segment.attachments + + # Assert + assert len(attachments) == 1 + source_url = attachments[0]["source_url"] + parsed = urlparse(source_url) + query = parse_qs(parsed.query) + assert parsed.netloc == "files.example.com" + assert parsed.path == "/files/upload-1/image-preview" + assert query["timestamp"] == ["1700000000"] + assert query["nonce"] == ["01010101010101010101010101010101"] + assert query["sign"][0] + def test_document_segment_error_tracking(self): """Test document segment error tracking.""" # Arrange diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 83258fd1b7..5d148974f8 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -173,7 +173,8 @@ class AudioServiceTestDataFactory: file = Mock(spec=FileStorage) file.filename = filename file.mimetype = mimetype - file.read = Mock(return_value=content) + file.stream = Mock() + file.stream.read = Mock(return_value=content) for key, value in kwargs.items(): setattr(file, key, value) return file @@ -216,7 +217,7 @@ class TestAudioServiceASR: """Test speech-to-text (ASR) operations.""" @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory): + def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory: AudioServiceTestDataFactory): """Test successful ASR transcription in CHAT mode.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -241,7 +242,9 @@ class TestAudioServiceASR: mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123") @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory): + def test_transcript_asr_success_advanced_chat_mode( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test successful ASR transcription in ADVANCED_CHAT mode.""" # Arrange workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": True}}) @@ -263,7 +266,7 @@ class TestAudioServiceASR: # Assert assert result == {"text": "Workflow transcribed text"} - def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory): + def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error when speech-to-text is disabled in CHAT mode.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": False}) @@ -277,7 +280,9 @@ class TestAudioServiceASR: with pytest.raises(ValueError, match="Speech to text is not enabled"): AudioService.transcript_asr(app_model=app, file=file) - def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode(self, factory): + def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode( + self, factory: AudioServiceTestDataFactory + ): """Test that ASR raises error when speech-to-text is disabled in WORKFLOW mode.""" # Arrange workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": False}}) @@ -291,7 +296,7 @@ class TestAudioServiceASR: with pytest.raises(ValueError, match="Speech to text is not enabled"): AudioService.transcript_asr(app_model=app, file=file) - def test_transcript_asr_raises_error_when_workflow_missing(self, factory): + def test_transcript_asr_raises_error_when_workflow_missing(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error when workflow is missing in WORKFLOW mode.""" # Arrange app = factory.create_app_mock( @@ -304,7 +309,7 @@ class TestAudioServiceASR: with pytest.raises(ValueError, match="Speech to text is not enabled"): AudioService.transcript_asr(app_model=app, file=file) - def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory): + def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error when no file is uploaded.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -317,7 +322,7 @@ class TestAudioServiceASR: with pytest.raises(NoAudioUploadedServiceError): AudioService.transcript_asr(app_model=app, file=None) - def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory): + def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error for unsupported audio file types.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -331,7 +336,7 @@ class TestAudioServiceASR: with pytest.raises(UnsupportedAudioTypeServiceError): AudioService.transcript_asr(app_model=app, file=file) - def test_transcript_asr_raises_error_for_large_file(self, factory): + def test_transcript_asr_raises_error_for_large_file(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error when file exceeds size limit (30MB).""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -348,7 +353,9 @@ class TestAudioServiceASR: AudioService.transcript_asr(app_model=app, file=file) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): + def test_transcript_asr_raises_error_when_no_model_instance( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test that ASR raises error when no model instance is available.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -371,7 +378,7 @@ class TestAudioServiceTTS: """Test text-to-speech (TTS) operations.""" @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory): + def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory: AudioServiceTestDataFactory): """Test successful TTS with text input.""" # Arrange app_model_config = factory.create_app_model_config_mock( @@ -405,7 +412,7 @@ class TestAudioServiceTTS: ) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory): + def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory: AudioServiceTestDataFactory): """Test TTS uses default voice when none specified.""" # Arrange app_model_config = factory.create_app_model_config_mock( @@ -435,7 +442,9 @@ class TestAudioServiceTTS: assert call_args.kwargs["voice"] == "default-voice" @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory): + def test_transcript_tts_gets_first_available_voice_when_none_configured( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test TTS gets first available voice when none is configured.""" # Arrange app_model_config = factory.create_app_model_config_mock( @@ -467,7 +476,7 @@ class TestAudioServiceTTS: @patch("services.audio_service.WorkflowService", autospec=True) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_workflow_mode_with_draft( - self, mock_model_manager_class, mock_workflow_service_class, factory + self, mock_model_manager_class, mock_workflow_service_class, factory: AudioServiceTestDataFactory ): """Test TTS in WORKFLOW mode with draft workflow.""" # Arrange @@ -499,7 +508,7 @@ class TestAudioServiceTTS: assert result == b"draft audio" mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app) - def test_transcript_tts_raises_error_when_text_missing(self, factory): + def test_transcript_tts_raises_error_when_text_missing(self, factory: AudioServiceTestDataFactory): """Test that TTS raises error when text is missing.""" # Arrange app = factory.create_app_mock() @@ -509,7 +518,9 @@ class TestAudioServiceTTS: AudioService.transcript_tts(app_model=app, text=None) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory): + def test_transcript_tts_raises_error_when_no_voices_available( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test that TTS raises error when no voices are available.""" # Arrange app_model_config = factory.create_app_model_config_mock( @@ -535,7 +546,7 @@ class TestAudioServiceTTSVoices: """Test TTS voice listing operations.""" @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_voices_success(self, mock_model_manager_class, factory): + def test_transcript_tts_voices_success(self, mock_model_manager_class, factory: AudioServiceTestDataFactory): """Test successful retrieval of TTS voices.""" # Arrange tenant_id = "tenant-123" @@ -560,7 +571,9 @@ class TestAudioServiceTTSVoices: mock_model_instance.get_tts_voices.assert_called_once_with(language) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): + def test_transcript_tts_voices_raises_error_when_no_model_instance( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test that TTS voices raises error when no model instance is available.""" # Arrange tenant_id = "tenant-123" @@ -575,7 +588,9 @@ class TestAudioServiceTTSVoices: AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory): + def test_transcript_tts_voices_propagates_exceptions( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test that TTS voices propagates exceptions from model instance.""" # Arrange tenant_id = "tenant-123" diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 95edc436d7..a2b56fe777 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -268,8 +268,8 @@ class TestWebhookServiceUnit: } # Mock file reads - files["file1"].read.return_value = b"content1" - files["file2"].read.return_value = b"content2" + files["file1"].stream.read.return_value = b"content1" + files["file2"].stream.read.return_value = b"content2" webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" @@ -304,8 +304,8 @@ class TestWebhookServiceUnit: "bad_file": MagicMock(filename="test.bad", content_type="text/plain"), } - files["good_file"].read.return_value = b"content" - files["bad_file"].read.side_effect = Exception("Read error") + files["good_file"].stream.read.return_value = b"content" + files["bad_file"].stream.read.side_effect = Exception("Read error") webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" diff --git a/eslint-suppressions.json b/eslint-suppressions.json index 2326e92d2f..2de84456ee 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -159,21 +159,11 @@ "count": 5 } }, - "web/app/account/(commonLayout)/delete-account/components/feed-back.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/account/(commonLayout)/delete-account/components/verify-email.tsx": { "react/set-state-in-effect": { "count": 1 } }, - "web/app/account/(commonLayout)/delete-account/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/account/oauth/authorize/layout.tsx": { "ts/no-explicit-any": { "count": 1 @@ -202,18 +192,10 @@ "count": 1 } }, - "web/app/components/app/annotation/add-annotation-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/annotation/batch-add-annotation-modal/index.tsx": { "erasable-syntax-only/enums": { "count": 1 }, - "no-restricted-imports": { - "count": 1 - }, "react-refresh/only-export-components": { "count": 1 }, @@ -235,11 +217,6 @@ "count": 1 } }, - "web/app/components/app/annotation/edit-annotation-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/annotation/header-opts/index.tsx": { "ts/no-explicit-any": { "count": 1 @@ -262,9 +239,6 @@ "erasable-syntax-only/enums": { "count": 1 }, - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 5 }, @@ -272,16 +246,16 @@ "count": 1 } }, + "web/app/components/app/app-access-control/add-member-or-group-pop.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/app/app-publisher/features-wrapper.tsx": { "ts/no-explicit-any": { "count": 4 } }, - "web/app/components/app/app-publisher/version-info-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/configuration/base/var-highlight/index.tsx": { "react-refresh/only-export-components": { "count": 1 @@ -293,9 +267,6 @@ } }, "web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -311,9 +282,6 @@ } }, "web/app/components/app/configuration/config-var/config-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 4 } @@ -337,9 +305,6 @@ } }, "web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react-hooks/exhaustive-deps": { "count": 1 }, @@ -356,9 +321,6 @@ } }, "web/app/components/app/configuration/config/automatic/get-automatic-res.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 4 }, @@ -387,9 +349,6 @@ } }, "web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 4 }, @@ -397,16 +356,6 @@ "count": 2 } }, - "web/app/components/app/configuration/configuration-view.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/app/configuration/dataset-config/card-item/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/configuration/dataset-config/index.tsx": { "ts/no-explicit-any": { "count": 1 @@ -418,9 +367,6 @@ } }, "web/app/components/app/configuration/dataset-config/params-config/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 1 } @@ -494,26 +440,10 @@ "count": 1 } }, - "web/app/components/app/create-app-modal/index.tsx": { - "react/set-state-in-effect": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/app/create-from-dsl-modal/dsl-confirm-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/create-from-dsl-modal/index.tsx": { "erasable-syntax-only/enums": { "count": 1 }, - "no-restricted-imports": { - "count": 1 - }, "react-refresh/only-export-components": { "count": 1 }, @@ -521,11 +451,6 @@ "count": 2 } }, - "web/app/components/app/duplicate-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/log/filter.tsx": { "react-refresh/only-export-components": { "count": 1 @@ -537,9 +462,6 @@ } }, "web/app/components/app/log/list.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 6 }, @@ -561,9 +483,6 @@ } }, "web/app/components/app/switch-app-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 1 } @@ -584,9 +503,6 @@ } }, "web/app/components/app/workflow-log/list.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 2 } @@ -881,11 +797,6 @@ "count": 3 } }, - "web/app/components/base/content-dialog/index.stories.tsx": { - "react/set-state-in-effect": { - "count": 1 - } - }, "web/app/components/base/date-and-time-picker/hooks.ts": { "react/no-unnecessary-use-prefix": { "count": 2 @@ -901,26 +812,6 @@ "count": 1 } }, - "web/app/components/base/dialog/index.stories.tsx": { - "react/set-state-in-effect": { - "count": 1 - } - }, - "web/app/components/base/drawer-plus/index.stories.tsx": { - "react/component-hook-factories": { - "count": 1 - } - }, - "web/app/components/base/drawer-plus/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/base/emoji-picker/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/base/error-boundary/index.tsx": { "react-refresh/only-export-components": { "count": 3 @@ -942,11 +833,6 @@ "count": 1 } }, - "web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/base/features/new-feature-panel/annotation-reply/index.tsx": { "ts/no-explicit-any": { "count": 3 @@ -973,9 +859,6 @@ } }, "web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 2 } @@ -1031,11 +914,6 @@ "count": 3 } }, - "web/app/components/base/float-right-container/index.tsx": { - "no-restricted-imports": { - "count": 2 - } - }, "web/app/components/base/form/components/base/base-form.tsx": { "ts/no-explicit-any": { "count": 6 @@ -1051,14 +929,6 @@ "count": 1 } }, - "web/app/components/base/form/components/field/variable-or-constant-input.tsx": { - "no-console": { - "count": 2 - }, - "ts/no-explicit-any": { - "count": 2 - } - }, "web/app/components/base/form/components/field/variable-selector.tsx": { "no-console": { "count": 1 @@ -1552,16 +1422,6 @@ "count": 1 } }, - "web/app/components/base/modal-like-wrap/index.stories.tsx": { - "no-console": { - "count": 3 - } - }, - "web/app/components/base/modal/index.stories.tsx": { - "react/set-state-in-effect": { - "count": 1 - } - }, "web/app/components/base/new-audio-button/index.tsx": { "ts/no-explicit-any": { "count": 1 @@ -1605,7 +1465,12 @@ }, "web/app/components/base/prompt-editor/index.tsx": { "ts/no-explicit-any": { - "count": 4 + "count": 3 + } + }, + "web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx": { + "no-restricted-imports": { + "count": 1 } }, "web/app/components/base/prompt-editor/plugins/component-picker-block/menu.tsx": { @@ -1693,8 +1558,8 @@ } }, "web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx": { - "ts/no-explicit-any": { - "count": 2 + "no-restricted-imports": { + "count": 1 } }, "web/app/components/base/prompt-editor/plugins/update-block.tsx": { @@ -1848,11 +1713,6 @@ "count": 4 } }, - "web/app/components/billing/annotation-full/modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/billing/billing-page/__tests__/index.spec.tsx": { "ts/no-explicit-any": { "count": 4 @@ -1916,11 +1776,6 @@ "count": 1 } }, - "web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/dsl-confirm-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts": { "erasable-syntax-only/enums": { "count": 1 @@ -1929,9 +1784,6 @@ "web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx": { "no-barrel-files/no-barrel-files": { "count": 1 - }, - "no-restricted-imports": { - "count": 1 } }, "web/app/components/datasets/create-from-pipeline/list/template-card/details/types.ts": { @@ -1939,16 +1791,6 @@ "count": 1 } }, - "web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/create/file-preview/index.tsx": { "react/set-state-in-effect": { "count": 1 @@ -2003,11 +1845,6 @@ "count": 1 } }, - "web/app/components/datasets/create/stop-embedding-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/create/website/firecrawl/index.tsx": { "no-console": { "count": 1 @@ -2066,11 +1903,6 @@ "count": 4 } }, - "web/app/components/datasets/documents/components/rename-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/__tests__/index.spec.tsx": { "erasable-syntax-only/enums": { "count": 1 @@ -2141,11 +1973,6 @@ "count": 1 } }, - "web/app/components/datasets/documents/detail/completed/common/regeneration-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/documents/detail/completed/components/segment-list-content.tsx": { "ts/no-non-null-asserted-optional-chain": { "count": 1 @@ -2220,30 +2047,27 @@ "count": 1 } }, + "web/app/components/datasets/formatted-text/flavours/edit-slice.tsx": { + "no-restricted-imports": { + "count": 2 + } + }, + "web/app/components/datasets/formatted-text/flavours/preview-slice.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/datasets/formatted-text/flavours/type.ts": { "ts/no-empty-object-type": { "count": 1 } }, - "web/app/components/datasets/hit-testing/components/chunk-detail-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/datasets/hit-testing/components/result-item-external.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/hit-testing/components/score.tsx": { "unicorn/prefer-number-properties": { "count": 1 } }, "web/app/components/datasets/hit-testing/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/unsupported-syntax": { "count": 1 } @@ -2253,11 +2077,6 @@ "count": 2 } }, - "web/app/components/datasets/metadata/edit-metadata-batch/modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/metadata/hooks/use-edit-dataset-metadata.ts": { "react/set-state-in-effect": { "count": 1 @@ -2271,47 +2090,19 @@ "count": 2 } }, - "web/app/components/datasets/metadata/metadata-dataset/create-content.tsx": { - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/datasets/metadata/metadata-dataset/create-metadata-modal.tsx": { - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer.tsx": { - "no-restricted-imports": { - "count": 2 - } - }, - "web/app/components/datasets/metadata/metadata-dataset/select-metadata-modal.tsx": { - "erasable-syntax-only/enums": { - "count": 1 - } - }, "web/app/components/datasets/metadata/types.ts": { "erasable-syntax-only/enums": { "count": 2 } }, - "web/app/components/datasets/rename-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/settings/chunk-structure/types.ts": { "erasable-syntax-only/enums": { "count": 1 } }, "web/app/components/develop/code.tsx": { - "ts/no-empty-object-type": { - "count": 1 - }, "ts/no-explicit-any": { - "count": 9 + "count": 7 } }, "web/app/components/develop/md.tsx": { @@ -2322,16 +2113,6 @@ "count": 2 } }, - "web/app/components/develop/secret-key/secret-key-generate.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/develop/secret-key/secret-key-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/explore/banner/banner-item.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -2353,11 +2134,6 @@ "count": 1 } }, - "web/app/components/explore/try-app/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/explore/try-app/tab.tsx": { "erasable-syntax-only/enums": { "count": 1 @@ -2437,16 +2213,6 @@ "count": 1 } }, - "web/app/components/header/account-about/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/header/account-setting/api-based-extension-page/modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/header/account-setting/data-source-page-new/card.tsx": { "ts/no-explicit-any": { "count": 2 @@ -2516,6 +2282,11 @@ "count": 4 } }, + "web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/header/account-setting/model-provider-page/model-auth/hooks/index.ts": { "no-barrel-files/no-barrel-files": { "count": 6 @@ -2587,9 +2358,6 @@ } }, "web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 1 }, @@ -2631,9 +2399,6 @@ "erasable-syntax-only/enums": { "count": 1 }, - "no-restricted-imports": { - "count": 1 - }, "react-refresh/only-export-components": { "count": 1 } @@ -2649,9 +2414,6 @@ } }, "web/app/components/plugins/install-plugin/install-from-github/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 3 } @@ -2661,21 +2423,6 @@ "count": 1 } }, - "web/app/components/plugins/install-plugin/install-from-local-package/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/plugins/install-plugin/install-from-local-package/steps/uploading.tsx": { - "ts/no-explicit-any": { - "count": 2 - } - }, - "web/app/components/plugins/install-plugin/install-from-marketplace/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/marketplace/hooks.ts": { "@tanstack/query/exhaustive-deps": { "count": 1 @@ -2686,6 +2433,11 @@ "count": 1 } }, + "web/app/components/plugins/plugin-auth/authorized/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/plugins/plugin-auth/authorized/item.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2761,18 +2513,10 @@ } }, "web/app/components/plugins/plugin-detail-panel/endpoint-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 7 } }, - "web/app/components/plugins/plugin-detail-panel/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/plugin-detail-panel/model-list.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2794,9 +2538,6 @@ } }, "web/app/components/plugins/plugin-detail-panel/strategy-detail.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 2 } @@ -2839,15 +2580,22 @@ "count": 7 } }, + "web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-base-form.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/index.ts": { "no-barrel-files/no-barrel-files": { "count": 2 } }, - "web/app/components/plugins/plugin-detail-panel/trigger/event-detail-drawer.tsx": { + "web/app/components/plugins/plugin-detail-panel/tool-selector/index.tsx": { "no-restricted-imports": { "count": 1 - }, + } + }, + "web/app/components/plugins/plugin-detail-panel/trigger/event-detail-drawer.tsx": { "ts/no-explicit-any": { "count": 5 } @@ -2857,11 +2605,6 @@ "count": 1 } }, - "web/app/components/plugins/plugin-mutation-model/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/plugin-page/context.ts": { "ts/no-explicit-any": { "count": 1 @@ -2877,21 +2620,11 @@ "count": 2 } }, - "web/app/components/plugins/plugin-page/plugin-info.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/reference-setting-modal/auto-update-setting/types.ts": { "erasable-syntax-only/enums": { "count": 2 } }, - "web/app/components/plugins/reference-setting-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/types.ts": { "erasable-syntax-only/enums": { "count": 7 @@ -2974,11 +2707,6 @@ "count": 4 } }, - "web/app/components/rag-pipeline/components/publish-as-knowledge-pipeline-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/rag-pipeline/components/rag-pipeline-children.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2994,16 +2722,6 @@ "count": 2 } }, - "web/app/components/rag-pipeline/components/update-dsl-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/rag-pipeline/components/version-mismatch-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/rag-pipeline/hooks/index.ts": { "no-barrel-files/no-barrel-files": { "count": 9 @@ -3059,11 +2777,6 @@ "count": 1 } }, - "web/app/components/share/text-generation/info-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/share/text-generation/menu-dropdown.tsx": { "react/set-state-in-effect": { "count": 1 @@ -3102,20 +2815,7 @@ "count": 2 } }, - "web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/tools/edit-custom-collection-modal/get-schema.tsx": { - "ts/no-explicit-any": { - "count": 1 - } - }, "web/app/components/tools/edit-custom-collection-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 4 }, @@ -3124,9 +2824,6 @@ } }, "web/app/components/tools/edit-custom-collection-modal/test-api.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -3136,29 +2833,11 @@ "count": 1 } }, - "web/app/components/tools/mcp/detail/provider-detail.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/tools/mcp/mcp-server-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 5 - } - }, "web/app/components/tools/mcp/mcp-server-param-item.tsx": { "ts/no-explicit-any": { "count": 1 } }, - "web/app/components/tools/mcp/modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/tools/mcp/provider-card.tsx": { "ts/no-explicit-any": { "count": 3 @@ -3169,20 +2848,12 @@ "count": 1 } }, - "web/app/components/tools/provider/detail.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/tools/provider/empty.tsx": { "ts/no-explicit-any": { "count": 1 } }, "web/app/components/tools/setting/build-in/config-credentials.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 3 } @@ -3276,6 +2947,11 @@ "count": 1 } }, + "web/app/components/workflow/block-selector/main.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/workflow/block-selector/market-place-plugin/action.tsx": { "react/set-state-in-effect": { "count": 1 @@ -3291,6 +2967,11 @@ "count": 1 } }, + "web/app/components/workflow/block-selector/tool-picker.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/workflow/block-selector/tool/tool-list-flat-view/list.tsx": { "ts/no-explicit-any": { "count": 1 @@ -3854,16 +3535,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/http/components/authorization/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/workflow/nodes/http/components/curl-panel.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/http/components/key-value/key-value-edit/index.tsx": { "ts/no-explicit-any": { "count": 2 @@ -3989,11 +3660,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/knowledge-retrieval/components/dataset-item.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-item.tsx": { "ts/no-explicit-any": { "count": 1 @@ -4045,11 +3711,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx": { "ts/no-explicit-any": { "count": 3 @@ -4170,9 +3831,6 @@ } }, "web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -4200,30 +3858,9 @@ "count": 9 } }, - "web/app/components/workflow/nodes/question-classifier/components/class-item.tsx": { - "react/set-state-in-effect": { - "count": 1 - } - }, "web/app/components/workflow/nodes/question-classifier/components/class-list.tsx": { "react/set-state-in-effect": { "count": 1 - }, - "react/unsupported-syntax": { - "count": 2 - } - }, - "web/app/components/workflow/nodes/question-classifier/default.ts": { - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/workflow/nodes/question-classifier/use-config.ts": { - "react/set-state-in-effect": { - "count": 2 - }, - "ts/no-explicit-any": { - "count": 2 } }, "web/app/components/workflow/nodes/question-classifier/use-single-run-form-params.ts": { @@ -4406,6 +4043,9 @@ } }, "web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/component.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react/set-state-in-effect": { "count": 1 } @@ -4416,6 +4056,9 @@ } }, "web/app/components/workflow/operator/add-block.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -4477,9 +4120,6 @@ } }, "web/app/components/workflow/panel/debug-and-preview/conversation-variable-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 2 } @@ -4507,16 +4147,6 @@ "count": 4 } }, - "web/app/components/workflow/panel/version-history-panel/delete-confirm-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/workflow/panel/version-history-panel/restore-confirm-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/panel/workflow-preview.tsx": { "ts/no-explicit-any": { "count": 2 @@ -4652,9 +4282,6 @@ } }, "web/app/components/workflow/update-dsl-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -4768,11 +4395,6 @@ "count": 1 } }, - "web/app/education-apply/expire-notice-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/education-apply/hooks.ts": { "react/set-state-in-effect": { "count": 5 diff --git a/packages/dify-ui/AGENTS.md b/packages/dify-ui/AGENTS.md index bdc2160702..9524394214 100644 --- a/packages/dify-ui/AGENTS.md +++ b/packages/dify-ui/AGENTS.md @@ -75,7 +75,7 @@ Composition rules: - Keep Base UI primitive semantics visible in the public API. Export compound parts such as `ComboboxInputGroup`, `ComboboxInput`, `ComboboxContent`, `ComboboxList`, `ComboboxItem`, and `ComboboxItemIndicator` instead of wrapping them into one business component. - For `Combobox` multiple selection, follow the official chips pattern: `ComboboxInputGroup` contains `ComboboxChips`, `ComboboxValue` renders `ComboboxChip` items, and `ComboboxInput` remains inside the chips row. Chips should wrap and let the input group grow vertically instead of forcing horizontal overflow. -- Content primitives must own their Base UI `Portal` and use `z-1002` on `Positioner`, matching the overlay contract in `README.md`. +- Content primitives must own their Base UI `Portal` and use `z-50` on `Positioner`, matching the overlay contract in `README.md`. Toast owns `z-60`. - Use `w-(--anchor-width)` with viewport-aware max-width for `Autocomplete` and `Combobox` popups. Do not add `min-w-(--anchor-width)` when it would defeat available-width clamping. [Autocomplete docs]: https://base-ui.com/react/components/autocomplete.md#usage-guidelines diff --git a/packages/dify-ui/README.md b/packages/dify-ui/README.md index c78faede89..010fb3e56d 100644 --- a/packages/dify-ui/README.md +++ b/packages/dify-ui/README.md @@ -84,18 +84,18 @@ Equivalent: any root element with `isolation: isolate` in CSS. Without it, overl Every overlay primitive uses a single, shared z-index. Do **not** override it at call sites. -| Layer | z-index | Where | -| ------------------------------------------------------------------------------------------------------------------- | -------- | -------------------------------------------------------------------------- | -| Overlays (Dialog, AlertDialog, Autocomplete, Combobox, Drawer, Popover, DropdownMenu, ContextMenu, Select, Tooltip) | `z-1002` | Positioner / Backdrop | -| Toast viewport | `z-1003` | One layer above overlays so notifications are never hidden under a dialog. | +| Layer | z-index | Where | +| ------------------------------------------------------------------------------------------------------------------- | ------- | -------------------------------------------------------------------------- | +| Overlays (Dialog, AlertDialog, Autocomplete, Combobox, Drawer, Popover, DropdownMenu, ContextMenu, Select, Tooltip) | `z-50` | Positioner / Backdrop | +| Toast viewport | `z-60` | One layer above overlays so notifications are never hidden under a dialog. | -Rationale: during Dify's migration from legacy `base/modal` / `base/dialog` / `base/drawer` / `base/drawer-plus` overlays to this package, new and old overlays coexist in the DOM. `z-1002` sits above any common legacy layer, eliminating per-call-site z-index hacks. Among themselves, new primitives share the same z-index and **rely on DOM order** for stacking — the portal mounted later wins. +Rationale: Dify UI owns the normal application overlay layer. Overlay primitives share `z-50` and **rely on DOM order** for stacking — the portal mounted later wins. Toast owns `z-60` so notifications remain visible above dialogs, popovers, and other portalled surfaces without falling back to `z-9999`. -See `[web/docs/overlay-migration.md](../../web/docs/overlay-migration.md)` for the Dify-web migration history. Once the legacy overlays are gone, the values in this table can drop back to `z-50` / `z-51`. +See `[web/docs/overlay.md](../../web/docs/overlay.md)` for the web app overlay best practices. ### Rules -- Never add `z-1003` / `z-9999` / etc. overrides on primitives from this package. If something is getting clipped, the **parent** overlay (typically a legacy one) is the problem and should be migrated. +- Never add ad hoc `z-*` overrides on primitives from this package. If something is getting clipped, fix the parent overlay structure instead of raising the child primitive. - Never create an extra manual portal on top of our primitives — use the exported content / portal parts such as `DialogContent`, `PopoverContent`, and `DrawerPortal`. Base UI handles focus management, scroll-locking, and dismissal. - When a primitive needs additional presentation chrome (e.g. a custom backdrop), add it **inside** the exported component, not at call sites. diff --git a/packages/dify-ui/package.json b/packages/dify-ui/package.json index 894e92bfd6..96c512f89c 100644 --- a/packages/dify-ui/package.json +++ b/packages/dify-ui/package.json @@ -77,6 +77,14 @@ "types": "./src/switch/index.tsx", "import": "./src/switch/index.tsx" }, + "./tabs": { + "types": "./src/tabs/index.tsx", + "import": "./src/tabs/index.tsx" + }, + "./toggle-group": { + "types": "./src/toggle-group/index.tsx", + "import": "./src/toggle-group/index.tsx" + }, "./toast": { "types": "./src/toast/index.tsx", "import": "./src/toast/index.tsx" diff --git a/packages/dify-ui/src/alert-dialog/index.tsx b/packages/dify-ui/src/alert-dialog/index.tsx index 7b432c87dc..81299ef932 100644 --- a/packages/dify-ui/src/alert-dialog/index.tsx +++ b/packages/dify-ui/src/alert-dialog/index.tsx @@ -29,14 +29,14 @@ export function AlertDialogContent({ { await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-side', 'bottom') await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-align', 'start') - await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveClass('z-1002') + await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveClass('z-50') await expect.element(screen.getByRole('dialog', { name: 'autocomplete popup' })).toHaveClass('rounded-xl') await expect.element(screen.getByRole('dialog', { name: 'autocomplete popup' })).toHaveClass('w-(--anchor-width)') await expect.element(screen.getByRole('listbox', { name: 'autocomplete list' })).toHaveClass('scroll-py-1') diff --git a/packages/dify-ui/src/autocomplete/index.tsx b/packages/dify-ui/src/autocomplete/index.tsx index 16c4b19673..4c8893b376 100644 --- a/packages/dify-ui/src/autocomplete/index.tsx +++ b/packages/dify-ui/src/autocomplete/index.tsx @@ -261,7 +261,7 @@ export function AutocompleteContent({ align={align} sideOffset={sideOffset} alignOffset={alignOffset} - className={cn('z-1002 outline-hidden', className)} + className={cn('z-50 outline-hidden', className)} {...positionerProps} > { await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-side', 'bottom') await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-align', 'start') - await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveClass('z-1002') + await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveClass('z-50') await expect.element(screen.getByRole('dialog', { name: 'combobox popup' })).toHaveClass('rounded-xl') await expect.element(screen.getByRole('dialog', { name: 'combobox popup' })).toHaveClass('w-(--anchor-width)') await expect.element(screen.getByRole('listbox', { name: 'combobox list' })).toHaveClass('scroll-py-1') diff --git a/packages/dify-ui/src/combobox/index.tsx b/packages/dify-ui/src/combobox/index.tsx index c4f03241f6..eb43b911c7 100644 --- a/packages/dify-ui/src/combobox/index.tsx +++ b/packages/dify-ui/src/combobox/index.tsx @@ -323,7 +323,7 @@ export function ComboboxContent({ align={align} sideOffset={sideOffset} alignOffset={alignOffset} - className={cn('z-1002 outline-hidden', className)} + className={cn('z-50 outline-hidden', className)} {...positionerProps} > { expect(screen.container).not.toContainElement(dialog) await expect.element(dialog).toHaveTextContent('Workspace controls') await expect.element(screen.getByText('Configure the current workspace.')).toBeInTheDocument() - await expect.element(screen.getByTestId('drawer-backdrop')).toHaveClass('z-1002') + await expect.element(screen.getByTestId('drawer-backdrop')).toHaveClass('z-50') asHTMLElement(screen.getByRole('button', { name: 'Close drawer' }).element()).click() diff --git a/packages/dify-ui/src/drawer/index.tsx b/packages/dify-ui/src/drawer/index.tsx index c63bc8174e..a2ad6dcdaf 100644 --- a/packages/dify-ui/src/drawer/index.tsx +++ b/packages/dify-ui/src/drawer/index.tsx @@ -32,7 +32,7 @@ export function DrawerBackdrop({ return ( ) @@ -60,7 +60,7 @@ export function DrawerPopup({ return ( element as HTMLElement + +describe('Tabs wrappers', () => { + it('renders Base UI tabs with accessible roles', async () => { + const screen = await render( + + + JavaScript + Python + + JS panel + Python panel + , + ) + + await expect.element(screen.getByRole('tablist')).toBeInTheDocument() + await expect.element(screen.getByRole('tab', { name: 'JavaScript' })).toHaveAttribute('aria-selected', 'true') + await expect.element(screen.getByRole('tab', { name: 'Python' })).toHaveAttribute('aria-selected', 'false') + await expect.element(screen.getByText('JS panel')).toBeInTheDocument() + }) + + it('keeps tabs styling minimal by default', async () => { + const screen = await render( + + + First + Second + + , + ) + + await expect.element(screen.getByRole('tablist')).toHaveClass( + 'flex', + ) + await expect.element(screen.getByRole('tab', { name: 'First' })).toHaveClass( + 'touch-manipulation', + 'focus-visible:outline-hidden', + ) + }) + + it('calls onValueChange while leaving controlled value to the caller', async () => { + const onValueChange = vi.fn() + const screen = await render( + + + JavaScript + Python + + , + ) + + asHTMLElement(screen.getByRole('tab', { name: 'Python' }).element()).click() + + expect(onValueChange).toHaveBeenCalledWith('py', expect.anything()) + await expect.element(screen.getByRole('tab', { name: 'JavaScript' })).toHaveAttribute('aria-selected', 'true') + }) + + it('forwards className to composable parts', async () => { + const screen = await render( + + + First + + Panel + , + ) + + await expect.element(screen.getByRole('tablist')).toHaveClass('custom-list') + await expect.element(screen.getByRole('tab', { name: 'First' })).toHaveClass('custom-tab') + expect(screen.getByText('Panel').element()).toHaveClass('custom-panel') + }) +}) diff --git a/packages/dify-ui/src/tabs/index.stories.tsx b/packages/dify-ui/src/tabs/index.stories.tsx new file mode 100644 index 0000000000..dd1e79a1ce --- /dev/null +++ b/packages/dify-ui/src/tabs/index.stories.tsx @@ -0,0 +1,51 @@ +import type { Meta, StoryObj } from '@storybook/react-vite' +import { + Tabs, + TabsList, + TabsPanel, + TabsTab, +} from '.' + +const meta = { + title: 'Base/UI/Tabs', + component: Tabs, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Composable tabs built on Base UI. Use this when a tab controls a corresponding tab panel.', + }, + }, + }, + tags: ['autodocs'], +} satisfies Meta + +export default meta +type Story = StoryObj + +export const Basic: Story = { + render: () => ( + + + + Overview + + + Activity + + + + Overview panel + + + Activity panel + + + ), +} diff --git a/packages/dify-ui/src/tabs/index.tsx b/packages/dify-ui/src/tabs/index.tsx new file mode 100644 index 0000000000..ddc5891b89 --- /dev/null +++ b/packages/dify-ui/src/tabs/index.tsx @@ -0,0 +1,59 @@ +'use client' + +import type { Tabs as BaseTabsNS } from '@base-ui/react/tabs' +import { Tabs as BaseTabs } from '@base-ui/react/tabs' +import { cn } from '../cn' + +export type TabsProps = BaseTabsNS.Root.Props + +export const Tabs = BaseTabs.Root + +export type TabsListProps = Omit & { + className?: string +} + +export function TabsList({ + className, + ...props +}: TabsListProps) { + return ( + + ) +} + +export type TabsTabProps = Omit & { + className?: string +} + +export function TabsTab({ + className, + ...props +}: TabsTabProps) { + return ( + + ) +} + +export type TabsPanelProps = Omit & { + className?: string +} + +export function TabsPanel({ + className, + ...props +}: TabsPanelProps) { + return ( + + ) +} + +export const TabsIndicator = BaseTabs.Indicator diff --git a/packages/dify-ui/src/toast/__tests__/index.spec.tsx b/packages/dify-ui/src/toast/__tests__/index.spec.tsx index e02f6828ac..68ba746f4f 100644 --- a/packages/dify-ui/src/toast/__tests__/index.spec.tsx +++ b/packages/dify-ui/src/toast/__tests__/index.spec.tsx @@ -39,7 +39,7 @@ describe('@langgenius/dify-ui/toast', () => { await expect.element(screen.getByText('Saved')).toBeInTheDocument() await expect.element(screen.getByText('Your changes are available now.')).toBeInTheDocument() await expect.element(screen.getByRole('region', { name: 'Notifications' })).toHaveAttribute('aria-live', 'polite') - await expect.element(screen.getByRole('region', { name: 'Notifications' })).toHaveClass('z-1003') + await expect.element(screen.getByRole('region', { name: 'Notifications' })).toHaveClass('z-60') expect(screen.getByRole('region', { name: 'Notifications' }).element().firstElementChild).toHaveClass('top-4') expect(screen.getByRole('dialog').element()).not.toHaveClass('outline-hidden') expect(document.body.querySelector('[aria-hidden="true"].i-ri-checkbox-circle-fill')).toBeInTheDocument() diff --git a/packages/dify-ui/src/toast/index.tsx b/packages/dify-ui/src/toast/index.tsx index a479621563..7d4e867faf 100644 --- a/packages/dify-ui/src/toast/index.tsx +++ b/packages/dify-ui/src/toast/index.tsx @@ -222,7 +222,7 @@ function ToastViewport() {
element as HTMLElement + +describe('ToggleGroup wrappers', () => { + it('renders a segmented control with Base UI pressed state', async () => { + const screen = await render( + + One + Two + , + ) + + await expect.element(screen.getByRole('group')).toHaveClass( + 'bg-components-segmented-control-bg-normal', + 'p-0.5', + 'rounded-[10px]', + ) + await expect.element(screen.getByRole('button', { name: 'One' })).toHaveAttribute('aria-pressed', 'true') + await expect.element(screen.getByRole('button', { name: 'One' })).toHaveClass( + 'data-pressed:bg-components-segmented-control-item-active-bg', + 'data-pressed:text-text-accent-light-mode-only', + ) + }) + + it('uses single selection by default', async () => { + const screen = await render( + + One + Two + , + ) + + asHTMLElement(screen.getByRole('button', { name: 'Two' }).element()).click() + + await expect.element(screen.getByRole('button', { name: 'One' })).toHaveAttribute('aria-pressed', 'false') + await expect.element(screen.getByRole('button', { name: 'Two' })).toHaveAttribute('aria-pressed', 'true') + }) + + it('calls onValueChange while leaving controlled value to the caller', async () => { + const onValueChange = vi.fn() + const screen = await render( + + One + Two + , + ) + + asHTMLElement(screen.getByRole('button', { name: 'Two' }).element()).click() + + expect(onValueChange).toHaveBeenCalledWith(['two'], expect.anything()) + await expect.element(screen.getByRole('button', { name: 'One' })).toHaveAttribute('aria-pressed', 'true') + }) + + it('forwards disabled and className to composable parts', async () => { + const screen = await render( + + One + + Two + , + ) + + await expect.element(screen.getByRole('group')).toHaveClass('custom-group') + await expect.element(screen.getByRole('button', { name: 'One' })).toHaveClass('custom-item') + await expect.element(screen.getByRole('button', { name: 'Two' })).toBeDisabled() + await expect.element(screen.getByTestId('divider')).toHaveClass('custom-divider') + }) +}) diff --git a/packages/dify-ui/src/toggle-group/index.stories.tsx b/packages/dify-ui/src/toggle-group/index.stories.tsx new file mode 100644 index 0000000000..960957b7ab --- /dev/null +++ b/packages/dify-ui/src/toggle-group/index.stories.tsx @@ -0,0 +1,177 @@ +import type { Meta, StoryObj } from '@storybook/react-vite' +import type { ReactNode } from 'react' +import { + ToggleGroup, + ToggleGroupDivider, + ToggleGroupItem, +} from '.' + +const meta = { + title: 'Base/UI/ToggleGroup', + component: ToggleGroup, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Segmented control built on Base UI ToggleGroup and Toggle. Use this for mode, filter, and view selection that does not need tabpanel semantics.', + }, + }, + }, + tags: ['autodocs'], +} satisfies Meta + +export default meta +type Story = StoryObj + +type SegmentedControlProps = { + defaultValue: string + values: string[] + iconOnly?: boolean + noPadding?: boolean +} + +const Icon = () => ( +