diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index cc07084dea..8c5f91cb7f 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -50,8 +50,8 @@ class PipelineTemplateDetailApi(Resource): @login_required @account_initialization_required @enterprise_license_required - def get(self, pipeline_id: str): - pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id) + def get(self, template_id: str): + pipeline_template = RagPipelineService.get_pipeline_template_detail(template_id) return pipeline_template, 200 @@ -120,7 +120,7 @@ api.add_resource( ) api.add_resource( PipelineTemplateDetailApi, - "/rag/pipeline/templates/", + "/rag/pipeline/templates/", ) api.add_resource( CustomizedPipelineTemplateApi, diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index e9f73d3c18..22ec5c3d23 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -4,6 +4,8 @@ from typing import Any, Optional from pydantic import BaseModel, Field, ValidationInfo, field_validator +from core.entities.provider_entities import ProviderConfig +from core.plugin.entities.oauth import OAuthSchema from core.plugin.entities.parameters import ( PluginParameter, PluginParameterOption, @@ -13,7 +15,7 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderEntity +from core.tools.entities.tool_entities import ToolLabelEnum, ToolProviderEntity class DatasourceProviderType(enum.StrEnum): @@ -118,29 +120,36 @@ class DatasourceIdentity(BaseModel): icon: Optional[str] = None -class DatasourceDescription(BaseModel): - human: I18nObject = Field(..., description="The description presented to the user") - llm: str = Field(..., description="The description presented to the LLM") - - class DatasourceEntity(BaseModel): identity: DatasourceIdentity parameters: list[DatasourceParameter] = Field(default_factory=list) - description: Optional[DatasourceDescription] = None + description: I18nObject = Field(..., description="The label of the datasource") output_schema: Optional[dict] = None - has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") @field_validator("parameters", mode="before") @classmethod def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: return v or [] +class DatasourceProviderIdentity(BaseModel): + author: str = Field(..., description="The author of the tool") + name: str = Field(..., description="The name of the tool") + description: I18nObject = Field(..., description="The description of the tool") + icon: str = Field(..., description="The icon of the tool") + label: I18nObject = Field(..., description="The label of the tool") + tags: Optional[list[ToolLabelEnum]] = Field( + default=[], + description="The tags of the tool", + ) -class DatasourceProviderEntity(ToolProviderEntity): + +class DatasourceProviderEntity(BaseModel): """ Datasource provider entity """ - + identity: DatasourceProviderIdentity + credentials_schema: list[ProviderConfig] = Field(default_factory=list) + oauth_schema: Optional[OAuthSchema] = None provider_type: DatasourceProviderType @@ -202,7 +211,6 @@ class GetOnlineDocumentPagesRequest(BaseModel): Get online document pages request """ - tenant_id: str = Field(..., description="The tenant id") class OnlineDocumentPageIcon(BaseModel): @@ -276,8 +284,6 @@ class GetWebsiteCrawlRequest(BaseModel): """ Get website crawl request """ - - url: str = Field(..., description="The url of the website") crawl_parameters: dict = Field(..., description="The crawl parameters") @@ -297,4 +303,4 @@ class GetWebsiteCrawlResponse(BaseModel): Get website crawl response """ - result: WebSiteInfo + result: list[WebSiteInfo] diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index 07d7a25160..7809ac2a89 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,3 +1,4 @@ +from typing import Any, Mapping from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( @@ -34,7 +35,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): def _get_online_document_pages( self, user_id: str, - datasource_parameters: GetOnlineDocumentPagesRequest, + datasource_parameters: Mapping[str, Any], provider_type: str, ) -> GetOnlineDocumentPagesResponse: manager = PluginDatasourceManager() diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index 5f92551198..e657fceb9c 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -1,3 +1,4 @@ +from typing import Any, Mapping from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( @@ -32,7 +33,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): def _get_website_crawl( self, user_id: str, - datasource_parameters: GetWebsiteCrawlRequest, + datasource_parameters: Mapping[str, Any], provider_type: str, ) -> GetWebsiteCrawlResponse: manager = PluginDatasourceManager() diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 3b0defbb08..90086173fa 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -52,7 +52,6 @@ class PluginDatasourceProviderEntity(BaseModel): provider: str plugin_unique_identifier: str plugin_id: str - author: str declaration: DatasourceProviderEntityWithPlugin diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 80d868c1af..430a9a6c01 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -1,12 +1,10 @@ -from typing import Any +from typing import Any, Mapping from core.datasource.entities.api_entities import DatasourceProviderApiEntity from core.datasource.entities.datasource_entities import ( GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentResponse, - GetOnlineDocumentPagesRequest, GetOnlineDocumentPagesResponse, - GetWebsiteCrawlRequest, GetWebsiteCrawlResponse, ) from core.plugin.entities.plugin import GenericProviderID, ToolProviderID @@ -86,7 +84,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_provider: str, datasource_name: str, credentials: dict[str, Any], - datasource_parameters: GetWebsiteCrawlRequest, + datasource_parameters: Mapping[str, Any], provider_type: str, ) -> GetWebsiteCrawlResponse: """ @@ -125,7 +123,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_provider: str, datasource_name: str, credentials: dict[str, Any], - datasource_parameters: GetOnlineDocumentPagesRequest, + datasource_parameters: Mapping[str, Any], provider_type: str, ) -> GetOnlineDocumentPagesResponse: """ diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index bf582b9d27..3bee0538ab 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -67,15 +67,15 @@ class RagPipelineService: return result.get("pipeline_templates") @classmethod - def get_pipeline_template_detail(cls, pipeline_id: str) -> Optional[dict]: + def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]: """ Get pipeline template detail. - :param pipeline_id: pipeline id + :param template_id: template id :return: """ mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE - retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode) - result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(pipeline_id) + retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() + result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) return result @classmethod @@ -427,7 +427,7 @@ class RagPipelineService: online_document_result: GetOnlineDocumentPagesResponse = ( datasource_runtime._get_online_document_pages( user_id=account.id, - datasource_parameters=GetOnlineDocumentPagesRequest(tenant_id=pipeline.tenant_id), + datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) ) @@ -440,11 +440,11 @@ class RagPipelineService: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( user_id=account.id, - datasource_parameters=GetWebsiteCrawlRequest(**user_inputs), + datasource_parameters=user_inputs, provider_type=datasource_runtime.datasource_provider_type(), ) return { - "result": website_crawl_result.result.model_dump(), + "result": [result.model_dump() for result in website_crawl_result.result], "provider_type": datasource_node_data.get("provider_type"), } else: