From 512417cbd65a12c5aac4fe5cac4e4b2aca37f85e Mon Sep 17 00:00:00 2001 From: chariri Date: Fri, 26 Jun 2026 03:17:53 +0900 Subject: [PATCH 1/2] refactor(api): migrate dataset rag pipeline endpoints to BaseModel --- .../datasets/rag_pipeline/datasource_auth.py | 130 ++++++--- .../datasource_content_preview.py | 11 +- .../rag_pipeline/rag_pipeline_workflow.py | 63 ++-- .../dataset/rag_pipeline/serializers.py | 32 --- api/openapi/markdown/console-openapi.md | 137 +++++++-- api/openapi/markdown/service-openapi.md | 28 +- .../test_rag_pipeline_workflow.py | 10 +- .../rag_pipeline/test_datasource_auth.py | 271 +++++++++++++++--- .../test_rag_pipeline_workflow.py | 62 ++++ .../test_rag_pipeline_workflow.py | 47 ++- .../generated/api/console/auth/types.gen.ts | 91 +++++- .../generated/api/console/auth/zod.gen.ts | 147 ++++++++-- .../generated/api/console/oauth/zod.gen.ts | 2 +- .../generated/api/console/rag/orpc.gen.ts | 40 +-- .../generated/api/console/rag/types.gen.ts | 28 -- .../generated/api/console/rag/zod.gen.ts | 29 -- .../generated/api/service/types.gen.ts | 44 ++- .../generated/api/service/zod.gen.ts | 10 +- packages/contracts/openapi-ts.api.config.ts | 78 ++++- 19 files changed, 926 insertions(+), 334 deletions(-) delete mode 100644 api/controllers/service_api/dataset/rag_pipeline/serializers.py diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index a575760ee19..389515bd4f9 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config -from controllers.common.fields import RedirectResponse, SimpleResultResponse +from controllers.common.fields import SimpleResultResponse from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( @@ -19,11 +19,13 @@ from controllers.console.wraps import ( with_current_tenant_id, with_current_user, ) +from core.entities.provider_entities import ProviderConfig from core.plugin.entities.plugin_daemon import PluginOAuthAuthorizationUrlResponse from core.plugin.impl.oauth import OAuthHandler +from core.tools.entities.common_entities import I18nObject from fields.base import ResponseModel from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder +from libs.helper import dump_response from libs.login import login_required from models import Account from models.provider_ids import DatasourceProviderID @@ -33,7 +35,9 @@ from services.plugin.oauth_service import OAuthProxyService class DatasourceCredentialPayload(BaseModel): name: str | None = Field(default=None, max_length=100) - credentials: dict[str, Any] + credentials: dict[str, Any] = Field( + description="Plugin-defined credential parameters. The schema is declared by the datasource provider." + ) class DatasourceCredentialDeletePayload(BaseModel): @@ -43,11 +47,17 @@ class DatasourceCredentialDeletePayload(BaseModel): class DatasourceCredentialUpdatePayload(BaseModel): credential_id: str name: str | None = Field(default=None, max_length=100) - credentials: dict[str, Any] | None = Field(default=None) + credentials: dict[str, Any] | None = Field( + default=None, + description="Plugin-defined credential parameters. The schema is declared by the datasource provider.", + ) class DatasourceCustomClientPayload(BaseModel): - client_params: dict[str, Any] | None = Field(default=None) + client_params: dict[str, Any] | None = Field( + default=None, + description="Plugin-defined OAuth client parameters. The schema is declared by the datasource provider.", + ) enable_oauth_custom_client: bool | None = None @@ -71,8 +81,48 @@ class DatasourceOAuthCallbackQuery(BaseModel): context_id: str | None = Field(default=None, description="OAuth proxy context ID") -class DatasourceCredentialsResponse(ResponseModel): - result: Any +class DatasourceCredentialResponse(ResponseModel): + credential: dict[str, Any] = Field( + description="Obfuscated plugin-defined credential parameters from the datasource provider." + ) + type: str + name: str + avatar_url: str | None + id: str + is_default: bool + + +class DatasourceCredentialListResponse(ResponseModel): + result: list[DatasourceCredentialResponse] + + +class DatasourceOAuthSchemaResponse(ResponseModel): + client_schema: list[ProviderConfig] + credentials_schema: list[ProviderConfig] + oauth_custom_client_params: dict[str, Any] | None = Field( + description="Masked plugin-defined OAuth client parameters, when configured for the tenant." + ) + is_oauth_custom_client_enabled: bool + is_system_oauth_params_exists: bool + redirect_uri: str + + +class DatasourceProviderAuthResponse(ResponseModel): + author: str + provider: str + plugin_id: str + plugin_unique_identifier: str + icon: str + name: str + label: I18nObject + description: I18nObject + credential_schema: list[ProviderConfig] + oauth_schema: DatasourceOAuthSchemaResponse | None + credentials_list: list[DatasourceCredentialResponse] + + +class DatasourceProviderAuthListResponse(ResponseModel): + result: list[DatasourceProviderAuthResponse] register_schema_models( @@ -88,9 +138,9 @@ register_schema_models( ) register_response_schema_models( console_ns, - DatasourceCredentialsResponse, + DatasourceCredentialListResponse, + DatasourceProviderAuthListResponse, PluginOAuthAuthorizationUrlResponse, - RedirectResponse, SimpleResultResponse, ) @@ -100,7 +150,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): @console_ns.doc(params=query_params_from_model(DatasourceOAuthAuthorizationQuery)) @console_ns.response( 200, - "Authorization URL retrieved successfully", + "Datasource OAuth authorization URL generated successfully", console_ns.models[PluginOAuthAuthorizationUrlResponse.__name__], ) @setup_required @@ -140,7 +190,8 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): redirect_uri=redirect_uri, system_credentials=oauth_config, ) - response = make_response(jsonable_encoder(authorization_url_response)) + # response-contract:ignore cookie-bearing Flask response + response = make_response(dump_response(PluginOAuthAuthorizationUrlResponse, authorization_url_response)) response.set_cookie( "context_id", context_id, @@ -154,11 +205,8 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): @console_ns.route("/oauth/plugin//datasource/callback") class DatasourceOAuthCallback(Resource): @console_ns.doc(params=query_params_from_model(DatasourceOAuthCallbackQuery)) - @console_ns.response( - 302, - "Redirect to console OAuth callback page", - console_ns.models[RedirectResponse.__name__], - ) + # response-contract:ignore redirect response + @console_ns.response(302, "Redirect to OAuth callback page") @setup_required def get(self, provider_id: str): context_id = request.cookies.get("context_id") or request.args.get("context_id") @@ -217,7 +265,9 @@ class DatasourceOAuthCallback(Resource): @console_ns.route("/auth/plugin/datasource/") class DatasourceAuth(Resource): @console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__]) - @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) + @console_ns.response( + 200, "Datasource credential created successfully", console_ns.models[SimpleResultResponse.__name__] + ) @setup_required @login_required @account_initialization_required @@ -238,12 +288,16 @@ class DatasourceAuth(Resource): ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) - return {"result": "success"}, 200 + return SimpleResultResponse(result="success").model_dump(mode="json"), 200 + @console_ns.response( + 200, + "Datasource credentials retrieved successfully", + console_ns.models[DatasourceCredentialListResponse.__name__], + ) @setup_required @login_required @account_initialization_required - @console_ns.response(200, "Success", console_ns.models[DatasourceCredentialsResponse.__name__]) @with_current_user @with_current_tenant_id def get(self, current_tenant_id: str, user: Account, provider_id: str): @@ -256,7 +310,7 @@ class DatasourceAuth(Resource): plugin_id=datasource_provider_id.plugin_id, user=user, ) - return {"result": datasources}, 200 + return dump_response(DatasourceCredentialListResponse, {"result": datasources}), 200 @console_ns.route("/auth/plugin/datasource//delete") @@ -282,13 +336,15 @@ class DatasourceAuthDeleteApi(Resource): provider=provider_name, plugin_id=plugin_id, ) - return {"result": "success"}, 200 + return SimpleResultResponse(result="success").model_dump(mode="json"), 200 @console_ns.route("/auth/plugin/datasource//update") class DatasourceAuthUpdateApi(Resource): @console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__]) - @console_ns.response(201, "Success", console_ns.models[SimpleResultResponse.__name__]) + @console_ns.response( + 201, "Datasource credential updated successfully", console_ns.models[SimpleResultResponse.__name__] + ) @setup_required @login_required @account_initialization_required @@ -308,12 +364,16 @@ class DatasourceAuthUpdateApi(Resource): credentials=payload.credentials or {}, name=payload.name, ) - return {"result": "success"}, 201 + return SimpleResultResponse(result="success").model_dump(mode="json"), 201 @console_ns.route("/auth/plugin/datasource/list") class DatasourceAuthListApi(Resource): - @console_ns.response(200, "Success", console_ns.models[DatasourceCredentialsResponse.__name__]) + @console_ns.response( + 200, + "Datasource credentials retrieved successfully", + console_ns.models[DatasourceProviderAuthListResponse.__name__], + ) @setup_required @login_required @account_initialization_required @@ -321,12 +381,16 @@ class DatasourceAuthListApi(Resource): def get(self, current_tenant_id: str): datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id) - return {"result": jsonable_encoder(datasources)}, 200 + return dump_response(DatasourceProviderAuthListResponse, {"result": datasources}), 200 @console_ns.route("/auth/plugin/datasource/default-list") class DatasourceHardCodeAuthListApi(Resource): - @console_ns.response(200, "Success", console_ns.models[DatasourceCredentialsResponse.__name__]) + @console_ns.response( + 200, + "Default datasource credentials retrieved successfully", + console_ns.models[DatasourceProviderAuthListResponse.__name__], + ) @setup_required @login_required @account_initialization_required @@ -334,13 +398,15 @@ class DatasourceHardCodeAuthListApi(Resource): def get(self, current_tenant_id: str): datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id) - return {"result": jsonable_encoder(datasources)}, 200 + return dump_response(DatasourceProviderAuthListResponse, {"result": datasources}), 200 @console_ns.route("/auth/plugin/datasource//custom-client") class DatasourceAuthOauthCustomClient(Resource): @console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__]) - @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) + @console_ns.response( + 200, "Datasource OAuth custom client saved successfully", console_ns.models[SimpleResultResponse.__name__] + ) @setup_required @login_required @account_initialization_required @@ -357,7 +423,7 @@ class DatasourceAuthOauthCustomClient(Resource): client_params=payload.client_params or {}, enabled=payload.enable_oauth_custom_client or False, ) - return {"result": "success"}, 200 + return SimpleResultResponse(result="success").model_dump(mode="json"), 200 @setup_required @login_required @@ -371,7 +437,7 @@ class DatasourceAuthOauthCustomClient(Resource): tenant_id=current_tenant_id, datasource_provider_id=datasource_provider_id, ) - return {"result": "success"}, 200 + return SimpleResultResponse(result="success").model_dump(mode="json"), 200 @console_ns.route("/auth/plugin/datasource//default") @@ -393,7 +459,7 @@ class DatasourceAuthDefaultApi(Resource): datasource_provider_id=datasource_provider_id, credential_id=payload.id, ) - return {"result": "success"}, 200 + return SimpleResultResponse(result="success").model_dump(mode="json"), 200 @console_ns.route("/auth/plugin/datasource//update-name") @@ -416,4 +482,4 @@ class DatasourceUpdateProviderNameApi(Resource): name=payload.name, credential_id=payload.credential_id, ) - return {"result": "success"}, 200 + return SimpleResultResponse(result="success").model_dump(mode="json"), 200 diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index b0af108444c..213337fedc9 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -3,9 +3,9 @@ from typing import Any from flask_restx import ( # type: ignore Resource, # type: ignore ) -from pydantic import BaseModel, RootModel +from pydantic import BaseModel -from controllers.common.schema import register_response_schema_models, register_schema_models +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required, with_current_user @@ -21,18 +21,13 @@ class Parser(BaseModel): credential_id: str | None = None -class DataSourceContentPreviewResponse(RootModel[Any]): - root: Any - - register_schema_models(console_ns, Parser) -register_response_schema_models(console_ns, DataSourceContentPreviewResponse) @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//preview") class DataSourceContentPreviewApi(Resource): @console_ns.expect(console_ns.models[Parser.__name__]) - @console_ns.response(200, "Success", console_ns.models[DataSourceContentPreviewResponse.__name__]) + @console_ns.response(200, "Success") @setup_required @login_required @account_initialization_required diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index a6a61262cdc..0b28f4d10c8 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -1,34 +1,30 @@ from collections.abc import Generator +from datetime import datetime from typing import Any from uuid import UUID from flask import request -from pydantic import BaseModel, Field, RootModel +from pydantic import BaseModel, Field, RootModel, field_validator from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound import services from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError -from controllers.common.fields import GeneratedAppResponse from controllers.common.schema import ( query_params_from_model, + query_params_from_request, register_response_schema_models, register_schema_model, - register_schema_models, ) from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import PipelineRunError -from controllers.service_api.dataset.rag_pipeline.serializers import serialize_upload_file -from controllers.service_api.schema import ( - event_stream_response, - json_or_event_stream_response, - multipart_file_params, -) +from controllers.service_api.schema import event_stream_response, json_or_event_stream_response, multipart_file_params from controllers.service_api.wraps import DatasetApiResource from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from fields.base import ResponseModel from libs import helper +from libs.helper import dump_response from libs.login import current_user from models import Account from models.dataset import Dataset, Pipeline @@ -82,7 +78,7 @@ class DatasourcePluginResponse(ResponseModel): datasource_type: str | None = None title: str | None = None user_input_variables: list[dict[str, Any]] = Field(default_factory=list) - credentials: list[DatasourceCredentialInfoResponse] + credentials: list[DatasourceCredentialInfoResponse] = Field(default_factory=list) class DatasourcePluginListResponse(RootModel[list[DatasourcePluginResponse]]): @@ -98,14 +94,22 @@ class PipelineUploadFileResponse(ResponseModel): created_by: str created_at: str | None = None + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | str | None) -> str | None: + if isinstance(value, datetime): + return value.isoformat() + return value + register_schema_model(service_api_ns, DatasourceNodeRunPayload) +register_schema_model(service_api_ns, DatasourcePluginsQuery) register_schema_model(service_api_ns, PipelineRunApiEntity) -register_schema_models(service_api_ns, DatasourcePluginsQuery) register_response_schema_models( service_api_ns, + DatasourceCredentialInfoResponse, + DatasourcePluginResponse, DatasourcePluginListResponse, - GeneratedAppResponse, PipelineUploadFileResponse, ) @@ -117,8 +121,8 @@ class DatasourcePluginsApi(DatasetApiResource): @service_api_ns.doc( summary="List Datasource Plugins", description=( - "List the datasource nodes configured in the knowledge pipeline. Each node includes the " - "plugin it uses plus the metadata needed to run it." + "List the datasource nodes configured in the knowledge pipeline. Each node includes the plugin it uses " + "plus the metadata needed to run it." ), tags=["Knowledge Pipeline"], responses={ @@ -150,14 +154,13 @@ class DatasourcePluginsApi(DatasetApiResource): if not dataset: raise NotFound("Dataset not found.") - # Get query parameter to determine published or draft - is_published: bool = request.args.get("is_published", default=True, type=bool) + query = query_params_from_request(DatasourcePluginsQuery) rag_pipeline_service: RagPipelineService = RagPipelineService() datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins( - tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=is_published + tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=query.is_published ) - return datasource_plugins, 200 + return dump_response(DatasourcePluginListResponse, datasource_plugins), 200 @service_api_ns.route("/datasets//pipeline/datasource/nodes//run") @@ -167,8 +170,8 @@ class DatasourceNodeRunApi(DatasetApiResource): @service_api_ns.doc( summary="Run Datasource Node", description=( - "Execute a single datasource node within the knowledge pipeline. Returns a streaming " - "response with the node execution results." + "Execute a single datasource node within the knowledge pipeline. Returns a streaming response with the " + "node execution results." ), tags=["Knowledge Pipeline"], responses={ @@ -187,11 +190,6 @@ class DatasourceNodeRunApi(DatasetApiResource): } ) @service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__]) - @service_api_ns.response( - 200, - "Datasource node run successfully", - service_api_ns.models[GeneratedAppResponse.__name__], - ) def post(self, tenant_id: str, dataset_id: UUID, node_id: str): """Resource for getting datasource plugins.""" dataset_id_str = str(dataset_id) @@ -208,10 +206,11 @@ class DatasourceNodeRunApi(DatasetApiResource): datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate( { **payload.model_dump(exclude_none=True), - "pipeline_id": str(pipeline.id), + "pipeline_id": pipeline.id, "node_id": node_id, } ) + # response-contract:ignore compact_generate_response return helper.compact_generate_response( PipelineGenerator.convert_to_event_stream( rag_pipeline_service.run_datasource_workflow_node( @@ -234,8 +233,8 @@ class PipelineRunApi(DatasetApiResource): @service_api_ns.doc( summary="Run Pipeline", description=( - "Execute the full knowledge pipeline for a knowledge base. Supports both streaming and " - "blocking response modes." + "Execute the full knowledge pipeline for a knowledge base. Supports both streaming and blocking response " + "modes." ), tags=["Knowledge Pipeline"], responses={ @@ -259,11 +258,6 @@ class PipelineRunApi(DatasetApiResource): } ) @service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__]) - @service_api_ns.response( - 200, - "Pipeline run successfully", - service_api_ns.models[GeneratedAppResponse.__name__], - ) def post(self, tenant_id: str, dataset_id: UUID): """Resource for running a rag pipeline.""" dataset_id_str = str(dataset_id) @@ -289,6 +283,7 @@ class PipelineRunApi(DatasetApiResource): streaming=payload.response_mode == "streaming", ) + # response-contract:ignore compact_generate_response return helper.compact_generate_response(response) except Exception as ex: raise PipelineRunError(description=str(ex)) @@ -364,4 +359,4 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return serialize_upload_file(upload_file), 201 + return dump_response(PipelineUploadFileResponse, upload_file), 201 diff --git a/api/controllers/service_api/dataset/rag_pipeline/serializers.py b/api/controllers/service_api/dataset/rag_pipeline/serializers.py deleted file mode 100644 index a5e8484037e..00000000000 --- a/api/controllers/service_api/dataset/rag_pipeline/serializers.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -Serialization helpers for Service API knowledge pipeline endpoints. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, TypedDict - -if TYPE_CHECKING: - from models.model import UploadFile - - -class UploadFileDict(TypedDict): - id: str - name: str - size: int - extension: str - mime_type: str | None - created_by: str - created_at: str | None - - -def serialize_upload_file(upload_file: UploadFile) -> UploadFileDict: - return { - "id": upload_file.id, - "name": upload_file.name, - "size": upload_file.size, - "extension": upload_file.extension, - "mime_type": upload_file.mime_type, - "created_by": upload_file.created_by, - "created_at": upload_file.created_at.isoformat() if upload_file.created_at else None, - } diff --git a/api/openapi/markdown/console-openapi.md b/api/openapi/markdown/console-openapi.md index b3a0b8a6a71..a683e2554b1 100644 --- a/api/openapi/markdown/console-openapi.md +++ b/api/openapi/markdown/console-openapi.md @@ -4496,14 +4496,14 @@ Refresh MCP server configuration and regenerate server code | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Success | **application/json**: [DatasourceCredentialsResponse](#datasourcecredentialsresponse)
| +| 200 | Default datasource credentials retrieved successfully | **application/json**: [DatasourceProviderAuthListResponse](#datasourceproviderauthlistresponse)
| ### [GET] /auth/plugin/datasource/list #### Responses | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Success | **application/json**: [DatasourceCredentialsResponse](#datasourcecredentialsresponse)
| +| 200 | Datasource credentials retrieved successfully | **application/json**: [DatasourceProviderAuthListResponse](#datasourceproviderauthlistresponse)
| ### [GET] /auth/plugin/datasource/{provider_id} #### Parameters @@ -4516,7 +4516,7 @@ Refresh MCP server configuration and regenerate server code | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Success | **application/json**: [DatasourceCredentialsResponse](#datasourcecredentialsresponse)
| +| 200 | Datasource credentials retrieved successfully | **application/json**: [DatasourceCredentialListResponse](#datasourcecredentiallistresponse)
| ### [POST] /auth/plugin/datasource/{provider_id} #### Parameters @@ -4535,7 +4535,7 @@ Refresh MCP server configuration and regenerate server code | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Success | **application/json**: [SimpleResultResponse](#simpleresultresponse)
| +| 200 | Datasource credential created successfully | **application/json**: [SimpleResultResponse](#simpleresultresponse)
| ### [DELETE] /auth/plugin/datasource/{provider_id}/custom-client #### Parameters @@ -4567,7 +4567,7 @@ Refresh MCP server configuration and regenerate server code | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Success | **application/json**: [SimpleResultResponse](#simpleresultresponse)
| +| 200 | Datasource OAuth custom client saved successfully | **application/json**: [SimpleResultResponse](#simpleresultresponse)
| ### [POST] /auth/plugin/datasource/{provider_id}/default #### Parameters @@ -4624,7 +4624,7 @@ Refresh MCP server configuration and regenerate server code | Code | Description | Schema | | ---- | ----------- | ------ | -| 201 | Success | **application/json**: [SimpleResultResponse](#simpleresultresponse)
| +| 201 | Datasource credential updated successfully | **application/json**: [SimpleResultResponse](#simpleresultresponse)
| ### [POST] /auth/plugin/datasource/{provider_id}/update-name #### Parameters @@ -7022,9 +7022,9 @@ Initiate OAuth login process #### Responses -| Code | Description | Schema | -| ---- | ----------- | ------ | -| 302 | Redirect to console OAuth callback page | **application/json**: [RedirectResponse](#redirectresponse)
| +| Code | Description | +| ---- | ----------- | +| 302 | Redirect to OAuth callback page | ### [GET] /oauth/plugin/{provider_id}/datasource/get-authorization-url #### Parameters @@ -7038,7 +7038,7 @@ Initiate OAuth login process | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Authorization URL retrieved successfully | **application/json**: [PluginOAuthAuthorizationUrlResponse](#pluginoauthauthorizationurlresponse)
| +| 200 | Datasource OAuth authorization URL generated successfully | **application/json**: [PluginOAuthAuthorizationUrlResponse](#pluginoauthauthorizationurlresponse)
| ### [GET] /oauth/plugin/{provider}/tool/authorization-url #### Parameters @@ -7861,9 +7861,9 @@ Initiate OAuth login process #### Responses -| Code | Description | Schema | -| ---- | ----------- | ------ | -| 200 | Success | **application/json**: [DataSourceContentPreviewResponse](#datasourcecontentpreviewresponse)
| +| Code | Description | +| ---- | ----------- | +| 200 | Success | ### [POST] /rag/pipelines/{pipeline_id}/workflows/published/datasource/nodes/{node_id}/run **Run rag pipeline datasource** @@ -13909,6 +13909,12 @@ AppMCPServer Status Enum | use_icon_as_answer_icon | boolean | | No | | workflow | [WorkflowPartial](#workflowpartial) | | No | +#### AppSelectorScope + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| AppSelectorScope | string | | | + #### AppSiteResponse | Name | Type | Description | Required | @@ -14851,12 +14857,6 @@ Model class for provider custom model configuration. | ---- | ---- | ----------- | -------- | | info_list | [InfoList](#infolist) | | Yes | -#### DataSourceContentPreviewResponse - -| Name | Type | Description | Required | -| ---- | ---- | ----------- | -------- | -| DataSourceContentPreviewResponse | | | | - #### DataSourceIntegrateIconResponse | Name | Type | Description | Required | @@ -15377,32 +15377,43 @@ Model class for provider custom model configuration. | ---- | ---- | ----------- | -------- | | credential_id | string | | Yes | +#### DatasourceCredentialListResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| result | [ [DatasourceCredentialResponse](#datasourcecredentialresponse) ] | | Yes | + #### DatasourceCredentialPayload | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| credentials | object | | Yes | +| credentials | object | Plugin-defined credential parameters. The schema is declared by the datasource provider. | Yes | | name | string | | No | +#### DatasourceCredentialResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| avatar_url | string | | Yes | +| credential | object | Obfuscated plugin-defined credential parameters from the datasource provider. | Yes | +| id | string | | Yes | +| is_default | boolean | | Yes | +| name | string | | Yes | +| type | string | | Yes | + #### DatasourceCredentialUpdatePayload | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | | credential_id | string | | Yes | -| credentials | object | | No | +| credentials | object | Plugin-defined credential parameters. The schema is declared by the datasource provider. | No | | name | string | | No | -#### DatasourceCredentialsResponse - -| Name | Type | Description | Required | -| ---- | ---- | ----------- | -------- | -| result | | | Yes | - #### DatasourceCustomClientPayload | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| client_params | object | | No | +| client_params | object | Plugin-defined OAuth client parameters. The schema is declared by the datasource provider. | No | | enable_oauth_custom_client | boolean | | No | #### DatasourceDefaultPayload @@ -15434,6 +15445,39 @@ Model class for provider custom model configuration. | error | string | Error message from OAuth provider | No | | state | string | OAuth state parameter | No | +#### DatasourceOAuthSchemaResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| client_schema | [ [ProviderConfig](#providerconfig) ] | | Yes | +| credentials_schema | [ [ProviderConfig](#providerconfig) ] | | Yes | +| is_oauth_custom_client_enabled | boolean | | Yes | +| is_system_oauth_params_exists | boolean | | Yes | +| oauth_custom_client_params | object | Masked plugin-defined OAuth client parameters, when configured for the tenant. | Yes | +| redirect_uri | string | | Yes | + +#### DatasourceProviderAuthListResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| result | [ [DatasourceProviderAuthResponse](#datasourceproviderauthresponse) ] | | Yes | + +#### DatasourceProviderAuthResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| author | string | | Yes | +| credential_schema | [ [ProviderConfig](#providerconfig) ] | | Yes | +| credentials_list | [ [DatasourceCredentialResponse](#datasourcecredentialresponse) ] | | Yes | +| description | [I18nObject](#i18nobject) | | Yes | +| icon | string | | Yes | +| label | [I18nObject](#i18nobject) | | Yes | +| name | string | | Yes | +| oauth_schema | [DatasourceOAuthSchemaResponse](#datasourceoauthschemaresponse) | | Yes | +| plugin_id | string | | Yes | +| plugin_unique_identifier | string | | Yes | +| provider | string | | Yes | + #### DatasourceUpdateNamePayload | Name | Type | Description | Required | @@ -17281,6 +17325,12 @@ Enum class for model property key. | ---- | ---- | ----------- | -------- | | payment_link | string | | Yes | +#### ModelSelectorScope + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| ModelSelectorScope | string | | | + #### ModelStatus Enum class for model status. @@ -17588,6 +17638,13 @@ Coarse node-level status used by Inspector to pick a banner. | ---- | ---- | ----------- | -------- | | OpaqueObjectResponse | object | | | +#### Option + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| label | [I18nObject](#i18nobject) | The label of the option | Yes | +| value | string | The value of the option | Yes | + #### OutputErrorStrategy Per-output failure handling strategy. @@ -18429,6 +18486,24 @@ Dataset Process Rule Mode | ---- | ---- | ----------- | -------- | | ProcessRuleMode | string | Dataset Process Rule Mode | | +#### ProviderConfig + +Model class for common provider settings like credentials + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| default | integer
string
number
boolean | | No | +| help | [I18nObject](#i18nobject) | | No | +| label | [I18nObject](#i18nobject) | | No | +| multiple | boolean | | No | +| name | string | The name of the credentials | Yes | +| options | [ [Option](#option) ] | | No | +| placeholder | [I18nObject](#i18nobject) | | No | +| required | boolean | | No | +| scope | [AppSelectorScope](#appselectorscope)
[ModelSelectorScope](#modelselectorscope)
[ToolSelectorScope](#toolselectorscope) | | No | +| type | [Type](#type) | The type of the credentials | Yes | +| url | string | | No | + #### ProviderCredentialResponse | Name | Type | Description | Required | @@ -19841,6 +19916,12 @@ Enum class for tool provider | ---- | ---- | ----------- | -------- | | ToolProviderType | string | Enum class for tool provider | | +#### ToolSelectorScope + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| ToolSelectorScope | string | | | + #### TraceAppConfigResponse | Name | Type | Description | Required | diff --git a/api/openapi/markdown/service-openapi.md b/api/openapi/markdown/service-openapi.md index 8fc5e75e3cf..5d9956d2ff8 100644 --- a/api/openapi/markdown/service-openapi.md +++ b/api/openapi/markdown/service-openapi.md @@ -1046,12 +1046,12 @@ Execute a single datasource node within the knowledge pipeline. Returns a stream #### Responses -| Code | Description | Schema | -| ---- | ----------- | ------ | -| 200 | Streaming response with node execution events. | **text/event-stream**: [GeneratedAppResponse](#generatedappresponse)
| -| 401 | Unauthorized - invalid API token | | -| 403 | Forbidden - dataset API access or workspace access denied | | -| 404 | `not_found` : Dataset not found. | | +| Code | Description | +| ---- | ----------- | +| 200 | Streaming response with node execution events. | +| 401 | Unauthorized - invalid API token | +| 403 | Forbidden - dataset API access or workspace access denied | +| 404 | `not_found` : Dataset not found. | ### [POST] /datasets/{dataset_id}/pipeline/run **Run Pipeline** @@ -1072,13 +1072,13 @@ Execute the full knowledge pipeline for a knowledge base. Supports both streamin #### Responses -| Code | Description | Schema | -| ---- | ----------- | ------ | -| 200 | Pipeline execution result. Format depends on `response_mode`: streaming returns a `text/event-stream`, blocking returns a JSON object. | **application/json**: [GeneratedAppResponse](#generatedappresponse)
**text/event-stream**: [GeneratedAppResponse](#generatedappresponse)
| -| 401 | Unauthorized - invalid API token | | -| 403 | `forbidden` : Forbidden. | | -| 404 | `not_found` : Dataset not found. | | -| 500 | `pipeline_run_error` : Pipeline execution failed. | | +| Code | Description | +| ---- | ----------- | +| 200 | Pipeline execution result. Format depends on `response_mode`: streaming returns a `text/event-stream`, blocking returns a JSON object. | +| 401 | Unauthorized - invalid API token | +| 403 | `forbidden` : Forbidden. | +| 404 | `not_found` : Dataset not found. | +| 500 | `pipeline_run_error` : Pipeline execution failed. | --- ## default @@ -2960,7 +2960,7 @@ Enum class for custom configuration status. | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| credentials | [ [DatasourceCredentialInfoResponse](#datasourcecredentialinforesponse) ] | | Yes | +| credentials | [ [DatasourceCredentialInfoResponse](#datasourcecredentialinforesponse) ] | | No | | datasource_type | string | | No | | node_id | string | | No | | plugin_id | string | | No | diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index bdec903ef33..4b8186f3017 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -607,7 +607,11 @@ class TestMiscApis: method = unwrap(api.get) service = MagicMock() - service.get_recommended_plugins.return_value = [{"id": "p1"}] + recommended_plugins = { + "installed_recommended_plugins": [{"id": "p1"}], + "uninstalled_recommended_plugins": [{"id": "p2"}], + } + service.get_recommended_plugins.return_value = recommended_plugins user = make_account() tenant_id = "tenant-1" @@ -619,7 +623,7 @@ class TestMiscApis: ), ): result = method(api, tenant_id, user) - assert result == [{"id": "p1"}] + assert result == recommended_plugins service.get_recommended_plugins.assert_called_once_with("all", user, tenant_id) @@ -826,7 +830,7 @@ class TestRagPipelineByIdApi: result = method(api, pipeline, "old-workflow") workflow_service.delete_workflow.assert_called_once() - assert result == (None, 204) + assert result == ("", 204) def test_delete_active_workflow_rejected(self, app: Flask) -> None: api = RagPipelineByIdApi() diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index e8faece89ca..e56bfe4adb5 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,4 +1,5 @@ import inspect +from datetime import UTC, datetime from unittest.mock import MagicMock, patch import pytest @@ -23,6 +24,76 @@ from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService +_PROVIDER_ID = "langgenius/notion_datasource/notion" + + +def _i18n(text: str) -> dict[str, str]: + return {"en_US": text, "zh_Hans": text, "pt_BR": text, "ja_JP": text} + + +def _provider_config(name: str, type_: str, label: str, *, required: bool = True) -> dict: + return { + "type": type_, + "name": name, + "scope": None, + "required": required, + "default": None, + "options": None, + "multiple": False, + "label": _i18n(label), + "help": None, + "url": None, + "placeholder": None, + } + + +def _datasource_credential(credential_id: str = "cred-1", *, is_default: bool = True) -> dict: + return { + "credential": { + "api_key": "******", + "workspace": "engineering", + "database_id": "db-123", + }, + "type": "api-key", + "name": "API Key", + "avatar_url": "https://cdn.example.com/notion.png", + "id": credential_id, + "is_default": is_default, + } + + +def _datasource_auth() -> dict: + return { + "author": "Dify", + "provider": "notion", + "plugin_id": "langgenius/notion_datasource", + "plugin_unique_identifier": "langgenius/notion_datasource:0.0.1", + "icon": "icon.svg", + "name": "notion", + "label": _i18n("Notion"), + "description": _i18n("Notion datasource"), + "credential_schema": [ + _provider_config("api_key", "secret-input", "API key"), + ], + "oauth_schema": { + "client_schema": [ + _provider_config("client_id", "text-input", "Client ID"), + ], + "credentials_schema": [ + _provider_config("access_token", "secret-input", "Access token"), + ], + "oauth_custom_client_params": {"client_id": "masked-client", "client_secret": "********"}, + "is_oauth_custom_client_enabled": True, + "is_system_oauth_params_exists": True, + "redirect_uri": "https://api.example.com/oauth/callback", + }, + "credentials_list": [_datasource_credential(), _datasource_credential("cred-2", is_default=False)], + } + + +def _success_response() -> dict[str, str]: + return {"result": "success"} + class TestDatasourcePluginOAuthAuthorizationUrl: def test_get_success(self, app: Flask): @@ -30,28 +101,50 @@ class TestDatasourcePluginOAuthAuthorizationUrl: method = inspect.unwrap(api.get) user = MagicMock(id="user-1") + oauth_client = {"client_id": "abc", "client_secret": "shh", "scopes": ["read", "write"]} + auth_url_payload = { + "authorization_url": "https://auth.example.com/oauth?client_id=abc&state=xyz", + } with ( app.test_request_context("/?credential_id=cred-1"), patch.object( DatasourceProviderService, "get_oauth_client", - return_value={"client_id": "abc"}, - ), + return_value=oauth_client, + ) as get_oauth_client, patch.object( OAuthProxyService, "create_proxy_context", return_value="ctx-1", - ), + ) as create_proxy_context, patch.object( OAuthHandler, "get_authorization_url", - return_value={"url": "http://auth"}, - ), + return_value=auth_url_payload, + ) as get_authorization_url, ): - response = method(api, "tenant-1", user, "notion") + response = method(api, "tenant-1", user, _PROVIDER_ID) assert response.status_code == 200 + assert response.get_json() == auth_url_payload + assert "context_id=ctx-1" in response.headers.get("Set-Cookie") + provider_id = get_oauth_client.call_args.kwargs["datasource_provider_id"] + assert str(provider_id) == _PROVIDER_ID + get_oauth_client.assert_called_once() + create_proxy_context.assert_called_once_with( + user_id="user-1", + tenant_id="tenant-1", + plugin_id="langgenius/notion_datasource", + provider="notion", + credential_id="cred-1", + ) + get_authorization_url.assert_called_once() + assert get_authorization_url.call_args.kwargs["tenant_id"] == "tenant-1" + assert get_authorization_url.call_args.kwargs["user_id"] == "user-1" + assert get_authorization_url.call_args.kwargs["plugin_id"] == "langgenius/notion_datasource" + assert get_authorization_url.call_args.kwargs["provider"] == "notion" + assert get_authorization_url.call_args.kwargs["system_credentials"] == oauth_client def test_get_no_oauth_config(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() @@ -90,10 +183,10 @@ class TestDatasourcePluginOAuthAuthorizationUrl: patch.object( OAuthHandler, "get_authorization_url", - return_value={"url": "http://auth"}, + return_value={"authorization_url": "http://auth"}, ), ): - response = method(api, "tenant-1", user, "notion") + response = method(api, "tenant-1", user, _PROVIDER_ID) assert response.status_code == 200 assert "context_id" in response.headers.get("Set-Cookie") @@ -106,8 +199,9 @@ class TestDatasourceOAuthCallback: oauth_response = MagicMock() oauth_response.credentials = {"token": "abc"} - oauth_response.expires_at = None - oauth_response.metadata = {"name": "test"} + expires_at = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) + oauth_response.expires_at = expires_at + oauth_response.metadata = {"name": "Workspace Bot", "avatar_url": "https://avatar.example.com/bot.png"} context = { "user_id": "user-1", @@ -125,7 +219,7 @@ class TestDatasourceOAuthCallback: patch.object( DatasourceProviderService, "get_oauth_client", - return_value={"client_id": "abc"}, + return_value={"client_id": "abc", "client_secret": "secret"}, ), patch.object( OAuthHandler, @@ -136,11 +230,22 @@ class TestDatasourceOAuthCallback: DatasourceProviderService, "add_datasource_oauth_provider", return_value=None, - ), + ) as add_oauth_provider, ): - response = method(api, "notion") + response = method(api, _PROVIDER_ID) assert response.status_code == 302 + assert "/oauth-callback" in response.location + add_oauth_provider.assert_called_once() + assert add_oauth_provider.call_args.kwargs == { + "tenant_id": "tenant-1", + "provider_id": add_oauth_provider.call_args.kwargs["provider_id"], + "avatar_url": "https://avatar.example.com/bot.png", + "name": "Workspace Bot", + "expire_at": expires_at, + "credentials": {"token": "abc"}, + } + assert str(add_oauth_provider.call_args.kwargs["provider_id"]) == _PROVIDER_ID def test_callback_missing_context(self, app: Flask): api = DatasourceOAuthCallback() @@ -223,12 +328,16 @@ class TestDatasourceOAuthCallback: DatasourceProviderService, "reauthorize_datasource_oauth_provider", return_value=None, - ), + ) as reauthorize_provider, ): - response = method(api, "notion") + response = method(api, _PROVIDER_ID) assert response.status_code == 302 assert "/oauth-callback" in response.location + reauthorize_provider.assert_called_once() + assert str(reauthorize_provider.call_args.kwargs["provider_id"]) == _PROVIDER_ID + assert reauthorize_provider.call_args.kwargs["credential_id"] == "cred-1" + assert reauthorize_provider.call_args.kwargs["credentials"] == {"token": "abc"} def test_callback_context_id_from_cookie(self, app: Flask): api = DatasourceOAuthCallback() @@ -278,7 +387,14 @@ class TestDatasourceAuth: api = DatasourceAuth() method = inspect.unwrap(api.post) - payload = {"credentials": {"key": "val"}} + payload = { + "name": "Engineering Notion", + "credentials": { + "api_key": "secret-token", + "workspace": "engineering", + "database_id": "db-123", + }, + } with ( app.test_request_context("/", json=payload), @@ -287,11 +403,17 @@ class TestDatasourceAuth: DatasourceProviderService, "add_datasource_api_key_provider", return_value=None, - ), + ) as add_api_key_provider, ): - response, status = method(api, "tenant-1", "notion") + response, status = method(api, "tenant-1", _PROVIDER_ID) + assert response == _success_response() assert status == 200 + add_api_key_provider.assert_called_once() + assert add_api_key_provider.call_args.kwargs["tenant_id"] == "tenant-1" + assert str(add_api_key_provider.call_args.kwargs["provider_id"]) == _PROVIDER_ID + assert add_api_key_provider.call_args.kwargs["credentials"] == payload["credentials"] + assert add_api_key_provider.call_args.kwargs["name"] == "Engineering Notion" def test_post_invalid_credentials(self, app: Flask): api = DatasourceAuth() @@ -321,19 +443,19 @@ class TestDatasourceAuth: patch.object( DatasourceProviderService, "list_datasource_credentials", - return_value=[{"id": "1"}], + return_value=[_datasource_credential()], ), ): - response, status = method(api, "tenant-1", user, "notion") + response, status = method(api, "tenant-1", user, _PROVIDER_ID) assert status == 200 - assert response["result"] + assert response == {"result": [_datasource_credential()]} def test_post_missing_credentials(self, app: Flask): api = DatasourceAuth() method = inspect.unwrap(api.post) - payload = {} + payload: dict[str, object] = {} with ( app.test_request_context("/", json=payload), @@ -375,17 +497,24 @@ class TestDatasourceAuthDeleteApi: DatasourceProviderService, "remove_datasource_credentials", return_value=None, - ), + ) as remove_datasource_credentials, ): - response, status = method(api, "tenant-1", "notion") + response, status = method(api, "tenant-1", _PROVIDER_ID) + assert response == _success_response() assert status == 200 + remove_datasource_credentials.assert_called_once_with( + tenant_id="tenant-1", + auth_id="cred-1", + provider="notion", + plugin_id="langgenius/notion_datasource", + ) def test_delete_missing_credential_id(self, app: Flask): api = DatasourceAuthDeleteApi() method = inspect.unwrap(api.post) - payload = {} + payload: dict[str, object] = {} with ( app.test_request_context("/", json=payload), @@ -400,7 +529,11 @@ class TestDatasourceAuthUpdateApi: api = DatasourceAuthUpdateApi() method = inspect.unwrap(api.post) - payload = {"credential_id": "id", "credentials": {"k": "v"}} + payload = { + "credential_id": "cred-1", + "name": "Updated Notion", + "credentials": {"api_key": "new-secret", "database_id": "db-456"}, + } with ( app.test_request_context("/", json=payload), @@ -409,11 +542,20 @@ class TestDatasourceAuthUpdateApi: DatasourceProviderService, "update_datasource_credentials", return_value=None, - ), + ) as update_datasource_credentials, ): - response, status = method(api, "tenant-1", "notion") + response, status = method(api, "tenant-1", _PROVIDER_ID) + assert response == _success_response() assert status == 201 + update_datasource_credentials.assert_called_once_with( + tenant_id="tenant-1", + auth_id="cred-1", + provider="notion", + plugin_id="langgenius/notion_datasource", + credentials=payload["credentials"], + name="Updated Notion", + ) def test_update_with_credentials_none(self, app: Flask): api = DatasourceAuthUpdateApi() @@ -432,7 +574,9 @@ class TestDatasourceAuthUpdateApi: ): response, status = method(api, "tenant-1", "notion") + assert response == _success_response() update_mock.assert_called_once() + assert update_mock.call_args.kwargs["credentials"] == {} assert status == 201 def test_update_name_only(self, app: Flask): @@ -450,8 +594,9 @@ class TestDatasourceAuthUpdateApi: return_value=None, ), ): - _, status = method(api, "tenant-1", "notion") + response, status = method(api, "tenant-1", "notion") + assert response == _success_response() assert status == 201 def test_update_with_empty_credentials_dict(self, app: Flask): @@ -469,8 +614,9 @@ class TestDatasourceAuthUpdateApi: return_value=None, ) as update_mock, ): - _, status = method(api, "tenant-1", "notion") + response, status = method(api, "tenant-1", "notion") + assert response == _success_response() update_mock.assert_called_once() assert status == 201 @@ -485,12 +631,14 @@ class TestDatasourceAuthListApi: patch.object( DatasourceProviderService, "get_all_datasource_credentials", - return_value=[{"id": "1"}], + return_value=[_datasource_auth()], ), ): response, status = method(api, "tenant-1") assert status == 200 + assert response == {"result": [_datasource_auth()]} + assert response == {"result": [_datasource_auth()]} def test_auth_list_empty(self, app: Flask): api = DatasourceAuthListApi() @@ -537,7 +685,7 @@ class TestDatasourceHardCodeAuthListApi: patch.object( DatasourceProviderService, "get_hard_code_datasource_credentials", - return_value=[{"id": "1"}], + return_value=[_datasource_auth()], ), ): response, status = method(api, "tenant-1") @@ -550,7 +698,14 @@ class TestDatasourceAuthOauthCustomClient: api = DatasourceAuthOauthCustomClient() method = inspect.unwrap(api.post) - payload = {"client_params": {}, "enable_oauth_custom_client": True} + payload = { + "client_params": { + "client_id": "custom-client", + "client_secret": "custom-secret", + "authorize_url": "https://auth.example.com/authorize", + }, + "enable_oauth_custom_client": True, + } with ( app.test_request_context("/", json=payload), @@ -559,11 +714,17 @@ class TestDatasourceAuthOauthCustomClient: DatasourceProviderService, "setup_oauth_custom_client_params", return_value=None, - ), + ) as setup_custom_client, ): - response, status = method(api, "tenant-1", "notion") + response, status = method(api, "tenant-1", _PROVIDER_ID) + assert response == _success_response() assert status == 200 + setup_custom_client.assert_called_once() + assert setup_custom_client.call_args.kwargs["tenant_id"] == "tenant-1" + assert str(setup_custom_client.call_args.kwargs["datasource_provider_id"]) == _PROVIDER_ID + assert setup_custom_client.call_args.kwargs["client_params"] == payload["client_params"] + assert setup_custom_client.call_args.kwargs["enabled"] is True def test_delete_success(self, app: Flask): api = DatasourceAuthOauthCustomClient() @@ -575,17 +736,20 @@ class TestDatasourceAuthOauthCustomClient: DatasourceProviderService, "remove_oauth_custom_client_params", return_value=None, - ), + ) as remove_custom_client, ): - response, status = method(api, "tenant-1", "notion") + response, status = method(api, "tenant-1", _PROVIDER_ID) + assert response == _success_response() assert status == 200 + remove_custom_client.assert_called_once() + assert str(remove_custom_client.call_args.kwargs["datasource_provider_id"]) == _PROVIDER_ID def test_post_empty_payload(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = inspect.unwrap(api.post) - payload = {} + payload: dict[str, object] = {} with ( app.test_request_context("/", json=payload), @@ -596,8 +760,9 @@ class TestDatasourceAuthOauthCustomClient: return_value=None, ), ): - _, status = method(api, "tenant-1", "notion") + response, status = method(api, "tenant-1", "notion") + assert response == _success_response() assert status == 200 def test_post_disabled_flag(self, app: Flask): @@ -618,9 +783,12 @@ class TestDatasourceAuthOauthCustomClient: return_value=None, ) as setup_mock, ): - _, status = method(api, "tenant-1", "notion") + response, status = method(api, "tenant-1", "notion") + assert response == _success_response() setup_mock.assert_called_once() + assert setup_mock.call_args.kwargs["client_params"] == {"a": 1} + assert setup_mock.call_args.kwargs["enabled"] is False assert status == 200 @@ -638,17 +806,22 @@ class TestDatasourceAuthDefaultApi: DatasourceProviderService, "set_default_datasource_provider", return_value=None, - ), + ) as set_default_datasource_provider, ): - response, status = method(api, "tenant-1", "notion") + response, status = method(api, "tenant-1", _PROVIDER_ID) + assert response == _success_response() assert status == 200 + set_default_datasource_provider.assert_called_once() + assert set_default_datasource_provider.call_args.kwargs["tenant_id"] == "tenant-1" + assert str(set_default_datasource_provider.call_args.kwargs["datasource_provider_id"]) == _PROVIDER_ID + assert set_default_datasource_provider.call_args.kwargs["credential_id"] == "cred-1" def test_default_missing_id(self, app: Flask): api = DatasourceAuthDefaultApi() method = inspect.unwrap(api.post) - payload = {} + payload: dict[str, object] = {} with ( app.test_request_context("/", json=payload), @@ -663,7 +836,7 @@ class TestDatasourceUpdateProviderNameApi: api = DatasourceUpdateProviderNameApi() method = inspect.unwrap(api.post) - payload = {"credential_id": "id", "name": "New Name"} + payload = {"credential_id": "cred-1", "name": "New Name"} with ( app.test_request_context("/", json=payload), @@ -672,11 +845,17 @@ class TestDatasourceUpdateProviderNameApi: DatasourceProviderService, "update_datasource_provider_name", return_value=None, - ), + ) as update_datasource_provider_name, ): - response, status = method(api, "tenant-1", "notion") + response, status = method(api, "tenant-1", _PROVIDER_ID) + assert response == _success_response() assert status == 200 + update_datasource_provider_name.assert_called_once() + assert update_datasource_provider_name.call_args.kwargs["tenant_id"] == "tenant-1" + assert str(update_datasource_provider_name.call_args.kwargs["datasource_provider_id"]) == _PROVIDER_ID + assert update_datasource_provider_name.call_args.kwargs["name"] == "New Name" + assert update_datasource_provider_name.call_args.kwargs["credential_id"] == "cred-1" def test_update_name_too_long(self, app: Flask): api = DatasourceUpdateProviderNameApi() diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index 5cc5af9592b..19dc90ed8a4 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -158,3 +158,65 @@ def test_rag_pipeline_workflow_patch_serializes_response_model(app: Flask, monke assert response["id"] == "workflow-1" assert response["marked_name"] == "Updated release" assert response["hash"] == "hash-1" + + +def test_default_rag_pipeline_block_configs_serializes_root_response(monkeypatch: pytest.MonkeyPatch) -> None: + block_configs = [{"type": "start", "config": {"title": "Start"}}] + monkeypatch.setattr( + module, + "RagPipelineService", + lambda: SimpleNamespace(get_default_block_configs=lambda: block_configs), + ) + + api = module.DefaultRagPipelineBlockConfigsApi() + handler = unwrap_all(api.get) + + response = handler(api, _pipeline()) + + assert response == block_configs + + +def test_draft_rag_pipeline_second_step_parameters_serializes_variables(app, monkeypatch: pytest.MonkeyPatch) -> None: + variables = [ + { + "belong_to_node_id": "shared", + "type": "number", + "label": "Chunk size", + "variable": "chunk_size", + "default_value": 1024, + "required": True, + } + ] + monkeypatch.setattr( + module, + "RagPipelineService", + lambda: SimpleNamespace(get_second_step_parameters=lambda **_kwargs: variables), + ) + + api = module.DraftRagPipelineSecondStepApi() + handler = unwrap_all(api.get) + + with app.test_request_context("/?node_id=node-1"): + response = handler(api, _pipeline()) + + assert response["variables"] == variables + + +def test_rag_pipeline_recommended_plugins_serializes_known_envelope(app, monkeypatch: pytest.MonkeyPatch) -> None: + recommended_plugins = { + "installed_recommended_plugins": [{"name": "Dify Extractor", "meta": {"version": "1.0.0"}}], + "uninstalled_recommended_plugins": [{"plugin_id": "langgenius/notion_datasource"}], + } + monkeypatch.setattr( + module, + "RagPipelineService", + lambda: SimpleNamespace(get_recommended_plugins=lambda *_args: recommended_plugins), + ) + + api = module.RagPipelineRecommendedPluginApi() + handler = unwrap_all(api.get) + + with app.test_request_context("/?type=tool"): + response = handler(api, "tenant-1", _account()) + + assert response == recommended_plugins diff --git a/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py index 362af883ed2..43cc2450db5 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py @@ -325,10 +325,12 @@ class TestPipelineRunApiEntity: def test_entity_missing_required_field(self): """Test entity raises on missing required field.""" with pytest.raises(ValueError): - PipelineRunApiEntity( - inputs={}, - datasource_type="online_document", - # missing datasource_info_list, start_node_id, etc. + PipelineRunApiEntity.model_validate( + { + "inputs": {}, + "datasource_type": "online_document", + # missing datasource_info_list, start_node_id, etc. + } ) @@ -382,8 +384,19 @@ class TestDatasourcePluginsApiGet: mock_dataset = Mock() mock_db.session.scalar.return_value = mock_dataset + datasource_plugins = [ + { + "node_id": "node-datasource-1", + "plugin_id": "plugin-a", + "provider_name": "provider-a", + "datasource_type": "online_document", + "title": "Online Docs", + "user_input_variables": [{"variable": "url", "label": "URL", "type": "text-input", "required": True}], + "credentials": [{"id": "cred-1", "name": "Default credential", "type": "oauth2", "is_default": True}], + } + ] mock_svc_instance = Mock() - mock_svc_instance.get_datasource_plugins.return_value = [{"name": "plugin_a"}] + mock_svc_instance.get_datasource_plugins.return_value = datasource_plugins mock_svc_cls.return_value = mock_svc_instance with app.test_request_context("/datasets/test/pipeline/datasource-plugins?is_published=true"): @@ -391,11 +404,33 @@ class TestDatasourcePluginsApiGet: response, status = api.get(tenant_id=tenant_id, dataset_id=dataset_id) assert status == 200 - assert response == [{"name": "plugin_a"}] + assert response == datasource_plugins mock_svc_instance.get_datasource_plugins.assert_called_once_with( tenant_id=tenant_id, dataset_id=dataset_id, is_published=True ) + @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") + @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService") + def test_get_plugins_parses_false_is_published_query(self, mock_svc_cls, mock_db, app: Flask): + """Test false query string is parsed as boolean False.""" + tenant_id = str(uuid.uuid4()) + dataset_id = str(uuid.uuid4()) + + mock_db.session.scalar.return_value = Mock() + mock_svc_instance = Mock() + mock_svc_instance.get_datasource_plugins.return_value = [] + mock_svc_cls.return_value = mock_svc_instance + + with app.test_request_context("/datasets/test/pipeline/datasource-plugins?is_published=false"): + api = DatasourcePluginsApi() + response, status = api.get(tenant_id=tenant_id, dataset_id=dataset_id) + + assert status == 200 + assert response == [] + mock_svc_instance.get_datasource_plugins.assert_called_once_with( + tenant_id=tenant_id, dataset_id=dataset_id, is_published=False + ) + @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") def test_get_plugins_not_found(self, mock_db, app: Flask): """Test NotFound when dataset check fails.""" diff --git a/packages/contracts/generated/api/console/auth/types.gen.ts b/packages/contracts/generated/api/console/auth/types.gen.ts index 99a06128e14..d7c01347da3 100644 --- a/packages/contracts/generated/api/console/auth/types.gen.ts +++ b/packages/contracts/generated/api/console/auth/types.gen.ts @@ -4,8 +4,12 @@ export type ClientOptions = { baseUrl: `${string}://${string}/console/api` | (string & {}) } -export type DatasourceCredentialsResponse = { - result: unknown +export type DatasourceProviderAuthListResponse = { + result: Array +} + +export type DatasourceCredentialListResponse = { + result: Array } export type DatasourceCredentialPayload = { @@ -47,6 +51,83 @@ export type DatasourceUpdateNamePayload = { name: string } +export type DatasourceProviderAuthResponse = { + author: string + credential_schema: Array + credentials_list: Array + description: I18nObject + icon: string + label: I18nObject + name: string + oauth_schema: DatasourceOAuthSchemaResponse | null + plugin_id: string + plugin_unique_identifier: string + provider: string +} + +export type DatasourceCredentialResponse = { + avatar_url: string | null + credential: { + [key: string]: unknown + } + id: string + is_default: boolean + name: string + type: string +} + +export type ProviderConfig = { + default?: number | string | number | boolean | null + help?: I18nObject | null + label?: I18nObject | null + multiple?: boolean + name: string + options?: Array