diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index b040e70b92..da6db303cd 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -428,7 +428,6 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): ) - class RagPipelineDraftDatasourceNodeRunApi(Resource): @setup_required @login_required diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 96f8f01032..8769bcea0d 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -589,7 +589,7 @@ class PipelineGenerator(BaseAppGenerator): if datasource_type == "local_file": name = datasource_info["name"] elif datasource_type == "online_document": - name = datasource_info['page']["page_name"] + name = datasource_info["page"]["page_name"] elif datasource_type == "website_crawl": name = datasource_info["title"] else: diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index a2351a19da..53104b0061 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -7,7 +7,7 @@ from core.datasource.entities.datasource_entities import ( OnlineDocumentPagesMessage, WebsiteCrawlMessage, ) -from core.plugin.entities.plugin import GenericProviderID, ToolProviderID, DatasourceProviderID +from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDatasourceProviderEntity, @@ -41,8 +41,8 @@ class PluginDatasourceManager(BasePluginClient): ) local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) - # for provider in response: - # ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider) + for provider in response: + ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider) all_response = [local_file_datasource_provider] + response for provider in all_response: diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py index 4acb558531..a36e32fc9c 100644 --- a/api/core/rag/entities/event.py +++ b/api/core/rag/entities/event.py @@ -9,6 +9,7 @@ class DatasourceStreamEvent(Enum): """ Datasource Stream event """ + PROCESSING = "datasource_processing" COMPLETED = "datasource_completed" ERROR = "datasource_error" @@ -17,19 +18,21 @@ class DatasourceStreamEvent(Enum): class BaseDatasourceEvent(BaseModel): pass + class DatasourceErrorEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.ERROR.value error: str = Field(..., description="error message") + class DatasourceCompletedEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.COMPLETED.value - data: Mapping[str,Any] | list = Field(..., description="result") + data: Mapping[str, Any] | list = Field(..., description="result") total: Optional[int] = Field(default=0, description="total") completed: Optional[int] = Field(default=0, description="completed") time_consuming: Optional[float] = Field(default=0.0, description="time consuming") + class DatasourceProcessingEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.PROCESSING.value total: Optional[int] = Field(..., description="total") completed: Optional[int] = Field(..., description="completed") - diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 3f82bda2c6..e382ff6b54 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -68,12 +68,15 @@ class QAChunk(BaseModel): question: str answer: str + class QAStructureChunk(BaseModel): """ QAStructureChunk. """ + qa_chunks: list[QAChunk] + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 063216dd49..9a4939502e 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -273,5 +273,3 @@ class AgentLogEvent(BaseAgentEvent): InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent - - diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 1f9337665b..333d559bf5 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -69,6 +69,7 @@ from services.rag_pipeline.pipeline_template.pipeline_template_factory import Pi logger = logging.getLogger(__name__) + class RagPipelineService: @classmethod def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: @@ -431,8 +432,13 @@ class RagPipelineService: return workflow_node_execution def run_datasource_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, - is_published: bool + self, + pipeline: Pipeline, + node_id: str, + user_inputs: dict, + account: Account, + datasource_type: str, + is_published: bool, ) -> Generator[BaseDatasourceEvent, None, None]: """ Run published workflow datasource @@ -497,23 +503,21 @@ class RagPipelineService: for message in online_document_result: end_time = time.time() online_document_event = DatasourceCompletedEvent( - data=message.result, - time_consuming=round(end_time - start_time, 2) + data=message.result, time_consuming=round(end_time - start_time, 2) ) yield online_document_event.model_dump() except Exception as e: logger.exception("Error during online document.") - yield DatasourceErrorEvent( - error=str(e) - ).model_dump() + yield DatasourceErrorEvent(error=str(e)).model_dump() case DatasourceProviderType.WEBSITE_CRAWL: datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = ( datasource_runtime.get_website_crawl( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - )) + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), + ) + ) start_time = time.time() try: for message in website_crawl_result: @@ -523,7 +527,7 @@ class RagPipelineService: data=message.result.web_info_list, total=message.result.total, completed=message.result.completed, - time_consuming=round(end_time - start_time, 2) + time_consuming=round(end_time - start_time, 2), ) else: crawl_event = DatasourceProcessingEvent( @@ -533,16 +537,12 @@ class RagPipelineService: yield crawl_event.model_dump() except Exception as e: logger.exception("Error during website crawl.") - yield DatasourceErrorEvent( - error=str(e) - ).model_dump() + yield DatasourceErrorEvent(error=str(e)).model_dump() case _: raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") except Exception as e: logger.exception("Error in run_datasource_workflow_node.") - yield DatasourceErrorEvent( - error=str(e) - ).model_dump() + yield DatasourceErrorEvent(error=str(e)).model_dump() def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 367121125b..8a73c73a1b 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,6 +5,7 @@ from typing import Optional, Union, cast from yarl import URL from configs import dify_config +from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -56,7 +57,7 @@ class ToolTransformService: return "" @staticmethod - def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]): + def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]): """ repack provider @@ -77,6 +78,17 @@ class ToolTransformService: provider.icon = ToolTransformService.get_tool_provider_icon_url( provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon ) + elif isinstance(provider, PluginDatasourceProviderEntity): + if provider.plugin_id: + if isinstance(provider.declaration.identity.icon, str): + provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url( + tenant_id=tenant_id, filename=provider.declaration.identity.icon + ) + else: + provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url( + provider_type=provider.type.value, provider_name=provider.name, + icon=provider.declaration.identity.icon + ) @classmethod def builtin_provider_to_user_provider(