From 7bcaa513fa2cc007b6c619b569a2ef090174ceb9 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 3 Sep 2025 08:56:00 +0800 Subject: [PATCH 1/7] chore: remove duplicate test helper classes from api root directory (#25024) --- api/child_class.py | 11 ----------- api/lazy_load_class.py | 11 ----------- 2 files changed, 22 deletions(-) delete mode 100644 api/child_class.py delete mode 100644 api/lazy_load_class.py diff --git a/api/child_class.py b/api/child_class.py deleted file mode 100644 index b210607b92..0000000000 --- a/api/child_class.py +++ /dev/null @@ -1,11 +0,0 @@ -from tests.integration_tests.utils.parent_class import ParentClass - - -class ChildClass(ParentClass): - """Test child class for module import helper tests""" - - def __init__(self, name): - super().__init__(name) - - def get_name(self): - return f"Child: {self.name}" diff --git a/api/lazy_load_class.py b/api/lazy_load_class.py deleted file mode 100644 index dd3c2a16e8..0000000000 --- a/api/lazy_load_class.py +++ /dev/null @@ -1,11 +0,0 @@ -from tests.integration_tests.utils.parent_class import ParentClass - - -class LazyLoadChildClass(ParentClass): - """Test lazy load child class for module import helper tests""" - - def __init__(self, name): - super().__init__(name) - - def get_name(self): - return self.name From f540d0b74708400150cdfd15091482bd7a59f21b Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 3 Sep 2025 08:56:23 +0800 Subject: [PATCH 2/7] chore: remove ty type checker from reformat script and pre-commit hooks (#25021) --- dev/reformat | 3 --- web/.husky/pre-commit | 9 --------- 2 files changed, 12 deletions(-) diff --git a/dev/reformat b/dev/reformat index 9e4f5d2a59..71cb6abb1e 100755 --- a/dev/reformat +++ b/dev/reformat @@ -14,8 +14,5 @@ uv run --directory api --dev ruff format ./ # run dotenv-linter linter uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example -# run ty check -dev/ty-check - # run mypy check dev/mypy-check diff --git a/web/.husky/pre-commit b/web/.husky/pre-commit index 55a8124938..2ad3922e99 100644 --- a/web/.husky/pre-commit +++ b/web/.husky/pre-commit @@ -41,15 +41,6 @@ if $api_modified; then echo "Please run 'dev/reformat' to fix the fixable linting errors." exit 1 fi - - # run ty checks - uv run --directory api --dev ty check || status=$? - status=${status:-0} - if [ $status -ne 0 ]; then - echo "ty type checker on api module error, exit code: $status" - echo "Please run 'dev/ty-check' to check the type errors." - exit 1 - fi fi if $web_modified; then From bc9efa7ea827c931a4ac4454851f2d3c8d3765fb Mon Sep 17 00:00:00 2001 From: Yongtao Huang Date: Wed, 3 Sep 2025 08:56:48 +0800 Subject: [PATCH 3/7] Refactor: use DatasourceType.XX.value instead of hardcoded (#25015) Signed-off-by: Yongtao Huang Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/workflow.py | 1 - api/controllers/console/datasets/data_source.py | 3 ++- api/controllers/console/datasets/datasets.py | 9 ++++++--- api/controllers/console/datasets/datasets_document.py | 9 +++++---- api/core/indexing_runner.py | 9 ++++++--- api/core/rag/extractor/extract_processor.py | 4 ++-- 6 files changed, 21 insertions(+), 14 deletions(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index e36f308bd4..9f829e27fd 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -526,7 +526,6 @@ class PublishedWorkflowApi(Resource): ) app_model.workflow_id = workflow.id - db.session.commit() workflow_created_at = TimestampField().format(workflow.created_at) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 6083a53bec..e4d5f1be6e 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -10,6 +10,7 @@ from werkzeug.exceptions import NotFound from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required from core.indexing_runner import IndexingRunner +from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db @@ -214,7 +215,7 @@ class DataSourceNotionApi(Resource): workspace_id = notion_info["workspace_id"] for page in notion_info["pages"]: extract_setting = ExtractSetting( - datasource_type="notion_import", + datasource_type=DatasourceType.NOTION.value, notion_info={ "notion_workspace_id": workspace_id, "notion_obj_id": page["page_id"], diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index a5a18e7f33..11b7b1fec0 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -22,6 +22,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.plugin import ModelProviderID from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db @@ -422,7 +423,9 @@ class DatasetIndexingEstimateApi(Resource): if file_details: for file_detail in file_details: extract_setting = ExtractSetting( - datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"] + datasource_type=DatasourceType.FILE.value, + upload_file=file_detail, + document_model=args["doc_form"], ) extract_settings.append(extract_setting) elif args["info_list"]["data_source_type"] == "notion_import": @@ -431,7 +434,7 @@ class DatasetIndexingEstimateApi(Resource): workspace_id = notion_info["workspace_id"] for page in notion_info["pages"]: extract_setting = ExtractSetting( - datasource_type="notion_import", + datasource_type=DatasourceType.NOTION.value, notion_info={ "notion_workspace_id": workspace_id, "notion_obj_id": page["page_id"], @@ -445,7 +448,7 @@ class DatasetIndexingEstimateApi(Resource): website_info_list = args["info_list"]["website_info_list"] for url in website_info_list["urls"]: extract_setting = ExtractSetting( - datasource_type="website_crawl", + datasource_type=DatasourceType.WEBSITE.value, website_info={ "provider": website_info_list["provider"], "job_id": website_info_list["job_id"], diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 22bb81f9e3..f9703f5a21 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -40,6 +40,7 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.plugin.impl.exc import PluginDaemonClientSideError +from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from extensions.ext_database import db from fields.document_fields import ( @@ -425,7 +426,7 @@ class DocumentIndexingEstimateApi(DocumentResource): raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", upload_file=file, document_model=document.doc_form + datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form ) indexing_runner = IndexingRunner() @@ -485,13 +486,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form + datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form ) extract_settings.append(extract_setting) elif document.data_source_type == "notion_import": extract_setting = ExtractSetting( - datasource_type="notion_import", + datasource_type=DatasourceType.NOTION.value, notion_info={ "notion_workspace_id": data_source_info["notion_workspace_id"], "notion_obj_id": data_source_info["notion_page_id"], @@ -503,7 +504,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): extract_settings.append(extract_setting) elif document.data_source_type == "website_crawl": extract_setting = ExtractSetting( - datasource_type="website_crawl", + datasource_type=DatasourceType.WEBSITE.value, website_info={ "provider": data_source_info["provider"], "job_id": data_source_info["job_id"], diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 4a768618f5..d31109f7a7 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -19,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore +from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor @@ -340,7 +341,9 @@ class IndexingRunner: if file_detail: extract_setting = ExtractSetting( - datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form + datasource_type=DatasourceType.FILE.value, + upload_file=file_detail, + document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) elif dataset_document.data_source_type == "notion_import": @@ -351,7 +354,7 @@ class IndexingRunner: ): raise ValueError("no notion import info found") extract_setting = ExtractSetting( - datasource_type="notion_import", + datasource_type=DatasourceType.NOTION.value, notion_info={ "notion_workspace_id": data_source_info["notion_workspace_id"], "notion_obj_id": data_source_info["notion_page_id"], @@ -371,7 +374,7 @@ class IndexingRunner: ): raise ValueError("no website import info found") extract_setting = ExtractSetting( - datasource_type="website_crawl", + datasource_type=DatasourceType.WEBSITE.value, website_info={ "provider": data_source_info["provider"], "job_id": data_source_info["job_id"], diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index e6b28b1bf4..b5ea08173b 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -45,7 +45,7 @@ class ExtractProcessor: cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False ) -> Union[list[Document], str]: extract_setting = ExtractSetting( - datasource_type="upload_file", upload_file=upload_file, document_model="text_model" + datasource_type=DatasourceType.FILE.value, upload_file=upload_file, document_model="text_model" ) if return_text: delimiter = "\n" @@ -76,7 +76,7 @@ class ExtractProcessor: # https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521 file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}" Path(file_path).write_bytes(response.content) - extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") + extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE.value, document_model="text_model") if return_text: delimiter = "\n" return delimiter.join( From c0bd35594e2b311a025a49628de729a302c6ac7c Mon Sep 17 00:00:00 2001 From: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Wed, 3 Sep 2025 09:20:16 +0800 Subject: [PATCH 4/7] feat: add test containers based tests for tools manage service (#25028) --- .../test_workflow_tools_manage_service.py | 716 ++++++++++++++++++ 1 file changed, 716 insertions(+) create mode 100644 api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py new file mode 100644 index 0000000000..cb1e79d507 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -0,0 +1,716 @@ +import json +from unittest.mock import patch + +import pytest +from faker import Faker + +from models.tools import WorkflowToolProvider +from models.workflow import Workflow as WorkflowModel +from services.account_service import AccountService, TenantService +from services.app_service import AppService +from services.tools.workflow_tools_manage_service import WorkflowToolManageService + + +class TestWorkflowToolManageService: + """Integration tests for WorkflowToolManageService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + patch( + "services.tools.workflow_tools_manage_service.WorkflowToolProviderController" + ) as mock_workflow_tool_provider_controller, + patch("services.tools.workflow_tools_manage_service.ToolLabelManager") as mock_tool_label_manager, + patch("services.tools.workflow_tools_manage_service.ToolTransformService") as mock_tool_transform_service, + ): + # Setup default mock returns for app service + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Mock ModelManager for model configuration + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + + # Mock WorkflowToolProviderController + mock_workflow_tool_provider_controller.from_db.return_value = None + + # Mock ToolLabelManager + mock_tool_label_manager.update_tool_labels.return_value = None + + # Mock ToolTransformService + mock_tool_transform_service.workflow_provider_to_controller.return_value = None + + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + "workflow_tool_provider_controller": mock_workflow_tool_provider_controller, + "tool_label_manager": mock_tool_label_manager, + "tool_transform_service": mock_tool_transform_service, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account, workflow) - Created app, account and workflow instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app with realistic data + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "workflow", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Create workflow for the app + workflow = WorkflowModel( + tenant_id=tenant.id, + app_id=app.id, + type="workflow", + version="1.0.0", + graph=json.dumps({}), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + ) + + from extensions.ext_database import db + + db.session.add(workflow) + db.session.commit() + + # Update app to reference the workflow + app.workflow_id = workflow.id + db.session.commit() + + return app, account, workflow + + def _create_test_workflow_tool_parameters(self): + """Helper method to create valid workflow tool parameters.""" + return [ + { + "name": "input_text", + "description": "Input text for processing", + "form": "form", + "type": "string", + "required": True, + }, + { + "name": "output_format", + "description": "Output format specification", + "form": "form", + "type": "select", + "required": False, + }, + ] + + def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful workflow tool creation with valid parameters. + + This test verifies: + - Proper workflow tool creation with all required fields + - Correct database state after creation + - Proper relationship establishment + - External service integration + - Return value correctness + """ + fake = Faker() + + # Create test data + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup workflow tool creation parameters + tool_name = fake.word() + tool_label = fake.word() + tool_icon = {"type": "emoji", "emoji": "🔧"} + tool_description = fake.text(max_nb_chars=200) + tool_parameters = self._create_test_workflow_tool_parameters() + tool_privacy_policy = fake.text(max_nb_chars=100) + tool_labels = ["automation", "workflow"] + + # Execute the method under test + result = WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=tool_name, + label=tool_label, + icon=tool_icon, + description=tool_description, + parameters=tool_parameters, + privacy_policy=tool_privacy_policy, + labels=tool_labels, + ) + + # Verify the result + assert result == {"result": "success"} + + # Verify database state + from extensions.ext_database import db + + # Check if workflow tool provider was created + created_tool_provider = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + WorkflowToolProvider.app_id == app.id, + ) + .first() + ) + + assert created_tool_provider is not None + assert created_tool_provider.name == tool_name + assert created_tool_provider.label == tool_label + assert created_tool_provider.icon == json.dumps(tool_icon) + assert created_tool_provider.description == tool_description + assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters) + assert created_tool_provider.privacy_policy == tool_privacy_policy + assert created_tool_provider.version == workflow.version + assert created_tool_provider.user_id == account.id + assert created_tool_provider.tenant_id == account.current_tenant.id + assert created_tool_provider.app_id == app.id + + # Verify external service calls + mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.assert_called_once() + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once() + mock_external_service_dependencies[ + "tool_transform_service" + ].workflow_provider_to_controller.assert_called_once() + + def test_create_workflow_tool_duplicate_name_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow tool creation fails when name already exists. + + This test verifies: + - Proper error handling for duplicate tool names + - Database constraint enforcement + - Correct error message + """ + fake = Faker() + + # Create test data + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create first workflow tool + first_tool_name = fake.word() + first_tool_parameters = self._create_test_workflow_tool_parameters() + + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=first_tool_name, + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=first_tool_parameters, + ) + + # Attempt to create second workflow tool with same name + second_tool_parameters = self._create_test_workflow_tool_parameters() + + with pytest.raises(ValueError) as exc_info: + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=first_tool_name, # Same name + label=fake.word(), + icon={"type": "emoji", "emoji": "⚙️"}, + description=fake.text(max_nb_chars=200), + parameters=second_tool_parameters, + ) + + # Verify error message + assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value) + + # Verify only one tool was created + from extensions.ext_database import db + + tool_count = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + ) + .count() + ) + + assert tool_count == 1 + + def test_create_workflow_tool_invalid_app_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow tool creation fails when app does not exist. + + This test verifies: + - Proper error handling for non-existent apps + - Correct error message + - No database changes when app is invalid + """ + fake = Faker() + + # Create test data + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + # Generate non-existent app ID + non_existent_app_id = fake.uuid4() + + # Attempt to create workflow tool with non-existent app + tool_parameters = self._create_test_workflow_tool_parameters() + + with pytest.raises(ValueError) as exc_info: + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=non_existent_app_id, # Non-existent app ID + name=fake.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=tool_parameters, + ) + + # Verify error message + assert f"App {non_existent_app_id} not found" in str(exc_info.value) + + # Verify no workflow tool was created + from extensions.ext_database import db + + tool_count = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + ) + .count() + ) + + assert tool_count == 0 + + def test_create_workflow_tool_invalid_parameters_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow tool creation fails when parameters are invalid. + + This test verifies: + - Proper error handling for invalid parameter configurations + - Parameter validation enforcement + - Correct error message + """ + fake = Faker() + + # Create test data + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup invalid workflow tool parameters (missing required fields) + invalid_parameters = [ + { + "name": "input_text", + # Missing description and form fields + "type": "string", + "required": True, + } + ] + + # Attempt to create workflow tool with invalid parameters + with pytest.raises(ValueError) as exc_info: + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=fake.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=invalid_parameters, + ) + + # Verify error message contains validation error + assert "validation error" in str(exc_info.value).lower() + + # Verify no workflow tool was created + from extensions.ext_database import db + + tool_count = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + ) + .count() + ) + + assert tool_count == 0 + + def test_create_workflow_tool_duplicate_app_id_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow tool creation fails when app_id already exists. + + This test verifies: + - Proper error handling for duplicate app_id + - Database constraint enforcement for app_id uniqueness + - Correct error message + """ + fake = Faker() + + # Create test data + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create first workflow tool + first_tool_name = fake.word() + first_tool_parameters = self._create_test_workflow_tool_parameters() + + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=first_tool_name, + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=first_tool_parameters, + ) + + # Attempt to create second workflow tool with same app_id but different name + second_tool_name = fake.word() + second_tool_parameters = self._create_test_workflow_tool_parameters() + + with pytest.raises(ValueError) as exc_info: + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, # Same app_id + name=second_tool_name, # Different name + label=fake.word(), + icon={"type": "emoji", "emoji": "⚙️"}, + description=fake.text(max_nb_chars=200), + parameters=second_tool_parameters, + ) + + # Verify error message + assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value) + + # Verify only one tool was created + from extensions.ext_database import db + + tool_count = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + ) + .count() + ) + + assert tool_count == 1 + + def test_create_workflow_tool_workflow_not_found_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow tool creation fails when app has no workflow. + + This test verifies: + - Proper error handling for apps without workflows + - Correct error message + - No database changes when workflow is missing + """ + fake = Faker() + + # Create test data but without workflow + app, account, _ = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + # Remove workflow reference from app + from extensions.ext_database import db + + app.workflow_id = None + db.session.commit() + + # Attempt to create workflow tool for app without workflow + tool_parameters = self._create_test_workflow_tool_parameters() + + with pytest.raises(ValueError) as exc_info: + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=fake.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=tool_parameters, + ) + + # Verify error message + assert f"Workflow not found for app {app.id}" in str(exc_info.value) + + # Verify no workflow tool was created + tool_count = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + ) + .count() + ) + + assert tool_count == 0 + + def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful workflow tool update with valid parameters. + + This test verifies: + - Proper workflow tool update with all required fields + - Correct database state after update + - Proper relationship maintenance + - External service integration + - Return value correctness + """ + fake = Faker() + + # Create test data + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create initial workflow tool + initial_tool_name = fake.word() + initial_tool_parameters = self._create_test_workflow_tool_parameters() + + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=initial_tool_name, + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=initial_tool_parameters, + ) + + # Get the created tool + from extensions.ext_database import db + + created_tool = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + WorkflowToolProvider.app_id == app.id, + ) + .first() + ) + + # Setup update parameters + updated_tool_name = fake.word() + updated_tool_label = fake.word() + updated_tool_icon = {"type": "emoji", "emoji": "⚙️"} + updated_tool_description = fake.text(max_nb_chars=200) + updated_tool_parameters = self._create_test_workflow_tool_parameters() + updated_tool_privacy_policy = fake.text(max_nb_chars=100) + updated_tool_labels = ["automation", "updated"] + + # Execute the update method + result = WorkflowToolManageService.update_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_tool_id=created_tool.id, + name=updated_tool_name, + label=updated_tool_label, + icon=updated_tool_icon, + description=updated_tool_description, + parameters=updated_tool_parameters, + privacy_policy=updated_tool_privacy_policy, + labels=updated_tool_labels, + ) + + # Verify the result + assert result == {"result": "success"} + + # Verify database state was updated + db.session.refresh(created_tool) + assert created_tool.name == updated_tool_name + assert created_tool.label == updated_tool_label + assert created_tool.icon == json.dumps(updated_tool_icon) + assert created_tool.description == updated_tool_description + assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters) + assert created_tool.privacy_policy == updated_tool_privacy_policy + assert created_tool.version == workflow.version + assert created_tool.updated_at is not None + + # Verify external service calls + mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.assert_called() + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called() + mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called() + + def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test workflow tool update fails when tool does not exist. + + This test verifies: + - Proper error handling for non-existent tools + - Correct error message + - No database changes when tool is invalid + """ + fake = Faker() + + # Create test data + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + # Generate non-existent tool ID + non_existent_tool_id = fake.uuid4() + + # Attempt to update non-existent workflow tool + tool_parameters = self._create_test_workflow_tool_parameters() + + with pytest.raises(ValueError) as exc_info: + WorkflowToolManageService.update_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_tool_id=non_existent_tool_id, # Non-existent tool ID + name=fake.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=tool_parameters, + ) + + # Verify error message + assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value) + + # Verify no workflow tool was created + from extensions.ext_database import db + + tool_count = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + ) + .count() + ) + + assert tool_count == 0 + + def test_update_workflow_tool_same_name_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test workflow tool update succeeds when keeping the same name. + + This test verifies: + - Proper handling when updating tool with same name + - Database state maintenance + - Update timestamp is set + """ + fake = Faker() + + # Create test data + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create first workflow tool + first_tool_name = fake.word() + first_tool_parameters = self._create_test_workflow_tool_parameters() + + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=first_tool_name, + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=first_tool_parameters, + ) + + # Get the created tool + from extensions.ext_database import db + + created_tool = ( + db.session.query(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == account.current_tenant.id, + WorkflowToolProvider.app_id == app.id, + ) + .first() + ) + + # Attempt to update tool with same name (should not fail) + result = WorkflowToolManageService.update_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_tool_id=created_tool.id, + name=first_tool_name, # Same name + label=fake.word(), + icon={"type": "emoji", "emoji": "⚙️"}, + description=fake.text(max_nb_chars=200), + parameters=first_tool_parameters, + ) + + # Verify update was successful + assert result == {"result": "success"} + + # Verify tool still exists with the same name + db.session.refresh(created_tool) + assert created_tool.name == first_tool_name + assert created_tool.updated_at is not None From 5092e5f6310243dbb695356912717cddc7283d56 Mon Sep 17 00:00:00 2001 From: Will Date: Wed, 3 Sep 2025 10:07:31 +0800 Subject: [PATCH 5/7] fix: workflow not published (#25030) --- api/controllers/console/app/workflow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 9f829e27fd..bf20a5ae62 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -526,6 +526,7 @@ class PublishedWorkflowApi(Resource): ) app_model.workflow_id = workflow.id + db.session.commit() # NOTE: this is necessary for update app_model.workflow_id workflow_created_at = TimestampField().format(workflow.created_at) From 60c5bdd62f50ca3df97039b65b35ccdba8631caa Mon Sep 17 00:00:00 2001 From: 17hz <0x149527@gmail.com> Date: Wed, 3 Sep 2025 10:39:07 +0800 Subject: [PATCH 6/7] fix: remove redundant z-index from Field component (#25034) --- web/app/components/workflow/nodes/_base/components/field.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/workflow/nodes/_base/components/field.tsx b/web/app/components/workflow/nodes/_base/components/field.tsx index d82ea027fb..44fa4f6f0a 100644 --- a/web/app/components/workflow/nodes/_base/components/field.tsx +++ b/web/app/components/workflow/nodes/_base/components/field.tsx @@ -38,7 +38,7 @@ const Field: FC = ({
supportFold && toggleFold()} - className={cn('sticky top-0 z-10 flex items-center justify-between bg-components-panel-bg', supportFold && 'cursor-pointer')}> + className={cn('sticky top-0 flex items-center justify-between bg-components-panel-bg', supportFold && 'cursor-pointer')}>
{title} {required && *} From c3820f55f43aa4d4cefd56cf1ce9bb2fd55f6597 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 3 Sep 2025 10:57:58 +0800 Subject: [PATCH 7/7] chore: translate Chinese comments to English in ClickZetta Volume storage module (#25037) --- .../clickzetta_volume_storage.py | 2 +- .../clickzetta_volume/file_lifecycle.py | 190 +++++++-------- .../clickzetta_volume/volume_permissions.py | 218 +++++++++--------- 3 files changed, 205 insertions(+), 205 deletions(-) diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 754c437fd7..5cc5314e25 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -87,7 +87,7 @@ class ClickZettaVolumeConfig(BaseModel): values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME")) values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_")) values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km")) - # 暂时禁用权限检查功能,直接设置为false + # Temporarily disable permission check feature, set directly to false values.setdefault("permission_check", False) # Validate required fields diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 2e0724f678..1ab5abdc99 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -1,7 +1,7 @@ -"""ClickZetta Volume文件生命周期管理 +"""ClickZetta Volume file lifecycle management -该模块提供文件版本控制、自动清理、备份和恢复等生命周期管理功能。 -支持知识库文件的完整生命周期管理。 +This module provides file lifecycle management features including version control, automatic cleanup, backup and restore. +Supports complete lifecycle management for knowledge base files. """ import json @@ -15,17 +15,17 @@ logger = logging.getLogger(__name__) class FileStatus(Enum): - """文件状态枚举""" + """File status enumeration""" - ACTIVE = "active" # 活跃状态 - ARCHIVED = "archived" # 已归档 - DELETED = "deleted" # 已删除(软删除) - BACKUP = "backup" # 备份文件 + ACTIVE = "active" # Active status + ARCHIVED = "archived" # Archived + DELETED = "deleted" # Deleted (soft delete) + BACKUP = "backup" # Backup file @dataclass class FileMetadata: - """文件元数据""" + """File metadata""" filename: str size: int | None @@ -38,7 +38,7 @@ class FileMetadata: parent_version: Optional[int] = None def to_dict(self) -> dict: - """转换为字典格式""" + """Convert to dictionary format""" data = asdict(self) data["created_at"] = self.created_at.isoformat() data["modified_at"] = self.modified_at.isoformat() @@ -47,7 +47,7 @@ class FileMetadata: @classmethod def from_dict(cls, data: dict) -> "FileMetadata": - """从字典创建实例""" + """Create instance from dictionary""" data = data.copy() data["created_at"] = datetime.fromisoformat(data["created_at"]) data["modified_at"] = datetime.fromisoformat(data["modified_at"]) @@ -56,14 +56,14 @@ class FileMetadata: class FileLifecycleManager: - """文件生命周期管理器""" + """File lifecycle manager""" def __init__(self, storage, dataset_id: Optional[str] = None): - """初始化生命周期管理器 + """Initialize lifecycle manager Args: - storage: ClickZetta Volume存储实例 - dataset_id: 数据集ID(用于Table Volume) + storage: ClickZetta Volume storage instance + dataset_id: Dataset ID (for Table Volume) """ self._storage = storage self._dataset_id = dataset_id @@ -72,21 +72,21 @@ class FileLifecycleManager: self._backup_prefix = ".backups/" self._deleted_prefix = ".deleted/" - # 获取权限管理器(如果存在) + # Get permission manager (if exists) self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None) def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata: - """保存文件并管理生命周期 + """Save file and manage lifecycle Args: - filename: 文件名 - data: 文件内容 - tags: 文件标签 + filename: File name + data: File content + tags: File tags Returns: - 文件元数据 + File metadata """ - # 权限检查 + # Permission check if not self._check_permission(filename, "save"): from .volume_permissions import VolumePermissionError @@ -98,28 +98,28 @@ class FileLifecycleManager: ) try: - # 1. 检查是否存在旧版本 + # 1. Check if old version exists metadata_dict = self._load_metadata() current_metadata = metadata_dict.get(filename) - # 2. 如果存在旧版本,创建版本备份 + # 2. If old version exists, create version backup if current_metadata: self._create_version_backup(filename, current_metadata) - # 3. 计算文件信息 + # 3. Calculate file information now = datetime.now() checksum = self._calculate_checksum(data) new_version = (current_metadata["version"] + 1) if current_metadata else 1 - # 4. 保存新文件 + # 4. Save new file self._storage.save(filename, data) - # 5. 创建元数据 + # 5. Create metadata created_at = now parent_version = None if current_metadata: - # 如果created_at是字符串,转换为datetime + # If created_at is string, convert to datetime if isinstance(current_metadata["created_at"], str): created_at = datetime.fromisoformat(current_metadata["created_at"]) else: @@ -138,7 +138,7 @@ class FileLifecycleManager: parent_version=parent_version, ) - # 6. 更新元数据 + # 6. Update metadata metadata_dict[filename] = file_metadata.to_dict() self._save_metadata(metadata_dict) @@ -150,13 +150,13 @@ class FileLifecycleManager: raise def get_file_metadata(self, filename: str) -> Optional[FileMetadata]: - """获取文件元数据 + """Get file metadata Args: - filename: 文件名 + filename: File name Returns: - 文件元数据,如果不存在返回None + File metadata, returns None if not exists """ try: metadata_dict = self._load_metadata() @@ -168,37 +168,37 @@ class FileLifecycleManager: return None def list_file_versions(self, filename: str) -> list[FileMetadata]: - """列出文件的所有版本 + """List all versions of a file Args: - filename: 文件名 + filename: File name Returns: - 文件版本列表,按版本号排序 + File version list, sorted by version number """ try: versions = [] - # 获取当前版本 + # Get current version current_metadata = self.get_file_metadata(filename) if current_metadata: versions.append(current_metadata) - # 获取历史版本 + # Get historical versions try: version_files = self._storage.scan(self._dataset_id or "", files=True) for file_path in version_files: if file_path.startswith(f"{self._version_prefix}{filename}.v"): - # 解析版本号 + # Parse version number version_str = file_path.split(".v")[-1].split(".")[0] try: version_num = int(version_str) - # 这里简化处理,实际应该从版本文件中读取元数据 - # 暂时创建基本的元数据信息 + # Simplified processing here, should actually read metadata from version file + # Temporarily create basic metadata information except ValueError: continue except: - # 如果无法扫描版本文件,只返回当前版本 + # If cannot scan version files, only return current version pass return sorted(versions, key=lambda x: x.version or 0, reverse=True) @@ -208,32 +208,32 @@ class FileLifecycleManager: return [] def restore_version(self, filename: str, version: int) -> bool: - """恢复文件到指定版本 + """Restore file to specified version Args: - filename: 文件名 - version: 要恢复的版本号 + filename: File name + version: Version number to restore Returns: - 恢复是否成功 + Whether restore succeeded """ try: version_filename = f"{self._version_prefix}{filename}.v{version}" - # 检查版本文件是否存在 + # Check if version file exists if not self._storage.exists(version_filename): logger.warning("Version %s of %s not found", version, filename) return False - # 读取版本文件内容 + # Read version file content version_data = self._storage.load_once(version_filename) - # 保存当前版本为备份 + # Save current version as backup current_metadata = self.get_file_metadata(filename) if current_metadata: self._create_version_backup(filename, current_metadata.to_dict()) - # 恢复文件 + # Restore file self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)}) return True @@ -242,21 +242,21 @@ class FileLifecycleManager: return False def archive_file(self, filename: str) -> bool: - """归档文件 + """Archive file Args: - filename: 文件名 + filename: File name Returns: - 归档是否成功 + Whether archive succeeded """ - # 权限检查 + # Permission check if not self._check_permission(filename, "archive"): logger.warning("Permission denied for archive operation on file: %s", filename) return False try: - # 更新文件状态为归档 + # Update file status to archived metadata_dict = self._load_metadata() if filename not in metadata_dict: logger.warning("File %s not found in metadata", filename) @@ -275,36 +275,36 @@ class FileLifecycleManager: return False def soft_delete_file(self, filename: str) -> bool: - """软删除文件(移动到删除目录) + """Soft delete file (move to deleted directory) Args: - filename: 文件名 + filename: File name Returns: - 删除是否成功 + Whether delete succeeded """ - # 权限检查 + # Permission check if not self._check_permission(filename, "delete"): logger.warning("Permission denied for soft delete operation on file: %s", filename) return False try: - # 检查文件是否存在 + # Check if file exists if not self._storage.exists(filename): logger.warning("File %s not found", filename) return False - # 读取文件内容 + # Read file content file_data = self._storage.load_once(filename) - # 移动到删除目录 + # Move to deleted directory deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}" self._storage.save(deleted_filename, file_data) - # 删除原文件 + # Delete original file self._storage.delete(filename) - # 更新元数据 + # Update metadata metadata_dict = self._load_metadata() if filename in metadata_dict: metadata_dict[filename]["status"] = FileStatus.DELETED.value @@ -319,27 +319,27 @@ class FileLifecycleManager: return False def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int: - """清理旧版本文件 + """Cleanup old version files Args: - max_versions: 保留的最大版本数 - max_age_days: 版本文件的最大保留天数 + max_versions: Maximum number of versions to keep + max_age_days: Maximum retention days for version files Returns: - 清理的文件数量 + Number of files cleaned """ try: cleaned_count = 0 - # 获取所有版本文件 + # Get all version files try: all_files = self._storage.scan(self._dataset_id or "", files=True) version_files = [f for f in all_files if f.startswith(self._version_prefix)] - # 按文件分组 + # Group by file file_versions: dict[str, list[tuple[int, str]]] = {} for version_file in version_files: - # 解析文件名和版本 + # Parse filename and version parts = version_file[len(self._version_prefix) :].split(".v") if len(parts) >= 2: base_filename = parts[0] @@ -352,12 +352,12 @@ class FileLifecycleManager: except ValueError: continue - # 清理每个文件的旧版本 + # Cleanup old versions for each file for base_filename, versions in file_versions.items(): - # 按版本号排序 + # Sort by version number versions.sort(key=lambda x: x[0], reverse=True) - # 保留最新的max_versions个版本,删除其余的 + # Keep the newest max_versions versions, delete the rest if len(versions) > max_versions: to_delete = versions[max_versions:] for version_num, version_file in to_delete: @@ -377,10 +377,10 @@ class FileLifecycleManager: return 0 def get_storage_statistics(self) -> dict[str, Any]: - """获取存储统计信息 + """Get storage statistics Returns: - 存储统计字典 + Storage statistics dictionary """ try: metadata_dict = self._load_metadata() @@ -402,7 +402,7 @@ class FileLifecycleManager: for filename, metadata in metadata_dict.items(): file_meta = FileMetadata.from_dict(metadata) - # 统计文件状态 + # Count file status if file_meta.status == FileStatus.ACTIVE: stats["active_files"] = (stats["active_files"] or 0) + 1 elif file_meta.status == FileStatus.ARCHIVED: @@ -410,13 +410,13 @@ class FileLifecycleManager: elif file_meta.status == FileStatus.DELETED: stats["deleted_files"] = (stats["deleted_files"] or 0) + 1 - # 统计大小 + # Count size stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0) - # 统计版本 + # Count versions stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0) - # 找出最新和最旧的文件 + # Find newest and oldest files if oldest_date is None or file_meta.created_at < oldest_date: oldest_date = file_meta.created_at stats["oldest_file"] = filename @@ -432,12 +432,12 @@ class FileLifecycleManager: return {} def _create_version_backup(self, filename: str, metadata: dict): - """创建版本备份""" + """Create version backup""" try: - # 读取当前文件内容 + # Read current file content current_data = self._storage.load_once(filename) - # 保存为版本文件 + # Save as version file version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}" self._storage.save(version_filename, current_data) @@ -447,7 +447,7 @@ class FileLifecycleManager: logger.warning("Failed to create version backup for %s: %s", filename, e) def _load_metadata(self) -> dict[str, Any]: - """加载元数据文件""" + """Load metadata file""" try: if self._storage.exists(self._metadata_file): metadata_content = self._storage.load_once(self._metadata_file) @@ -460,7 +460,7 @@ class FileLifecycleManager: return {} def _save_metadata(self, metadata_dict: dict): - """保存元数据文件""" + """Save metadata file""" try: metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False) self._storage.save(self._metadata_file, metadata_content.encode("utf-8")) @@ -470,45 +470,45 @@ class FileLifecycleManager: raise def _calculate_checksum(self, data: bytes) -> str: - """计算文件校验和""" + """Calculate file checksum""" import hashlib return hashlib.md5(data).hexdigest() def _check_permission(self, filename: str, operation: str) -> bool: - """检查文件操作权限 + """Check file operation permission Args: - filename: 文件名 - operation: 操作类型 + filename: File name + operation: Operation type Returns: True if permission granted, False otherwise """ - # 如果没有权限管理器,默认允许 + # If no permission manager, allow by default if not self._permission_manager: return True try: - # 根据操作类型映射到权限 + # Map operation type to permission operation_mapping = { "save": "save", "load": "load_once", "delete": "delete", - "archive": "delete", # 归档需要删除权限 - "restore": "save", # 恢复需要写权限 - "cleanup": "delete", # 清理需要删除权限 + "archive": "delete", # Archive requires delete permission + "restore": "save", # Restore requires write permission + "cleanup": "delete", # Cleanup requires delete permission "read": "load_once", "write": "save", } mapped_operation = operation_mapping.get(operation, operation) - # 检查权限 + # Check permission result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id) return bool(result) except Exception as e: logger.exception("Permission check failed for %s operation %s", filename, operation) - # 安全默认:权限检查失败时拒绝访问 + # Safe default: deny access when permission check fails return False diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py index 4801df5102..c5fde27b9f 100644 --- a/api/extensions/storage/clickzetta_volume/volume_permissions.py +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -1,7 +1,7 @@ -"""ClickZetta Volume权限管理机制 +"""ClickZetta Volume permission management mechanism -该模块提供Volume权限检查、验证和管理功能。 -根据ClickZetta的权限模型,不同Volume类型有不同的权限要求。 +This module provides Volume permission checking, validation and management features. +According to ClickZetta's permission model, different Volume types have different permission requirements. """ import logging @@ -12,29 +12,29 @@ logger = logging.getLogger(__name__) class VolumePermission(Enum): - """Volume权限类型枚举""" + """Volume permission type enumeration""" - READ = "SELECT" # 对应ClickZetta的SELECT权限 - WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限 - LIST = "SELECT" # 列出文件需要SELECT权限 - DELETE = "INSERT,UPDATE,DELETE" # 删除文件需要写权限 - USAGE = "USAGE" # External Volume需要的基本权限 + READ = "SELECT" # Corresponds to ClickZetta's SELECT permission + WRITE = "INSERT,UPDATE,DELETE" # Corresponds to ClickZetta's write permissions + LIST = "SELECT" # Listing files requires SELECT permission + DELETE = "INSERT,UPDATE,DELETE" # Deleting files requires write permissions + USAGE = "USAGE" # Basic permission required for External Volume class VolumePermissionManager: - """Volume权限管理器""" + """Volume permission manager""" def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None): - """初始化权限管理器 + """Initialize permission manager Args: - connection_or_config: ClickZetta连接对象或配置字典 - volume_type: Volume类型 (user|table|external) - volume_name: Volume名称 (用于external volume) + connection_or_config: ClickZetta connection object or configuration dictionary + volume_type: Volume type (user|table|external) + volume_name: Volume name (for external volume) """ - # 支持两种初始化方式:连接对象或配置字典 + # Support two initialization methods: connection object or configuration dictionary if isinstance(connection_or_config, dict): - # 从配置字典创建连接 + # Create connection from configuration dictionary import clickzetta # type: ignore[import-untyped] config = connection_or_config @@ -50,7 +50,7 @@ class VolumePermissionManager: self._volume_type = config.get("volume_type", volume_type) self._volume_name = config.get("volume_name", volume_name) else: - # 直接使用连接对象 + # Use connection object directly self._connection = connection_or_config self._volume_type = volume_type self._volume_name = volume_name @@ -61,14 +61,14 @@ class VolumePermissionManager: raise ValueError("volume_type is required") self._permission_cache: dict[str, set[str]] = {} - self._current_username = None # 将从连接中获取当前用户名 + self._current_username = None # Will get current username from connection def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool: - """检查用户是否有执行特定操作的权限 + """Check if user has permission to perform specific operation Args: - operation: 要执行的操作类型 - dataset_id: 数据集ID (用于table volume) + operation: Type of operation to perform + dataset_id: Dataset ID (for table volume) Returns: True if user has permission, False otherwise @@ -89,20 +89,20 @@ class VolumePermissionManager: return False def _check_user_volume_permission(self, operation: VolumePermission) -> bool: - """检查User Volume权限 + """Check User Volume permission - User Volume权限规则: - - 用户对自己的User Volume有全部权限 - - 只要用户能够连接到ClickZetta,就默认具有User Volume的基本权限 - - 更注重连接身份验证,而不是复杂的权限检查 + User Volume permission rules: + - User has full permissions on their own User Volume + - As long as user can connect to ClickZetta, they have basic User Volume permissions by default + - Focus more on connection authentication rather than complex permission checking """ try: - # 获取当前用户名 + # Get current username current_user = self._get_current_username() - # 检查基本连接状态 + # Check basic connection status with self._connection.cursor() as cursor: - # 简单的连接测试,如果能执行查询说明用户有基本权限 + # Simple connection test, if query can be executed user has basic permissions cursor.execute("SELECT 1") result = cursor.fetchone() @@ -121,17 +121,17 @@ class VolumePermissionManager: except Exception as e: logger.exception("User Volume permission check failed") - # 对于User Volume,如果权限检查失败,可能是配置问题,给出更友好的错误提示 + # For User Volume, if permission check fails, it might be a configuration issue, provide friendlier error message logger.info("User Volume permission check failed, but permission checking is disabled in this version") return False def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool: - """检查Table Volume权限 + """Check Table Volume permission - Table Volume权限规则: - - Table Volume权限继承对应表的权限 - - SELECT权限 -> 可以READ/LIST文件 - - INSERT,UPDATE,DELETE权限 -> 可以WRITE/DELETE文件 + Table Volume permission rules: + - Table Volume permissions inherit from corresponding table permissions + - SELECT permission -> can READ/LIST files + - INSERT,UPDATE,DELETE permissions -> can WRITE/DELETE files """ if not dataset_id: logger.warning("dataset_id is required for table volume permission check") @@ -140,11 +140,11 @@ class VolumePermissionManager: table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id try: - # 检查表权限 + # Check table permissions permissions = self._get_table_permissions(table_name) required_permissions = set(operation.value.split(",")) - # 检查是否有所需的所有权限 + # Check if has all required permissions has_permission = required_permissions.issubset(permissions) logger.debug( @@ -163,22 +163,22 @@ class VolumePermissionManager: return False def _check_external_volume_permission(self, operation: VolumePermission) -> bool: - """检查External Volume权限 + """Check External Volume permission - External Volume权限规则: - - 尝试获取对External Volume的权限 - - 如果权限检查失败,进行备选验证 - - 对于开发环境,提供更宽松的权限检查 + External Volume permission rules: + - Try to get permissions for External Volume + - If permission check fails, perform fallback verification + - For development environment, provide more lenient permission checking """ if not self._volume_name: logger.warning("volume_name is required for external volume permission check") return False try: - # 检查External Volume权限 + # Check External Volume permissions permissions = self._get_external_volume_permissions(self._volume_name) - # External Volume权限映射:根据操作类型确定所需权限 + # External Volume permission mapping: determine required permissions based on operation type required_permissions = set() if operation in [VolumePermission.READ, VolumePermission.LIST]: @@ -186,7 +186,7 @@ class VolumePermissionManager: elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]: required_permissions.add("write") - # 检查是否有所需的所有权限 + # Check if has all required permissions has_permission = required_permissions.issubset(permissions) logger.debug( @@ -198,11 +198,11 @@ class VolumePermissionManager: has_permission, ) - # 如果权限检查失败,尝试备选验证 + # If permission check fails, try fallback verification if not has_permission: logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name) - # 备选验证:尝试列出Volume来验证基本访问权限 + # Fallback verification: try listing Volume to verify basic access permissions try: with self._connection.cursor() as cursor: cursor.execute("SHOW VOLUMES") @@ -222,13 +222,13 @@ class VolumePermissionManager: return False def _get_table_permissions(self, table_name: str) -> set[str]: - """获取用户对指定表的权限 + """Get user permissions for specified table Args: - table_name: 表名 + table_name: Table name Returns: - 用户对该表的权限集合 + Set of user permissions for this table """ cache_key = f"table:{table_name}" @@ -239,18 +239,18 @@ class VolumePermissionManager: try: with self._connection.cursor() as cursor: - # 使用正确的ClickZetta语法检查当前用户权限 + # Use correct ClickZetta syntax to check current user permissions cursor.execute("SHOW GRANTS") grants = cursor.fetchall() - # 解析权限结果,查找对该表的权限 + # Parse permission results, find permissions for this table for grant in grants: - if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) + if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...) privilege = grant[0].upper() object_type = grant[1].upper() if len(grant) > 1 else "" object_name = grant[2] if len(grant) > 2 else "" - # 检查是否是对该表的权限 + # Check if it's permission for this table if ( object_type == "TABLE" and object_name == table_name @@ -263,7 +263,7 @@ class VolumePermissionManager: else: permissions.add(privilege) - # 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限 + # If no explicit permissions found, try executing a simple query to verify permissions if not permissions: try: cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1") @@ -273,15 +273,15 @@ class VolumePermissionManager: except Exception as e: logger.warning("Could not check table permissions for %s: %s", table_name, e) - # 安全默认:权限检查失败时拒绝访问 + # Safe default: deny access when permission check fails pass - # 缓存权限信息 + # Cache permission information self._permission_cache[cache_key] = permissions return permissions def _get_current_username(self) -> str: - """获取当前用户名""" + """Get current username""" if self._current_username: return self._current_username @@ -298,7 +298,7 @@ class VolumePermissionManager: return "unknown" def _get_user_permissions(self, username: str) -> set[str]: - """获取用户的基本权限集合""" + """Get user's basic permission set""" cache_key = f"user_permissions:{username}" if cache_key in self._permission_cache: @@ -308,17 +308,17 @@ class VolumePermissionManager: try: with self._connection.cursor() as cursor: - # 使用正确的ClickZetta语法检查当前用户权限 + # Use correct ClickZetta syntax to check current user permissions cursor.execute("SHOW GRANTS") grants = cursor.fetchall() - # 解析权限结果,查找用户的基本权限 + # Parse permission results, find user's basic permissions for grant in grants: - if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) + if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...) privilege = grant[0].upper() object_type = grant[1].upper() if len(grant) > 1 else "" - # 收集所有相关权限 + # Collect all relevant permissions if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: if privilege == "ALL": permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) @@ -327,21 +327,21 @@ class VolumePermissionManager: except Exception as e: logger.warning("Could not check user permissions for %s: %s", username, e) - # 安全默认:权限检查失败时拒绝访问 + # Safe default: deny access when permission check fails pass - # 缓存权限信息 + # Cache permission information self._permission_cache[cache_key] = permissions return permissions def _get_external_volume_permissions(self, volume_name: str) -> set[str]: - """获取用户对指定External Volume的权限 + """Get user permissions for specified External Volume Args: - volume_name: External Volume名称 + volume_name: External Volume name Returns: - 用户对该Volume的权限集合 + Set of user permissions for this Volume """ cache_key = f"external_volume:{volume_name}" @@ -352,15 +352,15 @@ class VolumePermissionManager: try: with self._connection.cursor() as cursor: - # 使用正确的ClickZetta语法检查Volume权限 + # Use correct ClickZetta syntax to check Volume permissions logger.info("Checking permissions for volume: %s", volume_name) cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}") grants = cursor.fetchall() logger.info("Raw grants result for %s: %s", volume_name, grants) - # 解析权限结果 - # 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to, + # Parse permission results + # Format: (granted_type, privilege, conditions, granted_on, object_name, granted_to, # grantee_name, grantor_name, grant_option, granted_time) for grant in grants: logger.info("Processing grant: %s", grant) @@ -378,7 +378,7 @@ class VolumePermissionManager: object_name, ) - # 检查是否是对该Volume的权限或者是层级权限 + # Check if it's permission for this Volume or hierarchical permission if ( granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name) ) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"): @@ -399,14 +399,14 @@ class VolumePermissionManager: logger.info("Final permissions for %s: %s", volume_name, permissions) - # 如果没有找到明确的权限,尝试查看Volume列表来验证基本权限 + # If no explicit permissions found, try viewing Volume list to verify basic permissions if not permissions: try: cursor.execute("SHOW VOLUMES") volumes = cursor.fetchall() for volume in volumes: if len(volume) > 0 and volume[0] == volume_name: - permissions.add("read") # 至少有读权限 + permissions.add("read") # At least has read permission logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name) break except Exception: @@ -414,7 +414,7 @@ class VolumePermissionManager: except Exception as e: logger.warning("Could not check external volume permissions for %s: %s", volume_name, e) - # 在权限检查失败时,尝试基本的Volume访问验证 + # When permission check fails, try basic Volume access verification try: with self._connection.cursor() as cursor: cursor.execute("SHOW VOLUMES") @@ -423,30 +423,30 @@ class VolumePermissionManager: if len(volume) > 0 and volume[0] == volume_name: logger.info("Basic volume access verified for %s", volume_name) permissions.add("read") - permissions.add("write") # 假设有写权限 + permissions.add("write") # Assume has write permission break except Exception as basic_e: logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e) - # 最后的备选方案:假设有基本权限 + # Last fallback: assume basic permissions permissions.add("read") - # 缓存权限信息 + # Cache permission information self._permission_cache[cache_key] = permissions return permissions def clear_permission_cache(self): - """清空权限缓存""" + """Clear permission cache""" self._permission_cache.clear() logger.debug("Permission cache cleared") def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]: - """获取权限摘要 + """Get permission summary Args: - dataset_id: 数据集ID (用于table volume) + dataset_id: Dataset ID (for table volume) Returns: - 权限摘要字典 + Permission summary dictionary """ summary = {} @@ -456,43 +456,43 @@ class VolumePermissionManager: return summary def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool: - """检查文件路径的权限继承 + """Check permission inheritance for file path Args: - file_path: 文件路径 - operation: 要执行的操作 + file_path: File path + operation: Operation to perform Returns: True if user has permission, False otherwise """ try: - # 解析文件路径 + # Parse file path path_parts = file_path.strip("/").split("/") if not path_parts: logger.warning("Invalid file path for permission inheritance check") return False - # 对于Table Volume,第一层是dataset_id + # For Table Volume, first layer is dataset_id if self._volume_type == "table": if len(path_parts) < 1: return False dataset_id = path_parts[0] - # 检查对dataset的权限 + # Check permissions for dataset has_dataset_permission = self.check_permission(operation, dataset_id) if not has_dataset_permission: logger.debug("Permission denied for dataset %s", dataset_id) return False - # 检查路径遍历攻击 + # Check path traversal attack if self._contains_path_traversal(file_path): logger.warning("Path traversal attack detected: %s", file_path) return False - # 检查是否访问敏感目录 + # Check if accessing sensitive directory if self._is_sensitive_path(file_path): logger.warning("Access to sensitive path denied: %s", file_path) return False @@ -501,20 +501,20 @@ class VolumePermissionManager: return True elif self._volume_type == "user": - # User Volume的权限继承 + # User Volume permission inheritance current_user = self._get_current_username() - # 检查是否试图访问其他用户的目录 + # Check if attempting to access other user's directory if len(path_parts) > 1 and path_parts[0] != current_user: logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0]) return False - # 检查基本权限 + # Check basic permissions return self.check_permission(operation) elif self._volume_type == "external": - # External Volume的权限继承 - # 检查对External Volume的权限 + # External Volume permission inheritance + # Check permissions for External Volume return self.check_permission(operation) else: @@ -526,8 +526,8 @@ class VolumePermissionManager: return False def _contains_path_traversal(self, file_path: str) -> bool: - """检查路径是否包含路径遍历攻击""" - # 检查常见的路径遍历模式 + """Check if path contains path traversal attack""" + # Check common path traversal patterns traversal_patterns = [ "../", "..\\", @@ -547,18 +547,18 @@ class VolumePermissionManager: if pattern in file_path_lower: return True - # 检查绝对路径 + # Check absolute path if file_path.startswith("/") or file_path.startswith("\\"): return True - # 检查Windows驱动器路径 + # Check Windows drive path if len(file_path) >= 2 and file_path[1] == ":": return True return False def _is_sensitive_path(self, file_path: str) -> bool: - """检查路径是否为敏感路径""" + """Check if path is sensitive path""" sensitive_patterns = [ "passwd", "shadow", @@ -582,11 +582,11 @@ class VolumePermissionManager: return any(pattern in file_path_lower for pattern in sensitive_patterns) def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool: - """验证操作权限 + """Validate operation permission Args: - operation: 操作名称 (save|load|exists|delete|scan) - dataset_id: 数据集ID + operation: Operation name (save|load|exists|delete|scan) + dataset_id: Dataset ID Returns: True if operation is allowed, False otherwise @@ -611,7 +611,7 @@ class VolumePermissionManager: class VolumePermissionError(Exception): - """Volume权限错误异常""" + """Volume permission error exception""" def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None): self.operation = operation @@ -623,15 +623,15 @@ class VolumePermissionError(Exception): def check_volume_permission( permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None ) -> None: - """权限检查装饰器函数 + """Permission check decorator function Args: - permission_manager: 权限管理器 - operation: 操作名称 - dataset_id: 数据集ID + permission_manager: Permission manager + operation: Operation name + dataset_id: Dataset ID Raises: - VolumePermissionError: 如果没有权限 + VolumePermissionError: If no permission """ if not permission_manager.validate_operation(operation, dataset_id): error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume"