diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index f0f6cd66e6..0442a121c0 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -6,7 +6,7 @@ on: - "main" - "deploy/dev" - "deploy/enterprise" - - "feat/r2" + - "deploy/rag-dev" tags: - "*" diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index 47ca03c2eb..0d99c6fa58 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -4,7 +4,7 @@ on: workflow_run: workflows: ["Build and Push API & Web"] branches: - - "deploy/dev" + - "deploy/rag-dev" types: - completed @@ -12,12 +12,13 @@ jobs: deploy: runs-on: ubuntu-latest if: | - github.event.workflow_run.conclusion == 'success' + github.event.workflow_run.conclusion == 'success' && + github.event.workflow_run.head_branch == 'deploy/rag-dev' steps: - name: Deploy to server uses: appleboy/ssh-action@v0.1.8 with: - host: ${{ secrets.SSH_HOST }} + host: ${{ secrets.RAG_SSH_HOST }} username: ${{ secrets.SSH_USER }} key: ${{ secrets.SSH_PRIVATE_KEY }} script: | diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 28ee7395d6..312a870472 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -86,6 +86,7 @@ from .datasets import ( ) from .datasets.rag_pipeline import ( datasource_auth, + datasource_content_preview, rag_pipeline, rag_pipeline_datasets, rag_pipeline_import, diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py new file mode 100644 index 0000000000..32b5f68364 --- /dev/null +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -0,0 +1,54 @@ +from flask_restful import ( # type: ignore + Resource, # type: ignore + reqparse, +) +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.datasets.wraps import get_rag_pipeline +from controllers.console.wraps import account_initialization_required, setup_required +from libs.login import current_user, login_required +from models import Account +from models.dataset import Pipeline +from services.rag_pipeline.rag_pipeline import RagPipelineService + + +class DataSourceContentPreviewApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, node_id: str): + """ + Run datasource content preview + """ + if not isinstance(current_user, Account): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("datasource_type", type=str, required=True, location="json") + args = parser.parse_args() + + inputs = args.get("inputs") + if inputs is None: + raise ValueError("missing inputs") + datasource_type = args.get("datasource_type") + if datasource_type is None: + raise ValueError("missing datasource_type") + + rag_pipeline_service = RagPipelineService() + preview_content = rag_pipeline_service.run_datasource_node_preview( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=True, + ) + return preview_content, 200 + +api.add_resource( + DataSourceContentPreviewApi, + "/rag/pipelines//workflows/published/datasource/nodes//preview" +) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 3ef0c42d0f..8bae9dc466 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -414,17 +414,19 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() - result = rag_pipeline_service.run_datasource_workflow_node( - pipeline=pipeline, - node_id=node_id, - user_inputs=inputs, - account=current_user, - datasource_type=datasource_type, - is_published=True, + return helper.compact_generate_response( + PipelineGenerator.convert_to_event_stream( + rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=False, + ) + ) ) - return result - class RagPipelineDraftDatasourceNodeRunApi(Resource): @setup_required @@ -455,21 +457,18 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): raise ValueError("missing datasource_type") rag_pipeline_service = RagPipelineService() - try: - return helper.compact_generate_response( - PipelineGenerator.convert_to_event_stream( - rag_pipeline_service.run_datasource_workflow_node( - pipeline=pipeline, - node_id=node_id, - user_inputs=inputs, - account=current_user, - datasource_type=datasource_type, - is_published=False, - ) + return helper.compact_generate_response( + PipelineGenerator.convert_to_event_stream( + rag_pipeline_service.run_datasource_workflow_node( + pipeline=pipeline, + node_id=node_id, + user_inputs=inputs, + account=current_user, + datasource_type=datasource_type, + is_published=False, ) ) - except Exception as e: - print(e) + ) class RagPipelinePublishedNodeRunApi(Resource): diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 7c0bbc46d9..13acc4ef38 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -2,14 +2,14 @@ import contextvars import datetime import json import logging -import random +import secrets import threading import time import uuid from collections.abc import Generator, Mapping from typing import Any, Literal, Optional, Union, overload -from flask import Flask, copy_current_request_context, current_app, has_request_context +from flask import Flask, current_app from pydantic import ValidationError from sqlalchemy.orm import sessionmaker @@ -110,7 +110,7 @@ class PipelineGenerator(BaseAppGenerator): start_node_id: str = args["start_node_id"] datasource_type: str = args["datasource_type"] datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] - batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000) documents = [] if invoke_from == InvokeFrom.PUBLISHED: for datasource_info in datasource_info_list: @@ -589,7 +589,7 @@ class PipelineGenerator(BaseAppGenerator): if datasource_type == "local_file": name = datasource_info["name"] elif datasource_type == "online_document": - name = datasource_info["page_title"] + name = datasource_info["page"]["page_name"] elif datasource_type == "website_crawl": name = datasource_info["title"] else: diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 41be2dcc3d..2c3de1e5d7 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -214,10 +214,11 @@ class OnlineDocumentPage(BaseModel): """ page_id: str = Field(..., description="The page id") - page_title: str = Field(..., description="The page title") + page_name: str = Field(..., description="The page title") page_icon: Optional[dict] = Field(None, description="The page icon") type: str = Field(..., description="The type of the page") last_edited_time: str = Field(..., description="The last edited time") + parent_id: Optional[str] = Field(None, description="The parent page id") class OnlineDocumentInfo(BaseModel): diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 6dc5ebca6b..a69e88baf4 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -141,7 +141,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_provider_id = GenericProviderID(datasource_provider) - response = self._request_with_plugin_daemon_response_stream( + return self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", OnlineDocumentPagesMessage, @@ -159,7 +159,6 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - yield from response def get_online_document_page_content( self, @@ -177,7 +176,7 @@ class PluginDatasourceManager(BasePluginClient): datasource_provider_id = GenericProviderID(datasource_provider) - response = self._request_with_plugin_daemon_response_stream( + return self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content", DatasourceMessage, @@ -195,7 +194,6 @@ class PluginDatasourceManager(BasePluginClient): "Content-Type": "application/json", }, ) - yield from response def online_drive_browse_files( self, diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py index 4921c94557..a36e32fc9c 100644 --- a/api/core/rag/entities/event.py +++ b/api/core/rag/entities/event.py @@ -9,22 +9,30 @@ class DatasourceStreamEvent(Enum): """ Datasource Stream event """ + PROCESSING = "datasource_processing" COMPLETED = "datasource_completed" + ERROR = "datasource_error" class BaseDatasourceEvent(BaseModel): pass + +class DatasourceErrorEvent(BaseDatasourceEvent): + event: str = DatasourceStreamEvent.ERROR.value + error: str = Field(..., description="error message") + + class DatasourceCompletedEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.COMPLETED.value - data: Mapping[str,Any] | list = Field(..., description="result") - total: Optional[int] = Field(..., description="total") - completed: Optional[int] = Field(..., description="completed") - time_consuming: Optional[float] = Field(..., description="time consuming") + data: Mapping[str, Any] | list = Field(..., description="result") + total: Optional[int] = Field(default=0, description="total") + completed: Optional[int] = Field(default=0, description="completed") + time_consuming: Optional[float] = Field(default=0.0, description="time consuming") + class DatasourceProcessingEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.PROCESSING.value total: Optional[int] = Field(..., description="total") completed: Optional[int] = Field(..., description="completed") - diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 3f82bda2c6..e382ff6b54 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -68,12 +68,15 @@ class QAChunk(BaseModel): question: str answer: str + class QAStructureChunk(BaseModel): """ QAStructureChunk. """ + qa_chunks: list[QAChunk] + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 89149c91db..e57e9e4d64 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -275,5 +275,3 @@ class AgentLogEvent(BaseAgentEvent): InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent - - diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index da6ba0fba5..6427c526d6 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1,8 +1,9 @@ import json +import logging import re import threading import time -from collections.abc import Callable, Generator, Sequence +from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime from typing import Any, Optional, cast from uuid import uuid4 @@ -15,16 +16,20 @@ import contexts from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.datasource.entities.datasource_entities import ( + DatasourceMessage, DatasourceProviderType, + GetOnlineDocumentPageContentRequest, OnlineDocumentPagesMessage, - OnlineDriveBrowseFilesRequest, - OnlineDriveBrowseFilesResponse, WebsiteCrawlMessage, ) from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin -from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin -from core.rag.entities.event import BaseDatasourceEvent, DatasourceCompletedEvent, DatasourceProcessingEvent +from core.rag.entities.event import ( + BaseDatasourceEvent, + DatasourceCompletedEvent, + DatasourceErrorEvent, + DatasourceProcessingEvent, +) from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.variables.variables import Variable from core.workflow.entities.node_entities import NodeRunResult @@ -64,6 +69,8 @@ from services.entities.knowledge_entities.rag_pipeline_entities import ( from services.errors.app import WorkflowHashNotEqualError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory +logger = logging.getLogger(__name__) + class RagPipelineService: @classmethod @@ -116,14 +123,6 @@ class RagPipelineService: ) if not customized_template: raise ValueError("Customized pipeline template not found.") - # check template name is exist - template_name = template_info.name - if template_name: - template = db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.name == template_name, - PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, - PipelineCustomizedTemplate.id != template_id).first() - if template: - raise ValueError("Template name is already exists") customized_template.name = template_info.name customized_template.description = template_info.description customized_template.icon = template_info.icon_info.model_dump() @@ -434,157 +433,210 @@ class RagPipelineService: return workflow_node_execution - # def run_datasource_workflow_node_status( - # self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, - # datasource_type: str, is_published: bool - # ) -> dict: - # """ - # Run published workflow datasource - # """ - # if is_published: - # # fetch published workflow by app_model - # workflow = self.get_published_workflow(pipeline=pipeline) - # else: - # workflow = self.get_draft_workflow(pipeline=pipeline) - # if not workflow: - # raise ValueError("Workflow not initialized") - # - # # run draft workflow node - # datasource_node_data = None - # start_at = time.perf_counter() - # datasource_nodes = workflow.graph_dict.get("nodes", []) - # for datasource_node in datasource_nodes: - # if datasource_node.get("id") == node_id: - # datasource_node_data = datasource_node.get("data", {}) - # break - # if not datasource_node_data: - # raise ValueError("Datasource node data not found") - # - # from core.datasource.datasource_manager import DatasourceManager - # - # datasource_runtime = DatasourceManager.get_datasource_runtime( - # provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", - # datasource_name=datasource_node_data.get("datasource_name"), - # tenant_id=pipeline.tenant_id, - # datasource_type=DatasourceProviderType(datasource_type), - # ) - # datasource_provider_service = DatasourceProviderService() - # credentials = datasource_provider_service.get_real_datasource_credentials( - # tenant_id=pipeline.tenant_id, - # provider=datasource_node_data.get('provider_name'), - # plugin_id=datasource_node_data.get('plugin_id'), - # ) - # if credentials: - # datasource_runtime.runtime.credentials = credentials[0].get("credentials") - # match datasource_type: - # - # case DatasourceProviderType.WEBSITE_CRAWL: - # datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - # website_crawl_results: list[WebsiteCrawlMessage] = [] - # for website_message in datasource_runtime.get_website_crawl( - # user_id=account.id, - # datasource_parameters={"job_id": job_id}, - # provider_type=datasource_runtime.datasource_provider_type(), - # ): - # website_crawl_results.append(website_message) - # return { - # "result": [result for result in website_crawl_results.result], - # "status": website_crawl_results.result.status, - # "provider_type": datasource_node_data.get("provider_type"), - # } - # case _: - # raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") - def run_datasource_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, - is_published: bool + self, + pipeline: Pipeline, + node_id: str, + user_inputs: dict, + account: Account, + datasource_type: str, + is_published: bool, ) -> Generator[BaseDatasourceEvent, None, None]: """ Run published workflow datasource """ - if is_published: - # fetch published workflow by app_model - workflow = self.get_published_workflow(pipeline=pipeline) - else: - workflow = self.get_draft_workflow(pipeline=pipeline) - if not workflow: - raise ValueError("Workflow not initialized") + try: + if is_published: + # fetch published workflow by app_model + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") - # run draft workflow node - datasource_node_data = None - start_at = time.perf_counter() - datasource_nodes = workflow.graph_dict.get("nodes", []) - for datasource_node in datasource_nodes: - if datasource_node.get("id") == node_id: - datasource_node_data = datasource_node.get("data", {}) - break - if not datasource_node_data: - raise ValueError("Datasource node data not found") + # run draft workflow node + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") - datasource_parameters = datasource_node_data.get("datasource_parameters", {}) - for key, value in datasource_parameters.items(): - if not user_inputs.get(key): - user_inputs[key] = value["value"] + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + for key, value in datasource_parameters.items(): + if not user_inputs.get(key): + user_inputs[key] = value["value"] - from core.datasource.datasource_manager import DatasourceManager + from core.datasource.datasource_manager import DatasourceManager - datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", - datasource_name=datasource_node_data.get("datasource_name"), - tenant_id=pipeline.tenant_id, - datasource_type=DatasourceProviderType(datasource_type), - ) - datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_real_datasource_credentials( - tenant_id=pipeline.tenant_id, - provider=datasource_node_data.get("provider_name"), - plugin_id=datasource_node_data.get("plugin_id"), - ) - if credentials: - datasource_runtime.runtime.credentials = credentials[0].get("credentials") - match datasource_type: - case DatasourceProviderType.ONLINE_DOCUMENT: - datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( - datasource_runtime.get_online_document_pages( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) - ) - start_time = time.time() - for message in online_document_result: - end_time = time.time() - online_document_event = DatasourceCompletedEvent( - data=message.result, - time_consuming=round(end_time - start_time, 2), - ) - yield online_document_event.model_dump() - case DatasourceProviderType.WEBSITE_CRAWL: - datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) - website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl( - user_id=account.id, - datasource_parameters=user_inputs, - provider_type=datasource_runtime.datasource_provider_type(), - ) - start_time = time.time() - for message in website_crawl_result: - end_time = time.time() - if message.result.status == "completed": - crawl_event = DatasourceCompletedEvent( - data=message.result.web_info_list, - total=message.result.total, - completed=message.result.completed, - time_consuming=round(end_time - start_time, 2) + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), + ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id"), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( + datasource_runtime.get_online_document_pages( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), ) - else: - crawl_event = DatasourceProcessingEvent( - total=message.result.total, - completed=message.result.completed, + ) + start_time = time.time() + start_event = DatasourceProcessingEvent( + total=0, + completed=0, + ) + yield start_event.model_dump() + try: + for message in online_document_result: + end_time = time.time() + online_document_event = DatasourceCompletedEvent( + data=message.result, time_consuming=round(end_time - start_time, 2) + ) + yield online_document_event.model_dump() + except Exception as e: + logger.exception("Error during online document.") + yield DatasourceErrorEvent(error=str(e)).model_dump() + case DatasourceProviderType.WEBSITE_CRAWL: + datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) + website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = ( + datasource_runtime.get_website_crawl( + user_id=account.id, + datasource_parameters=user_inputs, + provider_type=datasource_runtime.datasource_provider_type(), ) - yield crawl_event.model_dump() - case _: - raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + ) + start_time = time.time() + try: + for message in website_crawl_result: + end_time = time.time() + if message.result.status == "completed": + crawl_event = DatasourceCompletedEvent( + data=message.result.web_info_list, + total=message.result.total, + completed=message.result.completed, + time_consuming=round(end_time - start_time, 2), + ) + else: + crawl_event = DatasourceProcessingEvent( + total=message.result.total, + completed=message.result.completed, + ) + yield crawl_event.model_dump() + except Exception as e: + logger.exception("Error during website crawl.") + yield DatasourceErrorEvent(error=str(e)).model_dump() + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + except Exception as e: + logger.exception("Error in run_datasource_workflow_node.") + yield DatasourceErrorEvent(error=str(e)).model_dump() + + def run_datasource_node_preview( + self, + pipeline: Pipeline, + node_id: str, + user_inputs: dict, + account: Account, + datasource_type: str, + is_published: bool, + ) -> Mapping[str, Any]: + """ + Run published workflow datasource + """ + try: + if is_published: + # fetch published workflow by app_model + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # run draft workflow node + datasource_node_data = None + datasource_nodes = workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") + + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + for key, value in datasource_parameters.items(): + if not user_inputs.get(key): + user_inputs[key] = value["value"] + + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}", + datasource_name=datasource_node_data.get("datasource_name"), + tenant_id=pipeline.tenant_id, + datasource_type=DatasourceProviderType(datasource_type), + ) + datasource_provider_service = DatasourceProviderService() + credentials = datasource_provider_service.get_real_datasource_credentials( + tenant_id=pipeline.tenant_id, + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id"), + ) + if credentials: + datasource_runtime.runtime.credentials = credentials[0].get("credentials") + match datasource_type: + case DatasourceProviderType.ONLINE_DOCUMENT: + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: Generator[DatasourceMessage, None, None] = ( + datasource_runtime.get_online_document_page_content( + user_id=account.id, + datasource_parameters=GetOnlineDocumentPageContentRequest( + workspace_id=user_inputs.get("workspace_id"), + page_id=user_inputs.get("page_id"), + type=user_inputs.get("type"), + ), + provider_type=datasource_type, + ) + ) + try: + variables: dict[str, Any] = {} + for message in online_document_result: + if message.type == DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + else: + variables[variable_name] = variable_value + return variables + except Exception as e: + logger.exception("Error during get online document content.") + raise RuntimeError(str(e)) + #TODO Online Drive + case _: + raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") + except Exception as e: + logger.exception("Error in run_datasource_node_preview.") + raise RuntimeError(str(e)) def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] @@ -755,24 +807,77 @@ class RagPipelineService: return workflow - def get_first_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: + def get_published_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + """ + Get second step parameters of rag pipeline + """ + + workflow = self.get_published_workflow(pipeline=pipeline) + if not workflow: + raise ValueError("Workflow not initialized") + + # get second step node + rag_pipeline_variables = workflow.rag_pipeline_variables + if not rag_pipeline_variables: + return [] + + # get datasource provider + datasource_provider_variables = [ + item + for item in rag_pipeline_variables + if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" + ] + return datasource_provider_variables + + def get_published_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: """ Get first step parameters of rag pipeline """ - workflow = self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline) - if not workflow: + published_workflow = self.get_published_workflow(pipeline=pipeline) + if not published_workflow: raise ValueError("Workflow not initialized") + # get second step node datasource_node_data = None - datasource_nodes = workflow.graph_dict.get("nodes", []) + datasource_nodes = published_workflow.graph_dict.get("nodes", []) for datasource_node in datasource_nodes: if datasource_node.get("id") == node_id: datasource_node_data = datasource_node.get("data", {}) break if not datasource_node_data: raise ValueError("Datasource node data not found") - variables = workflow.rag_pipeline_variables + variables = datasource_node_data.get("variables", {}) + if variables: + variables_map = {item["variable"]: item for item in variables} + else: + return [] + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + user_input_variables = [] + for key, value in datasource_parameters.items(): + if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): + user_input_variables.append(variables_map.get(key, {})) + return user_input_variables + + def get_draft_first_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: + """ + Get first step parameters of rag pipeline + """ + + draft_workflow = self.get_draft_workflow(pipeline=pipeline) + if not draft_workflow: + raise ValueError("Workflow not initialized") + + # get second step node + datasource_node_data = None + datasource_nodes = draft_workflow.graph_dict.get("nodes", []) + for datasource_node in datasource_nodes: + if datasource_node.get("id") == node_id: + datasource_node_data = datasource_node.get("data", {}) + break + if not datasource_node_data: + raise ValueError("Datasource node data not found") + variables = datasource_node_data.get("variables", {}) if variables: variables_map = {item["variable"]: item for item in variables} else: @@ -781,21 +886,16 @@ class RagPipelineService: user_input_variables = [] for key, value in datasource_parameters.items(): - if value.get("value") and isinstance(value.get("value"), str): - pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" - match = re.match(pattern, value["value"]) - if match: - full_path = match.group(1) - last_part = full_path.split('.')[-1] - user_input_variables.append(variables_map.get(last_part, {})) + if not re.match(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}", value["value"]): + user_input_variables.append(variables_map.get(key, {})) return user_input_variables - def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]: + def get_draft_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> list[dict]: """ Get second step parameters of rag pipeline """ - workflow = self.get_draft_workflow(pipeline=pipeline) if is_draft else self.get_published_workflow(pipeline=pipeline) + workflow = self.get_draft_workflow(pipeline=pipeline) if not workflow: raise ValueError("Workflow not initialized") @@ -803,32 +903,13 @@ class RagPipelineService: rag_pipeline_variables = workflow.rag_pipeline_variables if not rag_pipeline_variables: return [] - variables_map = {item["variable"]: item for item in rag_pipeline_variables} - # get datasource node data - datasource_node_data = None - datasource_nodes = workflow.graph_dict.get("nodes", []) - for datasource_node in datasource_nodes: - if datasource_node.get("id") == node_id: - datasource_node_data = datasource_node.get("data", {}) - break - if datasource_node_data: - datasource_parameters = datasource_node_data.get("datasource_parameters", {}) - - for key, value in datasource_parameters.items(): - if value.get("value") and isinstance(value.get("value"), str): - pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" - match = re.match(pattern, value["value"]) - if match: - full_path = match.group(1) - last_part = full_path.split('.')[-1] - variables_map.pop(last_part) - all_second_step_variables = list(variables_map.values()) + # get datasource provider datasource_provider_variables = [ - item - for item in all_second_step_variables - if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" - ] + item + for item in rag_pipeline_variables + if item.get("belong_to_node_id") == node_id or item.get("belong_to_node_id") == "shared" + ] return datasource_provider_variables def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: @@ -950,16 +1031,6 @@ class RagPipelineService: if not dataset: raise ValueError("Dataset not found") - # check template name is exist - template_name = args.get("name") - if template_name: - template = db.session.query(PipelineCustomizedTemplate).filter( - PipelineCustomizedTemplate.name == template_name, - PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id, - ).first() - if template: - raise ValueError("Template name is already exists") - max_position = ( db.session.query(func.max(PipelineCustomizedTemplate.position)) .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index 7d5d4cb52d..e3354ade13 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -43,10 +43,10 @@ const AppDetailLayout: FC = (props) => { const media = useBreakpoints() const isMobile = media === MediaType.mobile const { isCurrentWorkspaceEditor, isLoadingCurrentWorkspace } = useAppContext() - const { appDetail, setAppDetail, setAppSiderbarExpand } = useStore(useShallow(state => ({ + const { appDetail, setAppDetail, setAppSidebarExpand } = useStore(useShallow(state => ({ appDetail: state.appDetail, setAppDetail: state.setAppDetail, - setAppSiderbarExpand: state.setAppSiderbarExpand, + setAppSidebarExpand: state.setAppSidebarExpand, }))) const [isLoadingAppDetail, setIsLoadingAppDetail] = useState(false) const [appDetailRes, setAppDetailRes] = useState(null) @@ -57,8 +57,8 @@ const AppDetailLayout: FC = (props) => { selectedIcon: NavIcon }>>([]) - const getNavigations = useCallback((appId: string, isCurrentWorkspaceEditor: boolean, mode: string) => { - const navs = [ + const getNavigationConfig = useCallback((appId: string, isCurrentWorkspaceEditor: boolean, mode: string) => { + const navConfig = [ ...(isCurrentWorkspaceEditor ? [{ name: t('common.appMenus.promptEng'), @@ -92,8 +92,8 @@ const AppDetailLayout: FC = (props) => { selectedIcon: RiDashboard2Fill, }, ] - return navs - }, []) + return navConfig + }, [t]) useDocumentTitle(appDetail?.name || t('common.menus.appDetail')) @@ -101,10 +101,10 @@ const AppDetailLayout: FC = (props) => { if (appDetail) { const localeMode = localStorage.getItem('app-detail-collapse-or-expand') || 'expand' const mode = isMobile ? 'collapse' : 'expand' - setAppSiderbarExpand(isMobile ? mode : localeMode) + setAppSidebarExpand(isMobile ? mode : localeMode) // TODO: consider screen size and mode // if ((appDetail.mode === 'advanced-chat' || appDetail.mode === 'workflow') && (pathname).endsWith('workflow')) - // setAppSiderbarExpand('collapse') + // setAppSidebarExpand('collapse') } // eslint-disable-next-line react-hooks/exhaustive-deps }, [appDetail, isMobile]) @@ -141,7 +141,7 @@ const AppDetailLayout: FC = (props) => { } else { setAppDetail({ ...res, enable_sso: false }) - setNavigation(getNavigations(appId, isCurrentWorkspaceEditor, res.mode)) + setNavigation(getNavigationConfig(appId, isCurrentWorkspaceEditor, res.mode)) } // eslint-disable-next-line react-hooks/exhaustive-deps }, [appDetailRes, isCurrentWorkspaceEditor, isLoadingAppDetail, isLoadingCurrentWorkspace]) @@ -161,7 +161,9 @@ const AppDetailLayout: FC = (props) => { return (
{appDetail && ( - + )}
{children} diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/create-from-pipeline/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/create-from-pipeline/page.tsx new file mode 100644 index 0000000000..9ce86bbef4 --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/create-from-pipeline/page.tsx @@ -0,0 +1,10 @@ +import React from 'react' +import CreateFromPipeline from '@/app/components/datasets/documents/create-from-pipeline' + +const CreateFromPipelinePage = async () => { + return ( + + ) +} + +export default CreateFromPipelinePage diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index fb3a9087ca..e0436d6f5c 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -2,10 +2,10 @@ import type { FC } from 'react' import React, { useEffect, useMemo } from 'react' import { usePathname } from 'next/navigation' -import useSWR from 'swr' import { useTranslation } from 'react-i18next' -import { useBoolean } from 'ahooks' +import type { RemixiconComponentType } from '@remixicon/react' import { + RiAttachmentLine, RiEqualizer2Fill, RiEqualizer2Line, RiFileTextFill, @@ -13,12 +13,8 @@ import { RiFocus2Fill, RiFocus2Line, } from '@remixicon/react' -import { - PaperClipIcon, -} from '@heroicons/react/24/outline' -import { RiApps2AddLine, RiBookOpenLine, RiInformation2Line } from '@remixicon/react' +import { RiInformation2Line } from '@remixicon/react' import classNames from '@/utils/classnames' -import { fetchDatasetDetail, fetchDatasetRelatedApps } from '@/service/datasets' import type { RelatedAppResponse } from '@/models/datasets' import AppSideBar from '@/app/components/app-sidebar' import Loading from '@/app/components/base/loading' @@ -30,6 +26,10 @@ import { useDocLink } from '@/context/i18n' import { useAppContext } from '@/context/app-context' import Tooltip from '@/app/components/base/tooltip' import LinkedAppsPanel from '@/app/components/base/linked-apps-panel' +import { PipelineFill, PipelineLine } from '@/app/components/base/icons/src/vender/pipeline' +import { Divider } from '@/app/components/base/icons/src/vender/knowledge' +import NoLinkedAppsPanel from '@/app/components/datasets/no-linked-apps-panel' +import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use-dataset' import useDocumentTitle from '@/hooks/use-document-title' export type IAppDetailLayoutProps = { @@ -38,81 +38,72 @@ export type IAppDetailLayoutProps = { } type IExtraInfoProps = { - isMobile: boolean relatedApps?: RelatedAppResponse + documentCount?: number expand: boolean } -const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => { - const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile) +const ExtraInfo = React.memo(({ + relatedApps, + documentCount, + expand, +}: IExtraInfoProps) => { const { t } = useTranslation() const docLink = useDocLink() const hasRelatedApps = relatedApps?.data && relatedApps?.data?.length > 0 const relatedAppsTotal = relatedApps?.data?.length || 0 - useEffect(() => { - setShowTips(!isMobile) - }, [isMobile, setShowTips]) - - return
- {hasRelatedApps && ( - <> - {!isMobile && ( - - } - > -
- {relatedAppsTotal || '--'} {t('common.datasetMenus.relatedApp')} - + return ( + <> + {!expand && ( +
+
+
+ {documentCount ?? '--'}
- - )} - - {isMobile &&
- {relatedAppsTotal || '--'} - -
} - - )} - {!hasRelatedApps && !expand && ( - -
- +
+ {t('common.datasetMenus.documents')}
-
{t('common.datasetMenus.emptyTip')}
- - - {t('common.datasetMenus.viewDoc')} -
- } - > -
- {t('common.datasetMenus.noRelatedApp')} - +
+ +
+
+
+ {relatedAppsTotal ?? '--'} +
+ + ) : + } + > +
+ {t('common.datasetMenus.relatedApp')} + +
+
+
-
- )} -
-} + )} + + {expand && ( +
+ {relatedAppsTotal ?? '--'} + +
+ )} + + ) +}) const DatasetDetailLayout: FC = (props) => { const { @@ -120,70 +111,98 @@ const DatasetDetailLayout: FC = (props) => { params: { datasetId }, } = props const pathname = usePathname() - const hideSideBar = /documents\/create$/.test(pathname) + const hideSideBar = pathname.endsWith('documents/create') || pathname.endsWith('documents/create-from-pipeline') const { t } = useTranslation() const { isCurrentWorkspaceDatasetOperator } = useAppContext() const media = useBreakpoints() const isMobile = media === MediaType.mobile - const { data: datasetRes, error, mutate: mutateDatasetRes } = useSWR({ - url: 'fetchDatasetDetail', - datasetId, - }, apiParams => fetchDatasetDetail(apiParams.datasetId)) + const { data: datasetRes, error, refetch: mutateDatasetRes } = useDatasetDetail(datasetId) - const { data: relatedApps } = useSWR({ - action: 'fetchDatasetRelatedApps', - datasetId, - }, apiParams => fetchDatasetRelatedApps(apiParams.datasetId)) + const { data: relatedApps } = useDatasetRelatedApps(datasetId) + + const isButtonDisabledWithPipeline = useMemo(() => { + if (!datasetRes) + return true + if (datasetRes.provider === 'external') + return false + if (datasetRes.runtime_mode === 'general') + return false + return !datasetRes.is_published + }, [datasetRes]) const navigation = useMemo(() => { const baseNavigation = [ - { name: t('common.datasetMenus.hitTesting'), href: `/datasets/${datasetId}/hitTesting`, icon: RiFocus2Line, selectedIcon: RiFocus2Fill }, - { name: t('common.datasetMenus.settings'), href: `/datasets/${datasetId}/settings`, icon: RiEqualizer2Line, selectedIcon: RiEqualizer2Fill }, + { + name: t('common.datasetMenus.hitTesting'), + href: `/datasets/${datasetId}/hitTesting`, + icon: RiFocus2Line, + selectedIcon: RiFocus2Fill, + disabled: isButtonDisabledWithPipeline, + }, + { + name: t('common.datasetMenus.settings'), + href: `/datasets/${datasetId}/settings`, + icon: RiEqualizer2Line, + selectedIcon: RiEqualizer2Fill, + disabled: false, + }, ] if (datasetRes?.provider !== 'external') { + if (datasetRes?.runtime_mode === 'rag_pipeline') { + baseNavigation.unshift({ + name: t('common.datasetMenus.pipeline'), + href: `/datasets/${datasetId}/pipeline`, + icon: PipelineLine as RemixiconComponentType, + selectedIcon: PipelineFill as RemixiconComponentType, + disabled: false, + }) + } baseNavigation.unshift({ name: t('common.datasetMenus.documents'), href: `/datasets/${datasetId}/documents`, icon: RiFileTextLine, selectedIcon: RiFileTextFill, + disabled: isButtonDisabledWithPipeline, }) } + return baseNavigation - }, [datasetRes?.provider, datasetId, t]) + }, [t, datasetId, isButtonDisabledWithPipeline, datasetRes?.provider, datasetRes?.runtime_mode]) useDocumentTitle(datasetRes?.name || t('common.menus.datasets')) - const setAppSiderbarExpand = useStore(state => state.setAppSiderbarExpand) + const setAppSidebarExpand = useStore(state => state.setAppSidebarExpand) useEffect(() => { const localeMode = localStorage.getItem('app-detail-collapse-or-expand') || 'expand' const mode = isMobile ? 'collapse' : 'expand' - setAppSiderbarExpand(isMobile ? mode : localeMode) - }, [isMobile, setAppSiderbarExpand]) + setAppSidebarExpand(isMobile ? mode : localeMode) + }, [isMobile, setAppSidebarExpand]) if (!datasetRes && !error) return return (
- {!hideSideBar && : undefined} - iconType={datasetRes?.data_source_type === DataSourceType.NOTION ? 'notion' : 'dataset'} - />} mutateDatasetRes(), + mutateDatasetRes, }}> + {!hideSideBar && ( + + : undefined + } + iconType={datasetRes?.data_source_type === DataSourceType.NOTION ? 'notion' : 'dataset'} + /> + )}
{children}
diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/pipeline/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/pipeline/page.tsx new file mode 100644 index 0000000000..9a18021cc0 --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/pipeline/page.tsx @@ -0,0 +1,11 @@ +'use client' +import RagPipeline from '@/app/components/rag-pipeline' + +const PipelinePage = () => { + return ( +
+ +
+ ) +} +export default PipelinePage diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx index d9a196d854..164c2dc7ba 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx @@ -8,8 +8,8 @@ const Settings = async () => { return (
-
-
{t('title')}
+
+
{t('title')}
{t('desc')}
diff --git a/web/app/(commonLayout)/datasets/Datasets.tsx b/web/app/(commonLayout)/datasets/Datasets.tsx deleted file mode 100644 index 2d4848e92e..0000000000 --- a/web/app/(commonLayout)/datasets/Datasets.tsx +++ /dev/null @@ -1,96 +0,0 @@ -'use client' - -import { useCallback, useEffect, useRef } from 'react' -import useSWRInfinite from 'swr/infinite' -import { debounce } from 'lodash-es' -import NewDatasetCard from './NewDatasetCard' -import DatasetCard from './DatasetCard' -import type { DataSetListResponse, FetchDatasetsParams } from '@/models/datasets' -import { fetchDatasets } from '@/service/datasets' -import { useAppContext } from '@/context/app-context' -import { useTranslation } from 'react-i18next' - -const getKey = ( - pageIndex: number, - previousPageData: DataSetListResponse, - tags: string[], - keyword: string, - includeAll: boolean, -) => { - if (!pageIndex || previousPageData.has_more) { - const params: FetchDatasetsParams = { - url: 'datasets', - params: { - page: pageIndex + 1, - limit: 30, - include_all: includeAll, - }, - } - if (tags.length) - params.params.tag_ids = tags - if (keyword) - params.params.keyword = keyword - return params - } - return null -} - -type Props = { - containerRef: React.RefObject - tags: string[] - keywords: string - includeAll: boolean -} - -const Datasets = ({ - containerRef, - tags, - keywords, - includeAll, -}: Props) => { - const { t } = useTranslation() - const { isCurrentWorkspaceEditor } = useAppContext() - const { data, isLoading, setSize, mutate } = useSWRInfinite( - (pageIndex: number, previousPageData: DataSetListResponse) => getKey(pageIndex, previousPageData, tags, keywords, includeAll), - fetchDatasets, - { revalidateFirstPage: false, revalidateAll: true }, - ) - const loadingStateRef = useRef(false) - const anchorRef = useRef(null) - - useEffect(() => { - loadingStateRef.current = isLoading - }, [isLoading, t]) - - const onScroll = useCallback( - debounce(() => { - if (!loadingStateRef.current && containerRef.current && anchorRef.current) { - const { scrollTop, clientHeight } = containerRef.current - const anchorOffset = anchorRef.current.offsetTop - if (anchorOffset - scrollTop - clientHeight < 100) - setSize(size => size + 1) - } - }, 50), - [setSize], - ) - - useEffect(() => { - const currentContainer = containerRef.current - currentContainer?.addEventListener('scroll', onScroll) - return () => { - currentContainer?.removeEventListener('scroll', onScroll) - onScroll.cancel() - } - }, [containerRef, onScroll]) - - return ( - - ) -} - -export default Datasets diff --git a/web/app/(commonLayout)/datasets/NewDatasetCard.tsx b/web/app/(commonLayout)/datasets/NewDatasetCard.tsx deleted file mode 100644 index f3532f398d..0000000000 --- a/web/app/(commonLayout)/datasets/NewDatasetCard.tsx +++ /dev/null @@ -1,42 +0,0 @@ -'use client' -import { useTranslation } from 'react-i18next' -import { basePath } from '@/utils/var' -import { - RiAddLine, - RiArrowRightLine, -} from '@remixicon/react' -import Link from 'next/link' - -type CreateAppCardProps = { - ref?: React.Ref -} - -const CreateAppCard = ({ ref }: CreateAppCardProps) => { - const { t } = useTranslation() - - return ( -
- -
-
- -
-
{t('dataset.createDataset')}
-
- -
{t('dataset.createDatasetIntro')}
- -
{t('dataset.connectDataset')}
- - -
- ) -} - -CreateAppCard.displayName = 'CreateAppCard' - -export default CreateAppCard diff --git a/web/app/(commonLayout)/datasets/create-from-pipeline/page.tsx b/web/app/(commonLayout)/datasets/create-from-pipeline/page.tsx new file mode 100644 index 0000000000..72f5ecdfd9 --- /dev/null +++ b/web/app/(commonLayout)/datasets/create-from-pipeline/page.tsx @@ -0,0 +1,10 @@ +import React from 'react' +import CreateFromPipeline from '@/app/components/datasets/create-from-pipeline' + +const DatasetCreation = async () => { + return ( + + ) +} + +export default DatasetCreation diff --git a/web/app/(commonLayout)/datasets/page.tsx b/web/app/(commonLayout)/datasets/page.tsx index 60a542f0a2..8388b69468 100644 --- a/web/app/(commonLayout)/datasets/page.tsx +++ b/web/app/(commonLayout)/datasets/page.tsx @@ -1,12 +1,7 @@ -'use client' -import { useTranslation } from 'react-i18next' -import Container from './Container' -import useDocumentTitle from '@/hooks/use-document-title' +import List from '../../components/datasets/list' -const AppList = () => { - const { t } = useTranslation() - useDocumentTitle(t('common.menus.datasets')) - return +const DatasetList = async () => { + return } -export default AppList +export default DatasetList diff --git a/web/app/(commonLayout)/datasets/store.ts b/web/app/(commonLayout)/datasets/store.ts deleted file mode 100644 index 40b7b15594..0000000000 --- a/web/app/(commonLayout)/datasets/store.ts +++ /dev/null @@ -1,11 +0,0 @@ -import { create } from 'zustand' - -type DatasetStore = { - showExternalApiPanel: boolean - setShowExternalApiPanel: (show: boolean) => void -} - -export const useDatasetStore = create(set => ({ - showExternalApiPanel: false, - setShowExternalApiPanel: show => set({ showExternalApiPanel: show }), -})) diff --git a/web/app/components/app-sidebar/dataset-info.tsx b/web/app/components/app-sidebar/dataset-info.tsx index 73740133ce..3db8789722 100644 --- a/web/app/components/app-sidebar/dataset-info.tsx +++ b/web/app/components/app-sidebar/dataset-info.tsx @@ -3,40 +3,89 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' import AppIcon from '../base/app-icon' - -const DatasetSvg = - - +import Effect from '../base/effect' +import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' +import type { DataSet } from '@/models/datasets' +import { DOC_FORM_ICON_WITH_BG, DOC_FORM_TEXT } from '@/models/datasets' +import { useKnowledge } from '@/hooks/use-knowledge' +import Badge from '../base/badge' +import cn from '@/utils/classnames' type Props = { - isExternal?: boolean - name: string - description: string expand: boolean extraInfo?: React.ReactNode } const DatasetInfo: FC = ({ - name, - description, - isExternal, expand, extraInfo, }) => { const { t } = useTranslation() + const dataset = useDatasetDetailContextWithSelector(state => state.dataset) as DataSet + const iconInfo = dataset.icon_info || { + icon: '📙', + icon_type: 'emoji', + icon_background: '#FFF4ED', + icon_url: '', + } + const isExternalProvider = dataset.provider === 'external' + const { formatIndexingTechniqueAndMethod } = useKnowledge() + const chunkingModeIcon = dataset.doc_form ? DOC_FORM_ICON_WITH_BG[dataset.doc_form] : React.Fragment + const Icon = isExternalProvider ? DOC_FORM_ICON_WITH_BG.external : chunkingModeIcon + return ( -
-
- -
+
{expand && ( -
-
- {name} + <> + +
+
+ + {(dataset.doc_form || isExternalProvider) && ( +
+ +
+ )} +
+ <> +
+
+ {dataset.name} +
+
+ {isExternalProvider && t('dataset.externalTag')} + {!isExternalProvider && dataset.doc_form && dataset.indexing_technique && ( +
+ {t(`dataset.chunkingMode.${DOC_FORM_TEXT[dataset.doc_form]}`)} + {formatIndexingTechniqueAndMethod(dataset.indexing_technique, dataset.retrieval_model_dict?.search_method)} +
+ )} +
+
+

+ {dataset.description} +

+
-
{isExternal ? t('dataset.externalTag') : t('dataset.localDocs')}
-
{description}
-
+ + )} + {!expand && ( + )} {extraInfo}
diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index b6bfc0e9ac..f90b7437d5 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -4,7 +4,6 @@ import { useShallow } from 'zustand/react/shallow' import { RiLayoutLeft2Line, RiLayoutRight2Line } from '@remixicon/react' import NavLink from './navLink' import type { NavIcon } from './navLink' -import AppBasic from './basic' import AppInfo from './app-info' import DatasetInfo from './dataset-info' import AppSidebarDropdown from './app-sidebar-dropdown' @@ -15,31 +14,31 @@ import cn from '@/utils/classnames' export type IAppDetailNavProps = { iconType?: 'app' | 'dataset' | 'notion' - title: string - desc: string - isExternal?: boolean - icon: string - icon_background: string | null navigation: Array<{ name: string href: string icon: NavIcon selectedIcon: NavIcon + disabled?: boolean }> extraInfo?: (modeState: string) => React.ReactNode } -const AppDetailNav = ({ title, desc, isExternal, icon, icon_background, navigation, extraInfo, iconType = 'app' }: IAppDetailNavProps) => { - const { appSidebarExpand, setAppSiderbarExpand } = useAppStore(useShallow(state => ({ +const AppDetailNav = ({ + navigation, + extraInfo, + iconType = 'app', +}: IAppDetailNavProps) => { + const { appSidebarExpand, setAppSidebarExpand } = useAppStore(useShallow(state => ({ appSidebarExpand: state.appSidebarExpand, - setAppSiderbarExpand: state.setAppSiderbarExpand, + setAppSidebarExpand: state.setAppSidebarExpand, }))) const media = useBreakpoints() const isMobile = media === MediaType.mobile const expand = appSidebarExpand === 'expand' const handleToggle = (state: string) => { - setAppSiderbarExpand(state === 'expand' ? 'collapse' : 'expand') + setAppSidebarExpand(state === 'expand' ? 'collapse' : 'expand') } // // Check if the current path is a workflow canvas & fullscreen @@ -57,9 +56,9 @@ const AppDetailNav = ({ title, desc, isExternal, icon, icon_background, navigati useEffect(() => { if (appSidebarExpand) { localStorage.setItem('app-detail-collapse-or-expand', appSidebarExpand) - setAppSiderbarExpand(appSidebarExpand) + setAppSidebarExpand(appSidebarExpand) } - }, [appSidebarExpand, setAppSiderbarExpand]) + }, [appSidebarExpand, setAppSidebarExpand]) if (inWorkflowCanvas && hideHeader) { return ( @@ -85,26 +84,12 @@ const AppDetailNav = ({ title, desc, isExternal, icon, icon_background, navigati {iconType === 'app' && ( )} - {iconType === 'dataset' && ( + {iconType !== 'app' && ( )} - {!['app', 'dataset'].includes(iconType) && ( - - )}
@@ -117,7 +102,14 @@ const AppDetailNav = ({ title, desc, isExternal, icon, icon_background, navigati > {navigation.map((item, index) => { return ( - + ) })} diff --git a/web/app/components/app-sidebar/navLink.tsx b/web/app/components/app-sidebar/navLink.tsx index 295b553b04..a69f0bd6aa 100644 --- a/web/app/components/app-sidebar/navLink.tsx +++ b/web/app/components/app-sidebar/navLink.tsx @@ -6,10 +6,10 @@ import classNames from '@/utils/classnames' import type { RemixiconComponentType } from '@remixicon/react' export type NavIcon = React.ComponentType< -React.PropsWithoutRef> & { - title?: string | undefined - titleId?: string | undefined -}> | RemixiconComponentType + React.PropsWithoutRef> & { + title?: string | undefined + titleId?: string | undefined + }> | RemixiconComponentType export type NavLinkProps = { name: string @@ -19,6 +19,7 @@ export type NavLinkProps = { normal: NavIcon } mode?: string + disabled?: boolean } export default function NavLink({ @@ -26,6 +27,7 @@ export default function NavLink({ href, iconMap, mode = 'expand', + disabled = false, }: NavLinkProps) { const segment = useSelectedLayoutSegment() const formattedSegment = (() => { @@ -39,13 +41,38 @@ export default function NavLink({ const isActive = href.toLowerCase().split('/')?.pop() === formattedSegment const NavIcon = isActive ? iconMap.selected : iconMap.normal + if (disabled) { + return ( + + ) + } + return ( = ({ const [isDeleting, setIsDeleting] = useState(false) + const iconInfo = config.icon_info || { + icon: '📙', + icon_type: 'emoji', + icon_background: '#FFF4ED', + icon_url: '', + } + return (
- { - config.data_source_type === DataSourceType.FILE && ( -
- -
- ) - } - { - config.data_source_type === DataSourceType.NOTION && ( -
- -
- ) - } - { - config.data_source_type === DataSourceType.WEB && ( -
- -
- ) - } +
{config.name}
diff --git a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx index 5f0ad94d86..ebfa3b1e12 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx @@ -21,10 +21,12 @@ type Value = { type WeightedScoreProps = { value: Value onChange: (value: Value) => void + readonly?: boolean } const WeightedScore = ({ value, onChange = noop, + readonly = false, }: WeightedScoreProps) => { const { t } = useTranslation() @@ -37,8 +39,9 @@ const WeightedScore = ({ min={0} step={0.1} value={value.value[0]} - onChange={v => onChange({ value: [v, (10 - v * 10) / 10] })} + onChange={v => !readonly && onChange({ value: [v, (10 - v * 10) / 10] })} trackClassName='weightedScoreSliderTrack' + disabled={readonly} />
diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index ffdb714f08..99dc32dfa2 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -5,7 +5,6 @@ import { useGetState, useInfiniteScroll } from 'ahooks' import { useTranslation } from 'react-i18next' import Link from 'next/link' import produce from 'immer' -import TypeIcon from '../type-icon' import Modal from '@/app/components/base/modal' import type { DataSet } from '@/models/datasets' import Button from '@/app/components/base/button' @@ -15,6 +14,7 @@ import Badge from '@/app/components/base/badge' import { useKnowledge } from '@/hooks/use-knowledge' import cn from '@/utils/classnames' import { basePath } from '@/utils/var' +import AppIcon from '@/app/components/base/app-icon' export type ISelectDataSetProps = { isShow: boolean @@ -91,6 +91,7 @@ const SelectDataSet: FC = ({ const handleSelect = () => { onSelect(selected) } + return ( = ({ >
- +
{item.name}
{!item.embedding_available && ( diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index 9835481ae0..62f1010b54 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -6,7 +6,7 @@ import { isEqual } from 'lodash-es' import { RiCloseLine } from '@remixicon/react' import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' import cn from '@/utils/classnames' -import IndexMethodRadio from '@/app/components/datasets/settings/index-method-radio' +import IndexMethod from '@/app/components/datasets/settings/index-method' import Divider from '@/app/components/base/divider' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' @@ -31,6 +31,7 @@ import { import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { fetchMembers } from '@/service/common' import type { Member } from '@/models/common' +import { IndexingType } from '@/app/components/datasets/create/step-two' import { useDocLink } from '@/context/i18n' type SettingsModalProps = { @@ -55,8 +56,6 @@ const SettingsModal: FC = ({ const { data: embeddingsModelList } = useModelList(ModelTypeEnum.textEmbedding) const { modelList: rerankModelList, - defaultModel: rerankDefaultModel, - currentModel: isRerankDefaultModelValid, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const { t } = useTranslation() const docLink = useDocLink() @@ -75,6 +74,7 @@ const SettingsModal: FC = ({ const [indexMethod, setIndexMethod] = useState(currentDataset.indexing_technique) const [retrievalConfig, setRetrievalConfig] = useState(localeCurrentDataset?.retrieval_model_dict as RetrievalConfig) + const [keywordNumber, setKeywordNumber] = useState(currentDataset.keyword_number ?? 10) const handleValueChange = (type: string, value: string) => { setLocaleCurrentDataset({ ...localeCurrentDataset, [type]: value }) @@ -126,6 +126,7 @@ const SettingsModal: FC = ({ description, permission, indexing_technique: indexMethod, + keyword_number: keywordNumber, retrieval_model: { ...retrievalConfig, score_threshold: retrievalConfig.score_threshold_enabled ? retrievalConfig.score_threshold : 0, @@ -247,17 +248,18 @@ const SettingsModal: FC = ({
{t('datasetSettings.form.indexMethod')}
- setIndexMethod(v!)} - docForm={currentDataset.doc_form} + onChange={setIndexMethod} currentValue={currentDataset.indexing_technique} + keywordNumber={keywordNumber} + onKeywordNumberChange={setKeywordNumber} />
)} - {indexMethod === 'high_quality' && ( + {indexMethod === IndexingType.QUALIFIED && (
{t('datasetSettings.form.embeddingModel')}
@@ -336,7 +338,7 @@ const SettingsModal: FC = ({
- {indexMethod === 'high_quality' + {indexMethod === IndexingType.QUALIFIED ? ( { const { t } = useTranslation() const { notify } = useContext(ToastContext) - const { appDetail, showAppConfigureFeaturesModal, setAppSiderbarExpand, setShowAppConfigureFeaturesModal } = useAppStore(useShallow(state => ({ + const { appDetail, showAppConfigureFeaturesModal, setAppSidebarExpand, setShowAppConfigureFeaturesModal } = useAppStore(useShallow(state => ({ appDetail: state.appDetail, - setAppSiderbarExpand: state.setAppSiderbarExpand, + setAppSidebarExpand: state.setAppSidebarExpand, showAppConfigureFeaturesModal: state.showAppConfigureFeaturesModal, setShowAppConfigureFeaturesModal: state.setShowAppConfigureFeaturesModal, }))) @@ -823,7 +823,7 @@ const Configuration: FC = () => { { id: `${Date.now()}-no-repeat`, model: '', provider: '', parameters: {} }, ], ) - setAppSiderbarExpand('collapse') + setAppSidebarExpand('collapse') } if (isLoading) { diff --git a/web/app/components/app/store.ts b/web/app/components/app/store.ts index 5f02f92f0d..a90d560ac7 100644 --- a/web/app/components/app/store.ts +++ b/web/app/components/app/store.ts @@ -15,7 +15,7 @@ type State = { type Action = { setAppDetail: (appDetail?: App & Partial) => void - setAppSiderbarExpand: (state: string) => void + setAppSidebarExpand: (state: string) => void setCurrentLogItem: (item?: IChatItem) => void setCurrentLogModalActiveTab: (tab: string) => void setShowPromptLogModal: (showPromptLogModal: boolean) => void @@ -28,7 +28,7 @@ export const useStore = create(set => ({ appDetail: undefined, setAppDetail: appDetail => set(() => ({ appDetail })), appSidebarExpand: '', - setAppSiderbarExpand: appSidebarExpand => set(() => ({ appSidebarExpand })), + setAppSidebarExpand: appSidebarExpand => set(() => ({ appSidebarExpand })), currentLogItem: undefined, currentLogModalActiveTab: 'DETAIL', setCurrentLogItem: currentLogItem => set(() => ({ currentLogItem })), diff --git a/web/app/components/app/workflow-log/detail.tsx b/web/app/components/app/workflow-log/detail.tsx index dc3eb89a2a..bb5b268d5d 100644 --- a/web/app/components/app/workflow-log/detail.tsx +++ b/web/app/components/app/workflow-log/detail.tsx @@ -3,6 +3,7 @@ import type { FC } from 'react' import { useTranslation } from 'react-i18next' import { RiCloseLine } from '@remixicon/react' import Run from '@/app/components/workflow/run' +import { useStore } from '@/app/components/app/store' type ILogDetail = { runID: string @@ -11,6 +12,7 @@ type ILogDetail = { const DetailPanel: FC = ({ runID, onClose }) => { const { t } = useTranslation() + const appDetail = useStore(state => state.appDetail) return (
@@ -18,7 +20,10 @@ const DetailPanel: FC = ({ runID, onClose }) => {

{t('appLog.runDetail.workflowTitle')}

- +
) } diff --git a/web/app/components/base/action-button/index.spec.tsx b/web/app/components/base/action-button/index.spec.tsx new file mode 100644 index 0000000000..76c8eebda0 --- /dev/null +++ b/web/app/components/base/action-button/index.spec.tsx @@ -0,0 +1,76 @@ +import { render, screen } from '@testing-library/react' +import { ActionButton, ActionButtonState } from './index' + +describe('ActionButton', () => { + test('renders button with default props', () => { + render(Click me) + const button = screen.getByRole('button', { name: 'Click me' }) + expect(button).toBeInTheDocument() + expect(button.classList.contains('action-btn')).toBe(true) + expect(button.classList.contains('action-btn-m')).toBe(true) + }) + + test('renders button with xs size', () => { + render(Small Button) + const button = screen.getByRole('button', { name: 'Small Button' }) + expect(button.classList.contains('action-btn-xs')).toBe(true) + }) + + test('renders button with l size', () => { + render(Large Button) + const button = screen.getByRole('button', { name: 'Large Button' }) + expect(button.classList.contains('action-btn-l')).toBe(true) + }) + + test('renders button with xl size', () => { + render(Extra Large Button) + const button = screen.getByRole('button', { name: 'Extra Large Button' }) + expect(button.classList.contains('action-btn-xl')).toBe(true) + }) + + test('applies correct state classes', () => { + const { rerender } = render( + Destructive, + ) + let button = screen.getByRole('button', { name: 'Destructive' }) + expect(button.classList.contains('action-btn-destructive')).toBe(true) + + rerender(Active) + button = screen.getByRole('button', { name: 'Active' }) + expect(button.classList.contains('action-btn-active')).toBe(true) + + rerender(Disabled) + button = screen.getByRole('button', { name: 'Disabled' }) + expect(button.classList.contains('action-btn-disabled')).toBe(true) + + rerender(Hover) + button = screen.getByRole('button', { name: 'Hover' }) + expect(button.classList.contains('action-btn-hover')).toBe(true) + }) + + test('applies custom className', () => { + render(Custom Class) + const button = screen.getByRole('button', { name: 'Custom Class' }) + expect(button.classList.contains('custom-class')).toBe(true) + }) + + test('applies custom style', () => { + render( + + Custom Style + , + ) + const button = screen.getByRole('button', { name: 'Custom Style' }) + expect(button).toHaveStyle({ + color: 'red', + backgroundColor: 'blue', + }) + }) + + test('forwards additional button props', () => { + render(Disabled Button) + const button = screen.getByRole('button', { name: 'Disabled Button' }) + expect(button).toBeDisabled() + expect(button).toHaveAttribute('data-testid', 'test-button') + }) +}) diff --git a/web/app/components/base/app-icon-picker/index.tsx b/web/app/components/base/app-icon-picker/index.tsx index 8e66cd38cf..a127c8acb7 100644 --- a/web/app/components/base/app-icon-picker/index.tsx +++ b/web/app/components/base/app-icon-picker/index.tsx @@ -5,7 +5,6 @@ import type { Area } from 'react-easy-crop' import Modal from '../modal' import Divider from '../divider' import Button from '../button' -import { ImagePlus } from '../icons/src/vender/line/images' import { useLocalFileUploader } from '../image-uploader/hooks' import EmojiPickerInner from '../emoji-picker/Inner' import type { OnImageInput } from './ImageInput' @@ -16,6 +15,7 @@ import type { AppIconType, ImageFile } from '@/types/app' import cn from '@/utils/classnames' import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' import { noop } from 'lodash-es' +import { RiImageCircleAiLine } from '@remixicon/react' export type AppIconEmojiSelection = { type: 'emoji' @@ -46,7 +46,7 @@ const AppIconPicker: FC = ({ const tabs = [ { key: 'emoji', label: t('app.iconPicker.emoji'), icon: 🤖 }, - { key: 'image', label: t('app.iconPicker.image'), icon: }, + { key: 'image', label: t('app.iconPicker.image'), icon: }, ] const [activeTab, setActiveTab] = useState('emoji') @@ -119,10 +119,10 @@ const AppIconPicker: FC = ({ {tabs.map(tab => (
} />) + const innerIcon = screen.getByTestId('inner-icon') + expect(innerIcon).toBeInTheDocument() + }) + + it('applies size classes correctly', () => { + const { container: xsContainer } = render() + expect(xsContainer.firstChild).toHaveClass('w-4 h-4 rounded-[4px]') + + const { container: tinyContainer } = render() + expect(tinyContainer.firstChild).toHaveClass('w-6 h-6 rounded-md') + + const { container: smallContainer } = render() + expect(smallContainer.firstChild).toHaveClass('w-8 h-8 rounded-lg') + + const { container: mediumContainer } = render() + expect(mediumContainer.firstChild).toHaveClass('w-9 h-9 rounded-[10px]') + + const { container: largeContainer } = render() + expect(largeContainer.firstChild).toHaveClass('w-10 h-10 rounded-[10px]') + + const { container: xlContainer } = render() + expect(xlContainer.firstChild).toHaveClass('w-12 h-12 rounded-xl') + + const { container: xxlContainer } = render() + expect(xxlContainer.firstChild).toHaveClass('w-14 h-14 rounded-2xl') + }) + + it('applies rounded class when rounded=true', () => { + const { container } = render() + expect(container.firstChild).toHaveClass('rounded-full') + }) + + it('applies custom background color', () => { + const { container } = render() + expect(container.firstChild).toHaveStyle('background: #FF5500') + }) + + it('uses default background color when no background is provided for non-image icons', () => { + const { container } = render() + expect(container.firstChild).toHaveStyle('background: #FFEAD5') + }) + + it('does not apply background style for image icons', () => { + const { container } = render() + // Should not have the background style from the prop + expect(container.firstChild).not.toHaveStyle('background: #FF5500') + }) + + it('calls onClick handler when clicked', () => { + const handleClick = jest.fn() + const { container } = render() + fireEvent.click(container.firstChild!) + + expect(handleClick).toHaveBeenCalledTimes(1) + }) + + it('applies custom className', () => { + const { container } = render() + expect(container.firstChild).toHaveClass('custom-class') + }) + + it('does not display edit icon when showEditIcon=false', () => { + render() + const editIcon = screen.queryByRole('svg') + expect(editIcon).not.toBeInTheDocument() + }) + + it('displays edit icon when showEditIcon=true and hovering', () => { + // Mock the useHover hook to return true for this test + require('ahooks').useHover.mockReturnValue(true) + + render() + const editIcon = document.querySelector('svg') + expect(editIcon).toBeInTheDocument() + }) + + it('does not display edit icon when showEditIcon=true but not hovering', () => { + // useHover returns false by default from our mock setup + render() + const editIcon = document.querySelector('svg') + expect(editIcon).not.toBeInTheDocument() + }) + + it('handles conditional isValidImageIcon check correctly', () => { + // Case 1: Valid image icon + const { rerender } = render( + , + ) + expect(screen.getByAltText('app icon')).toBeInTheDocument() + + // Case 2: Invalid - missing image URL + rerender() + expect(screen.queryByAltText('app icon')).not.toBeInTheDocument() + + // Case 3: Invalid - wrong icon type + rerender() + expect(screen.queryByAltText('app icon')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/base/app-icon/index.tsx b/web/app/components/base/app-icon/index.tsx index ac17af1988..f7eaa20917 100644 --- a/web/app/components/base/app-icon/index.tsx +++ b/web/app/components/base/app-icon/index.tsx @@ -1,11 +1,12 @@ 'use client' - -import type { FC } from 'react' +import React, { type FC, useRef } from 'react' import { init } from 'emoji-mart' import data from '@emoji-mart/data' import { cva } from 'class-variance-authority' import type { AppIconType } from '@/types/app' import classNames from '@/utils/classnames' +import { useHover } from 'ahooks' +import { RiEditLine } from '@remixicon/react' init({ data }) @@ -18,20 +19,21 @@ export type AppIconProps = { imageUrl?: string | null className?: string innerIcon?: React.ReactNode + showEditIcon?: boolean onClick?: () => void } const appIconVariants = cva( - 'flex items-center justify-center relative text-lg rounded-lg grow-0 shrink-0 overflow-hidden leading-none', + 'flex items-center justify-center relative grow-0 shrink-0 overflow-hidden leading-none border-[0.5px] border-divider-regular', { variants: { size: { - xs: 'w-4 h-4 text-xs', - tiny: 'w-6 h-6 text-base', - small: 'w-8 h-8 text-xl', - medium: 'w-9 h-9 text-[22px]', - large: 'w-10 h-10 text-[24px]', - xl: 'w-12 h-12 text-[28px]', - xxl: 'w-14 h-14 text-[32px]', + xs: 'w-4 h-4 text-xs rounded-[4px]', + tiny: 'w-6 h-6 text-base rounded-md', + small: 'w-8 h-8 text-xl rounded-lg', + medium: 'w-9 h-9 text-[22px] rounded-[10px]', + large: 'w-10 h-10 text-[24px] rounded-[10px]', + xl: 'w-12 h-12 text-[28px] rounded-xl', + xxl: 'w-14 h-14 text-[32px] rounded-2xl', }, rounded: { true: 'rounded-full', @@ -42,6 +44,46 @@ const appIconVariants = cva( rounded: false, }, }) +const EditIconWrapperVariants = cva( + 'absolute left-0 top-0 z-10 flex items-center justify-center bg-background-overlay-alt', + { + variants: { + size: { + xs: 'w-4 h-4 rounded-[4px]', + tiny: 'w-6 h-6 rounded-md', + small: 'w-8 h-8 rounded-lg', + medium: 'w-9 h-9 rounded-[10px]', + large: 'w-10 h-10 rounded-[10px]', + xl: 'w-12 h-12 rounded-xl', + xxl: 'w-14 h-14 rounded-2xl', + }, + rounded: { + true: 'rounded-full', + }, + }, + defaultVariants: { + size: 'medium', + rounded: false, + }, + }) +const EditIconVariants = cva( + 'text-text-primary-on-surface', + { + variants: { + size: { + xs: 'size-3', + tiny: 'size-3.5', + small: 'size-5', + medium: 'size-[22px]', + large: 'size-6', + xl: 'size-7', + xxl: 'size-8', + }, + }, + defaultVariants: { + size: 'medium', + }, + }) const AppIcon: FC = ({ size = 'medium', rounded = false, @@ -52,20 +94,34 @@ const AppIcon: FC = ({ className, innerIcon, onClick, + showEditIcon = false, }) => { const isValidImageIcon = iconType === 'image' && imageUrl + const Icon = (icon && icon !== '') ? : + const wrapperRef = useRef(null) + const isHovering = useHover(wrapperRef) - return - {isValidImageIcon - - ? app icon - : (innerIcon || ((icon && icon !== '') ? : )) - } - + return ( + + { + isValidImageIcon + ? app icon + : (innerIcon || Icon) + } + { + showEditIcon && isHovering && ( +
+ +
+ ) + } +
+ ) } -export default AppIcon +export default React.memo(AppIcon) diff --git a/web/app/components/base/corner-label/index.tsx b/web/app/components/base/corner-label/index.tsx index 9e192ed753..0807ed4659 100644 --- a/web/app/components/base/corner-label/index.tsx +++ b/web/app/components/base/corner-label/index.tsx @@ -10,8 +10,8 @@ type CornerLabelProps = { const CornerLabel: React.FC = ({ label, className, labelClassName }) => { return (
- -
+ +
{label}
diff --git a/web/app/components/base/effect/index.tsx b/web/app/components/base/effect/index.tsx new file mode 100644 index 0000000000..95afb1ba5f --- /dev/null +++ b/web/app/components/base/effect/index.tsx @@ -0,0 +1,18 @@ +import React from 'react' +import cn from '@/utils/classnames' + +type EffectProps = { + className?: string +} + +const Effect = ({ + className, +}: EffectProps) => { + return ( +
+ ) +} + +export default React.memo(Effect) diff --git a/web/app/components/base/file-uploader/file-uploader-in-attachment/index.tsx b/web/app/components/base/file-uploader/file-uploader-in-attachment/index.tsx index 02bb3ad673..5090b945e3 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-attachment/index.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-attachment/index.tsx @@ -114,7 +114,7 @@ const FileUploaderInAttachment = ({ ) } -type FileUploaderInAttachmentWrapperProps = { +export type FileUploaderInAttachmentWrapperProps = { value?: FileEntity[] onChange: (files: FileEntity[]) => void fileConfig: FileUpload diff --git a/web/app/components/base/form/components/field/custom-select.tsx b/web/app/components/base/form/components/field/custom-select.tsx new file mode 100644 index 0000000000..0e605184dc --- /dev/null +++ b/web/app/components/base/form/components/field/custom-select.tsx @@ -0,0 +1,41 @@ +import cn from '@/utils/classnames' +import { useFieldContext } from '../..' +import type { CustomSelectProps, Option } from '../../../select/custom' +import CustomSelect from '../../../select/custom' +import type { LabelProps } from '../label' +import Label from '../label' + +type CustomSelectFieldProps = { + label: string + labelOptions?: Omit + options: T[] + className?: string +} & Omit, 'options' | 'value' | 'onChange'> + +const CustomSelectField = ({ + label, + labelOptions, + options, + className, + ...selectProps +}: CustomSelectFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default CustomSelectField diff --git a/web/app/components/base/form/components/field/file-types.tsx b/web/app/components/base/form/components/field/file-types.tsx new file mode 100644 index 0000000000..44c77dc894 --- /dev/null +++ b/web/app/components/base/form/components/field/file-types.tsx @@ -0,0 +1,83 @@ +import cn from '@/utils/classnames' +import type { LabelProps } from '../label' +import { useFieldContext } from '../..' +import Label from '../label' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' +import FileTypeItem from '@/app/components/workflow/nodes/_base/components/file-type-item' +import { useCallback } from 'react' + +type FieldValue = { + allowedFileTypes: string[], + allowedFileExtensions: string[] +} + +type FileTypesFieldProps = { + label: string + labelOptions?: Omit + className?: string +} + +const FileTypesField = ({ + label, + labelOptions, + className, +}: FileTypesFieldProps) => { + const field = useFieldContext() + + const handleSupportFileTypeChange = useCallback((type: SupportUploadFileTypes) => { + let newAllowFileTypes = [...field.state.value.allowedFileTypes] + if (type === SupportUploadFileTypes.custom) { + if (!newAllowFileTypes.includes(SupportUploadFileTypes.custom)) + newAllowFileTypes = [SupportUploadFileTypes.custom] + else + newAllowFileTypes = newAllowFileTypes.filter(v => v !== type) + } + else { + newAllowFileTypes = newAllowFileTypes.filter(v => v !== SupportUploadFileTypes.custom) + if (newAllowFileTypes.includes(type)) + newAllowFileTypes = newAllowFileTypes.filter(v => v !== type) + else + newAllowFileTypes.push(type) + } + field.handleChange({ + ...field.state.value, + allowedFileTypes: newAllowFileTypes, + }) + }, [field]) + + const handleCustomFileTypesChange = useCallback((customFileTypes: string[]) => { + field.handleChange({ + ...field.state.value, + allowedFileExtensions: customFileTypes, + }) + }, [field]) + + return ( +
+
+ ) +} + +export default FileTypesField diff --git a/web/app/components/base/form/components/field/file-uploader.tsx b/web/app/components/base/form/components/field/file-uploader.tsx new file mode 100644 index 0000000000..2e4e26b5d6 --- /dev/null +++ b/web/app/components/base/form/components/field/file-uploader.tsx @@ -0,0 +1,40 @@ +import React from 'react' +import { useFieldContext } from '../..' +import type { LabelProps } from '../label' +import Label from '../label' +import cn from '@/utils/classnames' +import type { FileUploaderInAttachmentWrapperProps } from '../../../file-uploader/file-uploader-in-attachment' +import FileUploaderInAttachmentWrapper from '../../../file-uploader/file-uploader-in-attachment' +import type { FileEntity } from '../../../file-uploader/types' + +type FileUploaderFieldProps = { + label: string + labelOptions?: Omit + className?: string +} & Omit + +const FileUploaderField = ({ + label, + labelOptions, + className, + ...inputProps +}: FileUploaderFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default FileUploaderField diff --git a/web/app/components/base/form/components/field/input-type-select/hooks.tsx b/web/app/components/base/form/components/field/input-type-select/hooks.tsx new file mode 100644 index 0000000000..cc1192414d --- /dev/null +++ b/web/app/components/base/form/components/field/input-type-select/hooks.tsx @@ -0,0 +1,52 @@ +import { InputTypeEnum } from './types' +import { PipelineInputVarType } from '@/models/pipeline' +import { useTranslation } from 'react-i18next' +import { + RiAlignLeft, + RiCheckboxLine, + RiFileCopy2Line, + RiFileTextLine, + RiHashtag, + RiListCheck3, + RiTextSnippet, +} from '@remixicon/react' + +const i18nFileTypeMap: Record = { + 'number': 'number', + 'file': 'single-file', + 'file-list': 'multi-files', +} + +const INPUT_TYPE_ICON = { + [PipelineInputVarType.textInput]: RiTextSnippet, + [PipelineInputVarType.paragraph]: RiAlignLeft, + [PipelineInputVarType.number]: RiHashtag, + [PipelineInputVarType.select]: RiListCheck3, + [PipelineInputVarType.checkbox]: RiCheckboxLine, + [PipelineInputVarType.singleFile]: RiFileTextLine, + [PipelineInputVarType.multiFiles]: RiFileCopy2Line, +} + +const DATA_TYPE = { + [PipelineInputVarType.textInput]: 'string', + [PipelineInputVarType.paragraph]: 'string', + [PipelineInputVarType.number]: 'number', + [PipelineInputVarType.select]: 'string', + [PipelineInputVarType.checkbox]: 'boolean', + [PipelineInputVarType.singleFile]: 'file', + [PipelineInputVarType.multiFiles]: 'array[file]', +} + +export const useInputTypeOptions = (supportFile: boolean) => { + const { t } = useTranslation() + const options = supportFile ? InputTypeEnum.options : InputTypeEnum.exclude(['file', 'file-list']).options + + return options.map((value) => { + return { + value, + label: t(`appDebug.variableConfig.${i18nFileTypeMap[value] || value}`), + Icon: INPUT_TYPE_ICON[value], + type: DATA_TYPE[value], + } + }) +} diff --git a/web/app/components/base/form/components/field/input-type-select/index.tsx b/web/app/components/base/form/components/field/input-type-select/index.tsx new file mode 100644 index 0000000000..256fd872d2 --- /dev/null +++ b/web/app/components/base/form/components/field/input-type-select/index.tsx @@ -0,0 +1,64 @@ +import cn from '@/utils/classnames' +import { useFieldContext } from '../../..' +import type { CustomSelectProps } from '../../../../select/custom' +import CustomSelect from '../../../../select/custom' +import type { LabelProps } from '../../label' +import Label from '../../label' +import { useCallback } from 'react' +import Trigger from './trigger' +import type { FileTypeSelectOption, InputType } from './types' +import { useInputTypeOptions } from './hooks' +import Option from './option' + +type InputTypeSelectFieldProps = { + label: string + labelOptions?: Omit + supportFile: boolean + className?: string +} & Omit, 'options' | 'value' | 'onChange' | 'CustomTrigger' | 'CustomOption'> + +const InputTypeSelectField = ({ + label, + labelOptions, + supportFile, + className, + ...customSelectProps +}: InputTypeSelectFieldProps) => { + const field = useFieldContext() + const inputTypeOptions = useInputTypeOptions(supportFile) + + const renderTrigger = useCallback((option: FileTypeSelectOption | undefined, open: boolean) => { + return + }, []) + const renderOption = useCallback((option: FileTypeSelectOption) => { + return