diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 512d14b2ee..7d0a873ebd 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -84,10 +84,8 @@ jobs: elasticsearch oceanbase - - name: Check VDB Ready (TiDB, Oceanbase) - run: | - uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py - uv run --project api python api/tests/integration_tests/vdb/oceanbase/check_oceanbase_ready.py + - name: Check VDB Ready (TiDB) + run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py - name: Test Vector Stores run: uv run --project api bash dev/pytest/pytest_vdb.sh diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 14fd4679a1..2b48afd550 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -208,7 +208,7 @@ class AnnotationBatchImportApi(Resource): if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename or not file.filename.endswith(".csv"): + if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") return AppAnnotationService.batch_import_app_annotations(app_id, file) diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index bc37907a30..48142dbe73 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -374,7 +374,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename or not file.filename.endswith(".csv"): + if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 27e8dd3fa6..1467dfb6b3 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -5,7 +5,11 @@ from werkzeug.exceptions import Forbidden, NotFound import services.dataset_service from controllers.service_api import api from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError -from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token +from controllers.service_api.wraps import ( + DatasetApiResource, + cloud_edition_billing_rate_limit_check, + validate_dataset_token, +) from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.plugin import ModelProviderID from core.provider_manager import ProviderManager @@ -70,6 +74,7 @@ class DatasetListApi(DatasetApiResource): response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id): """Resource for creating datasets.""" parser = reqparse.RequestParser() @@ -193,6 +198,7 @@ class DatasetApi(DatasetApiResource): return data, 200 + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, _, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -293,6 +299,7 @@ class DatasetApi(DatasetApiResource): return result_data, 200 + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, _, dataset_id): """ Deletes a dataset given its ID. diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index ab7ab4dcf0..e4779f3bdf 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -19,7 +19,11 @@ from controllers.service_api.dataset.error import ( ArchivedDocumentImmutableError, DocumentIndexingError, ) -from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check +from controllers.service_api.wraps import ( + DatasetApiResource, + cloud_edition_billing_rate_limit_check, + cloud_edition_billing_resource_check, +) from core.errors.error import ProviderTokenNotInitError from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields @@ -35,6 +39,7 @@ class DocumentAddByTextApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by text.""" parser = reqparse.RequestParser() @@ -99,6 +104,7 @@ class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by text.""" parser = reqparse.RequestParser() @@ -158,6 +164,7 @@ class DocumentAddByFileApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" args = {} @@ -232,6 +239,7 @@ class DocumentUpdateByFileApi(DatasetApiResource): """Resource for update documents.""" @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" args = {} @@ -302,6 +310,7 @@ class DocumentUpdateByFileApi(DatasetApiResource): class DocumentDeleteApi(DatasetApiResource): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, document_id): """Delete document.""" document_id = str(document_id) diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 465f71bf03..52e9bca5da 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,9 +1,10 @@ from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.service_api import api -from controllers.service_api.wraps import DatasetApiResource +from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): dataset_id_str = str(dataset_id) diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 35582feea0..1968696ee5 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -3,7 +3,7 @@ from flask_restful import marshal, reqparse from werkzeug.exceptions import NotFound from controllers.service_api import api -from controllers.service_api.wraps import DatasetApiResource +from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from fields.dataset_fields import dataset_metadata_fields from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( @@ -14,6 +14,7 @@ from services.metadata_service import MetadataService class DatasetMetadataCreateServiceApi(DatasetApiResource): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): parser = reqparse.RequestParser() parser.add_argument("type", type=str, required=True, nullable=True, location="json") @@ -39,6 +40,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): class DatasetMetadataServiceApi(DatasetApiResource): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id, dataset_id, metadata_id): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=True, location="json") @@ -54,6 +56,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) return marshal(metadata, dataset_metadata_fields), 200 + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, metadata_id): dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -73,6 +76,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, action): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -88,6 +92,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): class DocumentMetadataEditServiceApi(DatasetApiResource): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 337752275a..403b7f0a0c 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -8,6 +8,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_knowledge_limit_check, + cloud_edition_billing_rate_limit_check, cloud_edition_billing_resource_check, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError @@ -35,6 +36,7 @@ class SegmentApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Create single segment.""" # check dataset @@ -139,6 +141,7 @@ class SegmentApi(DatasetApiResource): class DatasetSegmentApi(DatasetApiResource): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -162,6 +165,7 @@ class DatasetSegmentApi(DatasetApiResource): return 204 @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -236,6 +240,7 @@ class ChildChunkApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id, segment_id): """Create child chunk.""" # check dataset @@ -332,6 +337,7 @@ class DatasetChildChunkApi(DatasetApiResource): """Resource for updating child chunks.""" @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): """Delete child chunk.""" # check dataset @@ -370,6 +376,7 @@ class DatasetChildChunkApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): """Update child chunk.""" # check dataset diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 20189053f4..a5492d70bd 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -138,15 +138,12 @@ class DatasetConfigManager: if not config.get("dataset_configs"): config["dataset_configs"] = {"retrieval_model": "single"} + if not isinstance(config["dataset_configs"], dict): + raise ValueError("dataset_configs must be of object type") + if not config["dataset_configs"].get("datasets"): config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} - if not isinstance(config["dataset_configs"], dict): - raise ValueError("dataset_configs must be of object type") - - if not isinstance(config["dataset_configs"], dict): - raise ValueError("dataset_configs must be of object type") - need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( "datasets", {} ).get("datasets") diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 9e6adc4b08..a8848b9534 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -367,6 +367,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param user: account or end user :param invoke_from: invoke from source :param application_generate_entity: application generate entity + :param workflow_execution_repository: repository for workflow execution :param workflow_node_execution_repository: repository for workflow node execution :param conversation: conversation :param stream: is stream diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 7f4770fc97..fd15bd9f50 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -195,6 +195,7 @@ class WorkflowAppGenerator(BaseAppGenerator): :param user: account or end user :param application_generate_entity: application generate entity :param invoke_from: invoke from source + :param workflow_execution_repository: repository for workflow execution :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream :param workflow_thread_pool_id: workflow thread pool id diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 995a30d44c..4886ffe244 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -542,8 +542,6 @@ class LBModelManager: return config - return None - def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None: """ Cooldown model load balancing config diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index e0dfe0c312..a98904102c 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -251,7 +251,7 @@ class OpsTraceManager: provider_config_map[tracing_provider]["trace_instance"], provider_config_map[tracing_provider]["config_class"], ) - decrypt_trace_config_key = str(decrypt_trace_config) + decrypt_trace_config_key = json.dumps(decrypt_trace_config, sort_keys=True) tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key) if tracing_instance is None: # create new tracing_instance and update the cache if it absent diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index e9275c31cc..e0d2857e97 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -156,9 +156,23 @@ class PluginInstallTaskStartResponse(BaseModel): task_id: str = Field(description="The ID of the install task.") -class PluginUploadResponse(BaseModel): +class PluginVerification(BaseModel): + """ + Verification of the plugin. + """ + + class AuthorizedCategory(StrEnum): + Langgenius = "langgenius" + Partner = "partner" + Community = "community" + + authorized_category: AuthorizedCategory = Field(description="The authorized category of the plugin.") + + +class PluginDecodeResponse(BaseModel): unique_identifier: str = Field(description="The unique identifier of the plugin.") manifest: PluginDeclaration + verification: Optional[PluginVerification] = Field(default=None, description="Basic verification information") class PluginOAuthAuthorizationUrlResponse(BaseModel): diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index 1cd2dc1be7..b7f7b31655 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -10,10 +10,10 @@ from core.plugin.entities.plugin import ( PluginInstallationSource, ) from core.plugin.entities.plugin_daemon import ( + PluginDecodeResponse, PluginInstallTask, PluginInstallTaskStartResponse, PluginListResponse, - PluginUploadResponse, ) from core.plugin.impl.base import BasePluginClient @@ -53,7 +53,7 @@ class PluginInstaller(BasePluginClient): tenant_id: str, pkg: bytes, verify_signature: bool = False, - ) -> PluginUploadResponse: + ) -> PluginDecodeResponse: """ Upload a plugin package and return the plugin unique identifier. """ @@ -68,7 +68,7 @@ class PluginInstaller(BasePluginClient): return self._request_with_plugin_daemon_response( "POST", f"plugin/{tenant_id}/management/install/upload/package", - PluginUploadResponse, + PluginDecodeResponse, files=body, data=data, ) @@ -176,6 +176,18 @@ class PluginInstaller(BasePluginClient): params={"plugin_unique_identifier": plugin_unique_identifier}, ) + def decode_plugin_from_identifier(self, tenant_id: str, plugin_unique_identifier: str) -> PluginDecodeResponse: + """ + Decode a plugin from an identifier. + """ + return self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/decode/from_identifier", + PluginDecodeResponse, + data={"plugin_unique_identifier": plugin_unique_identifier}, + headers={"Content-Type": "application/json"}, + ) + def fetch_plugin_installation_by_ids( self, tenant_id: str, plugin_ids: Sequence[str] ) -> Sequence[PluginInstallation]: diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 836a1398bf..83a4ac651f 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -22,6 +22,7 @@ class FirecrawlApp: "formats": ["markdown"], "onlyMainContent": True, "timeout": 30000, + "integration": "dify", } if params: json_data.update(params) @@ -39,7 +40,7 @@ class FirecrawlApp: def crawl_url(self, url, params=None) -> str: # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post headers = self._prepare_headers() - json_data = {"url": url} + json_data = {"url": url, "integration": "dify"} if params: json_data.update(params) response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers) @@ -49,7 +50,6 @@ class FirecrawlApp: return cast(str, job_id) else: self._handle_error(response, "start crawl job") - # FIXME: unreachable code for mypy return "" # unreachable def check_crawl_status(self, job_id) -> dict[str, Any]: @@ -82,7 +82,6 @@ class FirecrawlApp: ) else: self._handle_error(response, "check crawl status") - # FIXME: unreachable code for mypy return {} # unreachable def _format_crawl_status_response( @@ -126,4 +125,31 @@ class FirecrawlApp: def _handle_error(self, response, action) -> None: error_message = response.json().get("error", "Unknown error occurred") - raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return] + + def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/search + headers = self._prepare_headers() + json_data = { + "query": query, + "limit": 5, + "lang": "en", + "country": "us", + "timeout": 60000, + "ignoreInvalidURLs": False, + "scrapeOptions": {}, + "integration": "dify", + } + if params: + json_data.update(params) + response = self._post_request(f"{self.base_url}/v1/search", json_data, headers) + if response.status_code == 200: + response_data = response.json() + if not response_data.get("success"): + raise Exception(f"Search failed. Error: {response_data.get('warning', 'Unknown error')}") + return cast(dict[str, Any], response_data) + elif response.status_code in {402, 409, 500, 429, 408}: + self._handle_error(response, "perform search") + return {} # Avoid additional exception after handling error + else: + raise Exception(f"Failed to perform search. Status code: {response.status_code}") diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index e778b2cec4..75f3153697 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -104,7 +104,7 @@ class QAIndexProcessor(BaseIndexProcessor): def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: # check file type - if not file.filename or not file.filename.endswith(".csv"): + if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 6978860529..38c0b540d5 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -496,6 +496,8 @@ class DatasetRetrieval: all_documents = self.calculate_keyword_score(query, all_documents, top_k) elif index_type == "high_quality": all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold) + else: + all_documents = all_documents[:top_k] if top_k else all_documents self._on_query(query, dataset_ids, app_id, user_from, user_id) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index e5ead9dc56..f82562a498 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -151,12 +151,17 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): existing = session.scalar(select(WorkflowRun).where(WorkflowRun.id == domain_model.id_)) if not existing: # For new records, get the next sequence number - stmt = select(func.max(WorkflowRun.sequence_number)).where( - WorkflowRun.app_id == self._app_id, - WorkflowRun.tenant_id == self._tenant_id, + # in case multiple executions are created concurrently, use for update + stmt = ( + select(func.coalesce(func.max(WorkflowRun.sequence_number), 0) + 1) + .where( + WorkflowRun.app_id == self._app_id, + WorkflowRun.tenant_id == self._tenant_id, + ) + .with_for_update() ) - max_sequence = session.scalar(stmt) - db_model.sequence_number = (max_sequence or 0) + 1 + next_seq = session.scalar(stmt) + db_model.sequence_number = int(next_seq) if next_seq is not None else 1 else: # For updates, keep the existing sequence number db_model.sequence_number = existing.sequence_number diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 875cee17e6..ee2164f22f 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -639,26 +639,19 @@ class GraphEngine: retry_start_at = datetime.now(UTC).replace(tzinfo=None) # yield control to other threads time.sleep(0.001) - generator = node_instance.run() - for item in generator: - if isinstance(item, GraphEngineEvent): - if isinstance(item, BaseIterationEvent): - # add parallel info to iteration event - item.parallel_id = parallel_id - item.parallel_start_node_id = parallel_start_node_id - item.parent_parallel_id = parent_parallel_id - item.parent_parallel_start_node_id = parent_parallel_start_node_id - elif isinstance(item, BaseLoopEvent): - # add parallel info to loop event - item.parallel_id = parallel_id - item.parallel_start_node_id = parallel_start_node_id - item.parent_parallel_id = parent_parallel_id - item.parent_parallel_start_node_id = parent_parallel_start_node_id - - yield item + event_stream = node_instance.run() + for event in event_stream: + if isinstance(event, GraphEngineEvent): + # add parallel info to iteration event + if isinstance(event, BaseIterationEvent | BaseLoopEvent): + event.parallel_id = parallel_id + event.parallel_start_node_id = parallel_start_node_id + event.parent_parallel_id = parent_parallel_id + event.parent_parallel_start_node_id = parent_parallel_start_node_id + yield event else: - if isinstance(item, RunCompletedEvent): - run_result = item.run_result + if isinstance(event, RunCompletedEvent): + run_result = event.run_result if run_result.status == WorkflowNodeExecutionStatus.FAILED: if ( retries == max_retries @@ -694,7 +687,7 @@ class GraphEngine: # if run failed, handle error run_result = self._handle_continue_on_error( node_instance, - item.run_result, + event.run_result, self.graph_runtime_state.variable_pool, handle_exceptions=handle_exceptions, ) @@ -797,28 +790,28 @@ class GraphEngine: should_continue_retry = False break - elif isinstance(item, RunStreamChunkEvent): + elif isinstance(event, RunStreamChunkEvent): yield NodeRunStreamChunkEvent( id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, - chunk_content=item.chunk_content, - from_variable_selector=item.from_variable_selector, + chunk_content=event.chunk_content, + from_variable_selector=event.from_variable_selector, route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - elif isinstance(item, RunRetrieverResourceEvent): + elif isinstance(event, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, - retriever_resources=item.retriever_resources, - context=item.context, + retriever_resources=event.retriever_resources, + context=event.context, route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, diff --git a/api/models/model.py b/api/models/model.py index 229e77134e..fa83baa9cf 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -10,7 +10,6 @@ from core.plugin.entities.plugin import GenericProviderID from core.tools.entities.tool_entities import ToolProviderType from core.tools.signature import sign_tool_file from core.workflow.entities.workflow_execution import WorkflowExecutionStatus -from services.plugin.plugin_service import PluginService if TYPE_CHECKING: from models.workflow import Workflow @@ -169,6 +168,7 @@ class App(Base): @property def deleted_tools(self) -> list: from core.tools.tool_manager import ToolManager + from services.plugin.plugin_service import PluginService # get agent mode tools app_model_config = self.app_model_config diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index d2875180d8..1b026acfd6 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -421,7 +421,7 @@ class AppDslService: # Set icon type icon_type_value = icon_type or app_data.get("icon_type") - if icon_type_value in ["emoji", "link"]: + if icon_type_value in ["emoji", "link", "image"]: icon_type = icon_type_value else: icon_type = "emoji" diff --git a/api/services/errors/plugin.py b/api/services/errors/plugin.py new file mode 100644 index 0000000000..be5b144b3d --- /dev/null +++ b/api/services/errors/plugin.py @@ -0,0 +1,5 @@ +from services.errors.base import BaseServiceError + + +class PluginInstallationForbiddenError(BaseServiceError): + pass diff --git a/api/services/feature_service.py b/api/services/feature_service.py index be85a03e80..188caf3505 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -88,6 +88,26 @@ class WebAppAuthModel(BaseModel): allow_email_password_login: bool = False +class PluginInstallationScope(StrEnum): + NONE = "none" + OFFICIAL_ONLY = "official_only" + OFFICIAL_AND_SPECIFIC_PARTNERS = "official_and_specific_partners" + ALL = "all" + + +class PluginInstallationPermissionModel(BaseModel): + # Plugin installation scope – possible values: + # none: prohibit all plugin installations + # official_only: allow only Dify official plugins + # official_and_specific_partners: allow official and specific partner plugins + # all: allow installation of all plugins + plugin_installation_scope: PluginInstallationScope = PluginInstallationScope.ALL + + # If True, restrict plugin installation to the marketplace only + # Equivalent to ForceEnablePluginVerification + restrict_to_marketplace_only: bool = False + + class FeatureModel(BaseModel): billing: BillingModel = BillingModel() education: EducationModel = EducationModel() @@ -128,6 +148,7 @@ class SystemFeatureModel(BaseModel): license: LicenseModel = LicenseModel() branding: BrandingModel = BrandingModel() webapp_auth: WebAppAuthModel = WebAppAuthModel() + plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() class FeatureService: @@ -291,3 +312,12 @@ class FeatureService: features.license.workspaces.enabled = license_info["workspaces"]["enabled"] features.license.workspaces.limit = license_info["workspaces"]["limit"] features.license.workspaces.size = license_info["workspaces"]["used"] + + if "PluginInstallationPermission" in enterprise_info: + plugin_installation_info = enterprise_info["PluginInstallationPermission"] + features.plugin_installation_permission.plugin_installation_scope = plugin_installation_info[ + "pluginInstallationScope" + ] + features.plugin_installation_permission.restrict_to_marketplace_only = plugin_installation_info[ + "restrictToMarketplaceOnly" + ] diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index a8b64f27db..d7fb4a7c1b 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -17,11 +17,18 @@ from core.plugin.entities.plugin import ( PluginInstallation, PluginInstallationSource, ) -from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginListResponse, PluginUploadResponse +from core.plugin.entities.plugin_daemon import ( + PluginDecodeResponse, + PluginInstallTask, + PluginListResponse, + PluginVerification, +) from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.debugging import PluginDebuggingClient from core.plugin.impl.plugin import PluginInstaller from extensions.ext_redis import redis_client +from services.errors.plugin import PluginInstallationForbiddenError +from services.feature_service import FeatureService, PluginInstallationScope logger = logging.getLogger(__name__) @@ -86,6 +93,42 @@ class PluginService: logger.exception("failed to fetch latest plugin version") return result + @staticmethod + def _check_marketplace_only_permission(): + """ + Check if the marketplace only permission is enabled + """ + features = FeatureService.get_system_features() + if features.plugin_installation_permission.restrict_to_marketplace_only: + raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only") + + @staticmethod + def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]): + """ + Check the plugin installation scope + """ + features = FeatureService.get_system_features() + + match features.plugin_installation_permission.plugin_installation_scope: + case PluginInstallationScope.OFFICIAL_ONLY: + if ( + plugin_verification is None + or plugin_verification.authorized_category != PluginVerification.AuthorizedCategory.Langgenius + ): + raise PluginInstallationForbiddenError("Plugin installation is restricted to official only") + case PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS: + if plugin_verification is None or plugin_verification.authorized_category not in [ + PluginVerification.AuthorizedCategory.Langgenius, + PluginVerification.AuthorizedCategory.Partner, + ]: + raise PluginInstallationForbiddenError( + "Plugin installation is restricted to official and specific partners" + ) + case PluginInstallationScope.NONE: + raise PluginInstallationForbiddenError("Installing plugins is not allowed") + case PluginInstallationScope.ALL: + pass + @staticmethod def get_debugging_key(tenant_id: str) -> str: """ @@ -208,6 +251,8 @@ class PluginService: # check if plugin pkg is already downloaded manager = PluginInstaller() + features = FeatureService.get_system_features() + try: manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier) # already downloaded, skip, and record install event @@ -215,7 +260,14 @@ class PluginService: except Exception: # plugin not installed, download and upload pkg pkg = download_plugin_pkg(new_plugin_unique_identifier) - manager.upload_pkg(tenant_id, pkg, verify_signature=False) + response = manager.upload_pkg( + tenant_id, + pkg, + verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, + ) + + # check if the plugin is available to install + PluginService._check_plugin_installation_scope(response.verification) return manager.upgrade_plugin( tenant_id, @@ -239,6 +291,7 @@ class PluginService: """ Upgrade plugin with github """ + PluginService._check_marketplace_only_permission() manager = PluginInstaller() return manager.upgrade_plugin( tenant_id, @@ -253,33 +306,43 @@ class PluginService: ) @staticmethod - def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginUploadResponse: + def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse: """ Upload plugin package files returns: plugin_unique_identifier """ + PluginService._check_marketplace_only_permission() manager = PluginInstaller() - return manager.upload_pkg(tenant_id, pkg, verify_signature) + features = FeatureService.get_system_features() + response = manager.upload_pkg( + tenant_id, + pkg, + verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, + ) + return response @staticmethod def upload_pkg_from_github( tenant_id: str, repo: str, version: str, package: str, verify_signature: bool = False - ) -> PluginUploadResponse: + ) -> PluginDecodeResponse: """ Install plugin from github release package files, returns plugin_unique_identifier """ + PluginService._check_marketplace_only_permission() pkg = download_with_size_limit( f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE ) + features = FeatureService.get_system_features() manager = PluginInstaller() - return manager.upload_pkg( + response = manager.upload_pkg( tenant_id, pkg, - verify_signature, + verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, ) + return response @staticmethod def upload_bundle( @@ -289,11 +352,15 @@ class PluginService: Upload a plugin bundle and return the dependencies. """ manager = PluginInstaller() + PluginService._check_marketplace_only_permission() return manager.upload_bundle(tenant_id, bundle, verify_signature) @staticmethod def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]): + PluginService._check_marketplace_only_permission() + manager = PluginInstaller() + return manager.install_from_identifiers( tenant_id, plugin_unique_identifiers, @@ -307,6 +374,8 @@ class PluginService: Install plugin from github release package files, returns plugin_unique_identifier """ + PluginService._check_marketplace_only_permission() + manager = PluginInstaller() return manager.install_from_identifiers( tenant_id, @@ -322,28 +391,33 @@ class PluginService: ) @staticmethod - def fetch_marketplace_pkg( - tenant_id: str, plugin_unique_identifier: str, verify_signature: bool = False - ) -> PluginDeclaration: + def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration: """ Fetch marketplace package """ if not dify_config.MARKETPLACE_ENABLED: raise ValueError("marketplace is not enabled") + features = FeatureService.get_system_features() + manager = PluginInstaller() try: declaration = manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier) except Exception: pkg = download_plugin_pkg(plugin_unique_identifier) - declaration = manager.upload_pkg(tenant_id, pkg, verify_signature).manifest + response = manager.upload_pkg( + tenant_id, + pkg, + verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, + ) + # check if the plugin is available to install + PluginService._check_plugin_installation_scope(response.verification) + declaration = response.manifest return declaration @staticmethod - def install_from_marketplace_pkg( - tenant_id: str, plugin_unique_identifiers: Sequence[str], verify_signature: bool = False - ): + def install_from_marketplace_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]): """ Install plugin from marketplace package files, returns installation task id @@ -353,15 +427,26 @@ class PluginService: manager = PluginInstaller() + features = FeatureService.get_system_features() + # check if already downloaded for plugin_unique_identifier in plugin_unique_identifiers: try: manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier) + plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier) + # check if the plugin is available to install + PluginService._check_plugin_installation_scope(plugin_decode_response.verification) # already downloaded, skip except Exception: # plugin not installed, download and upload pkg pkg = download_plugin_pkg(plugin_unique_identifier) - manager.upload_pkg(tenant_id, pkg, verify_signature) + response = manager.upload_pkg( + tenant_id, + pkg, + verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, + ) + # check if the plugin is available to install + PluginService._check_plugin_installation_scope(response.verification) return manager.install_from_identifiers( tenant_id, diff --git a/api/tests/integration_tests/vdb/oceanbase/check_oceanbase_ready.py b/api/tests/integration_tests/vdb/oceanbase/check_oceanbase_ready.py deleted file mode 100644 index 94a51292ff..0000000000 --- a/api/tests/integration_tests/vdb/oceanbase/check_oceanbase_ready.py +++ /dev/null @@ -1,49 +0,0 @@ -import time - -import pymysql - - -def check_oceanbase_ready() -> bool: - try: - connection = pymysql.connect( - host="localhost", - port=2881, - user="root", - password="difyai123456", - ) - affected_rows = connection.query("SELECT 1") - return affected_rows == 1 - except Exception as e: - print(f"Oceanbase is not ready. Exception: {e}") - return False - finally: - if connection: - connection.close() - - -def main(): - max_attempts = 50 - retry_interval_seconds = 2 - is_oceanbase_ready = False - for attempt in range(max_attempts): - try: - is_oceanbase_ready = check_oceanbase_ready() - except Exception as e: - print(f"Oceanbase is not ready. Exception: {e}") - is_oceanbase_ready = False - - if is_oceanbase_ready: - break - else: - print(f"Attempt {attempt + 1} failed, retry in {retry_interval_seconds} seconds...") - time.sleep(retry_interval_seconds) - - if is_oceanbase_ready: - print("Oceanbase is ready.") - else: - print(f"Oceanbase is not ready after {max_attempts} attempting checks.") - exit(1) - - -if __name__ == "__main__": - main() diff --git a/docker/.env.example b/docker/.env.example index 195446b7ba..020deb6881 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -815,7 +815,8 @@ TEXT_GENERATION_TIMEOUT_MS=60000 # Environment Variables for db Service # ------------------------------ -PGUSER=${DB_USERNAME} +# The name of the default postgres user. +POSTGRES_USER=${DB_USERNAME} # The password for the default postgres user. POSTGRES_PASSWORD=${DB_PASSWORD} # The name of the default postgres database. diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 55e1b55599..9534d8ef7c 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -84,7 +84,7 @@ services: image: postgres:15-alpine restart: always environment: - PGUSER: ${PGUSER:-postgres} + POSTGRES_USER: ${POSTGRES_USER:-postgres} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456} POSTGRES_DB: ${POSTGRES_DB:-dify} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} @@ -451,6 +451,14 @@ services: OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OB_SERVER_IP: 127.0.0.1 MODE: mini + ports: + - "${OCEANBASE_VECTOR_PORT:-2881}:2881" + healthcheck: + test: [ 'CMD-SHELL', 'obclient -h127.0.0.1 -P2881 -uroot@test -p$${OB_TENANT_PASSWORD} -e "SELECT 1;"' ] + interval: 10s + retries: 30 + start_period: 30s + timeout: 10s # Oracle vector database oracle: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 2b98d098b3..b0f274c298 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -356,7 +356,7 @@ x-shared-env: &shared-api-worker-env MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} - PGUSER: ${PGUSER:-${DB_USERNAME}} + POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}} POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} @@ -591,7 +591,7 @@ services: image: postgres:15-alpine restart: always environment: - PGUSER: ${PGUSER:-postgres} + POSTGRES_USER: ${POSTGRES_USER:-postgres} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456} POSTGRES_DB: ${POSTGRES_DB:-dify} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} @@ -958,6 +958,14 @@ services: OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OB_SERVER_IP: 127.0.0.1 MODE: mini + ports: + - "${OCEANBASE_VECTOR_PORT:-2881}:2881" + healthcheck: + test: [ 'CMD-SHELL', 'obclient -h127.0.0.1 -P2881 -uroot@test -p$${OB_TENANT_PASSWORD} -e "SELECT 1;"' ] + interval: 10s + retries: 30 + start_period: 30s + timeout: 10s # Oracle vector database oracle: diff --git a/docker/middleware.env.example b/docker/middleware.env.example index f261d88d48..2eba62f594 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -1,7 +1,7 @@ # ------------------------------ # Environment Variables for db Service # ------------------------------ -PGUSER=postgres +POSTGRES_USER=postgres # The password for the default postgres user. POSTGRES_PASSWORD=difyai123456 # The name of the default postgres database. diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx index d0cc7ff91f..2aa192fb02 100644 --- a/web/app/(commonLayout)/apps/Apps.tsx +++ b/web/app/(commonLayout)/apps/Apps.tsx @@ -9,6 +9,7 @@ import { useTranslation } from 'react-i18next' import { useDebounceFn } from 'ahooks' import { RiApps2Line, + RiDragDropLine, RiExchange2Line, RiFile4Line, RiMessage3Line, @@ -16,7 +17,8 @@ import { } from '@remixicon/react' import AppCard from './AppCard' import NewAppCard from './NewAppCard' -import useAppsQueryState from './hooks/useAppsQueryState' +import useAppsQueryState from './hooks/use-apps-query-state' +import { useDSLDragDrop } from './hooks/use-dsl-drag-drop' import type { AppListResponse } from '@/models/app' import { fetchAppList } from '@/service/apps' import { useAppContext } from '@/context/app-context' @@ -29,6 +31,7 @@ import { useStore as useTagStore } from '@/app/components/base/tag-management/st import TagManagementModal from '@/app/components/base/tag-management' import TagFilter from '@/app/components/base/tag-management/filter' import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label' +import CreateFromDSLModal from '@/app/components/app/create-from-dsl-modal' const getKey = ( pageIndex: number, @@ -67,6 +70,9 @@ const Apps = () => { const [tagFilterValue, setTagFilterValue] = useState(tagIDs) const [searchKeywords, setSearchKeywords] = useState(keywords) const newAppCardRef = useRef(null) + const containerRef = useRef(null) + const [showCreateFromDSLModal, setShowCreateFromDSLModal] = useState(false) + const [droppedDSLFile, setDroppedDSLFile] = useState() const setKeywords = useCallback((keywords: string) => { setQuery(prev => ({ ...prev, keywords })) }, [setQuery]) @@ -74,6 +80,17 @@ const Apps = () => { setQuery(prev => ({ ...prev, tagIDs })) }, [setQuery]) + const handleDSLFileDropped = useCallback((file: File) => { + setDroppedDSLFile(file) + setShowCreateFromDSLModal(true) + }, []) + + const { dragging } = useDSLDragDrop({ + onDSLFileDropped: handleDSLFileDropped, + containerRef, + enabled: isCurrentWorkspaceEditor, + }) + const { data, isLoading, error, setSize, mutate } = useSWRInfinite( (pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords), fetchAppList, @@ -151,47 +168,81 @@ const Apps = () => { return ( <> -
- -
- - - handleKeywordsChange(e.target.value)} - onClear={() => handleKeywordsChange('')} +
+ {dragging && ( +
+
+ )} + +
+ +
+ + + handleKeywordsChange(e.target.value)} + onClear={() => handleKeywordsChange('')} + /> +
+ {(data && data[0].total > 0) + ?
+ {isCurrentWorkspaceEditor + && } + {data.map(({ data: apps }) => apps.map(app => ( + + )))} +
+ :
+ {isCurrentWorkspaceEditor + && } + +
} + + {isCurrentWorkspaceEditor && ( +
+ + {t('app.newApp.dropDSLToCreateApp')} +
+ )} + +
+ {showTagManagementModal && ( + + )}
- {(data && data[0].total > 0) - ?
- {isCurrentWorkspaceEditor - && } - {data.map(({ data: apps }) => apps.map(app => ( - - )))} -
- :
- {isCurrentWorkspaceEditor - && } - -
} - -
- {showTagManagementModal && ( - + + {showCreateFromDSLModal && ( + { + setShowCreateFromDSLModal(false) + setDroppedDSLFile(undefined) + }} + onSuccess={() => { + setShowCreateFromDSLModal(false) + setDroppedDSLFile(undefined) + mutate() + }} + droppedFile={droppedDSLFile} + /> )} ) diff --git a/web/app/(commonLayout)/apps/hooks/useAppsQueryState.ts b/web/app/(commonLayout)/apps/hooks/use-apps-query-state.ts similarity index 100% rename from web/app/(commonLayout)/apps/hooks/useAppsQueryState.ts rename to web/app/(commonLayout)/apps/hooks/use-apps-query-state.ts diff --git a/web/app/(commonLayout)/apps/hooks/use-dsl-drag-drop.ts b/web/app/(commonLayout)/apps/hooks/use-dsl-drag-drop.ts new file mode 100644 index 0000000000..96942ec54e --- /dev/null +++ b/web/app/(commonLayout)/apps/hooks/use-dsl-drag-drop.ts @@ -0,0 +1,72 @@ +import { useEffect, useState } from 'react' + +type DSLDragDropHookProps = { + onDSLFileDropped: (file: File) => void + containerRef: React.RefObject + enabled?: boolean +} + +export const useDSLDragDrop = ({ onDSLFileDropped, containerRef, enabled = true }: DSLDragDropHookProps) => { + const [dragging, setDragging] = useState(false) + + const handleDragEnter = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + if (e.dataTransfer?.types.includes('Files')) + setDragging(true) + } + + const handleDragOver = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + } + + const handleDragLeave = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + if (e.relatedTarget === null || !containerRef.current?.contains(e.relatedTarget as Node)) + setDragging(false) + } + + const handleDrop = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + setDragging(false) + + if (!e.dataTransfer) + return + + const files = [...e.dataTransfer.files] + if (files.length === 0) + return + + const file = files[0] + if (file.name.toLowerCase().endsWith('.yaml') || file.name.toLowerCase().endsWith('.yml')) + onDSLFileDropped(file) + } + + useEffect(() => { + if (!enabled) + return + + const current = containerRef.current + if (current) { + current.addEventListener('dragenter', handleDragEnter) + current.addEventListener('dragover', handleDragOver) + current.addEventListener('dragleave', handleDragLeave) + current.addEventListener('drop', handleDrop) + } + return () => { + if (current) { + current.removeEventListener('dragenter', handleDragEnter) + current.removeEventListener('dragover', handleDragOver) + current.removeEventListener('dragleave', handleDragLeave) + current.removeEventListener('drop', handleDrop) + } + } + }, [containerRef, enabled]) + + return { + dragging: enabled ? dragging : false, + } +} diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 1485964198..83a7ffd553 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -314,10 +314,10 @@ const AppPublisher = ({ {!isAppAccessSet &&

{t('app.publishApp.notSetDesc')}

}
}
- + } > @@ -326,10 +326,10 @@ const AppPublisher = ({ {appDetail?.mode === 'workflow' || appDetail?.mode === 'completion' ? ( - + } > diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index 9739ac47ea..8faafe05a8 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { MouseEventHandler } from 'react' -import { useMemo, useRef, useState } from 'react' +import { useEffect, useMemo, useRef, useState } from 'react' import { useRouter } from 'next/navigation' import { useContext } from 'use-context-selector' import { useTranslation } from 'react-i18next' @@ -35,6 +35,7 @@ type CreateFromDSLModalProps = { onClose: () => void activeTab?: string dslUrl?: string + droppedFile?: File } export enum CreateFromDSLModalTab { @@ -42,11 +43,11 @@ export enum CreateFromDSLModalTab { FROM_URL = 'from-url', } -const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDSLModalTab.FROM_FILE, dslUrl = '' }: CreateFromDSLModalProps) => { +const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDSLModalTab.FROM_FILE, dslUrl = '', droppedFile }: CreateFromDSLModalProps) => { const { push } = useRouter() const { t } = useTranslation() const { notify } = useContext(ToastContext) - const [currentFile, setDSLFile] = useState() + const [currentFile, setDSLFile] = useState(droppedFile) const [fileContent, setFileContent] = useState() const [currentTab, setCurrentTab] = useState(activeTab) const [dslUrlValue, setDslUrlValue] = useState(dslUrl) @@ -78,6 +79,11 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS const isCreatingRef = useRef(false) + useEffect(() => { + if (droppedFile) + handleFile(droppedFile) + }, [droppedFile]) + const onCreate: MouseEventHandler = async () => { if (currentTab === CreateFromDSLModalTab.FROM_FILE && !currentFile) return diff --git a/web/app/components/app/overview/embedded/index.tsx b/web/app/components/app/overview/embedded/index.tsx index 691b727b8e..b48eac5458 100644 --- a/web/app/components/app/overview/embedded/index.tsx +++ b/web/app/components/app/overview/embedded/index.tsx @@ -50,6 +50,10 @@ const OPTION_MAP = { // user_id: 'YOU CAN DEFINE USER ID HERE', // conversation_id: 'YOU CAN DEFINE CONVERSATION ID HERE, IT MUST BE A VALID UUID', }, + userVariables: { + // avatar_url: 'YOU CAN DEFINE USER AVATAR URL HERE', + // name: 'YOU CAN DEFINE USER NAME HERE', + }, }