mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 20:17:29 +08:00
Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine
This commit is contained in:
commit
8c41d95d03
@ -1,11 +0,0 @@
|
|||||||
from tests.integration_tests.utils.parent_class import ParentClass
|
|
||||||
|
|
||||||
|
|
||||||
class ChildClass(ParentClass):
|
|
||||||
"""Test child class for module import helper tests"""
|
|
||||||
|
|
||||||
def __init__(self, name):
|
|
||||||
super().__init__(name)
|
|
||||||
|
|
||||||
def get_name(self):
|
|
||||||
return f"Child: {self.name}"
|
|
||||||
@ -532,7 +532,7 @@ class PublishedWorkflowApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
app_model.workflow_id = workflow.id
|
app_model.workflow_id = workflow.id
|
||||||
db.session.commit()
|
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
|
||||||
|
|
||||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from werkzeug.exceptions import NotFound
|
|||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
|
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -214,7 +215,7 @@ class DataSourceNotionApi(Resource):
|
|||||||
workspace_id = notion_info["workspace_id"]
|
workspace_id = notion_info["workspace_id"]
|
||||||
for page in notion_info["pages"]:
|
for page in notion_info["pages"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="notion_import",
|
datasource_type=DatasourceType.NOTION.value,
|
||||||
notion_info={
|
notion_info={
|
||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_obj_id": page["page_id"],
|
"notion_obj_id": page["page_id"],
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from core.indexing_runner import IndexingRunner
|
|||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
|
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -422,7 +423,9 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
if file_details:
|
if file_details:
|
||||||
for file_detail in file_details:
|
for file_detail in file_details:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
|
datasource_type=DatasourceType.FILE.value,
|
||||||
|
upload_file=file_detail,
|
||||||
|
document_model=args["doc_form"],
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
elif args["info_list"]["data_source_type"] == "notion_import":
|
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||||
@ -431,7 +434,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
workspace_id = notion_info["workspace_id"]
|
workspace_id = notion_info["workspace_id"]
|
||||||
for page in notion_info["pages"]:
|
for page in notion_info["pages"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="notion_import",
|
datasource_type=DatasourceType.NOTION.value,
|
||||||
notion_info={
|
notion_info={
|
||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_obj_id": page["page_id"],
|
"notion_obj_id": page["page_id"],
|
||||||
@ -445,7 +448,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
website_info_list = args["info_list"]["website_info_list"]
|
website_info_list = args["info_list"]["website_info_list"]
|
||||||
for url in website_info_list["urls"]:
|
for url in website_info_list["urls"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="website_crawl",
|
datasource_type=DatasourceType.WEBSITE.value,
|
||||||
website_info={
|
website_info={
|
||||||
"provider": website_info_list["provider"],
|
"provider": website_info_list["provider"],
|
||||||
"job_id": website_info_list["job_id"],
|
"job_id": website_info_list["job_id"],
|
||||||
|
|||||||
@ -40,6 +40,7 @@ from core.model_manager import ModelManager
|
|||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
|
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.document_fields import (
|
from fields.document_fields import (
|
||||||
@ -425,7 +426,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||||||
raise NotFound("File not found.")
|
raise NotFound("File not found.")
|
||||||
|
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
|
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
||||||
)
|
)
|
||||||
|
|
||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
@ -485,13 +486,13 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
raise NotFound("File not found.")
|
raise NotFound("File not found.")
|
||||||
|
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
|
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
|
|
||||||
elif document.data_source_type == "notion_import":
|
elif document.data_source_type == "notion_import":
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="notion_import",
|
datasource_type=DatasourceType.NOTION.value,
|
||||||
notion_info={
|
notion_info={
|
||||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||||
"notion_obj_id": data_source_info["notion_page_id"],
|
"notion_obj_id": data_source_info["notion_page_id"],
|
||||||
@ -503,7 +504,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
elif document.data_source_type == "website_crawl":
|
elif document.data_source_type == "website_crawl":
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="website_crawl",
|
datasource_type=DatasourceType.WEBSITE.value,
|
||||||
website_info={
|
website_info={
|
||||||
"provider": data_source_info["provider"],
|
"provider": data_source_info["provider"],
|
||||||
"job_id": data_source_info["job_id"],
|
"job_id": data_source_info["job_id"],
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||||
|
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||||
@ -340,7 +341,9 @@ class IndexingRunner:
|
|||||||
|
|
||||||
if file_detail:
|
if file_detail:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
|
datasource_type=DatasourceType.FILE.value,
|
||||||
|
upload_file=file_detail,
|
||||||
|
document_model=dataset_document.doc_form,
|
||||||
)
|
)
|
||||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
|
||||||
elif dataset_document.data_source_type == "notion_import":
|
elif dataset_document.data_source_type == "notion_import":
|
||||||
@ -351,7 +354,7 @@ class IndexingRunner:
|
|||||||
):
|
):
|
||||||
raise ValueError("no notion import info found")
|
raise ValueError("no notion import info found")
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="notion_import",
|
datasource_type=DatasourceType.NOTION.value,
|
||||||
notion_info={
|
notion_info={
|
||||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||||
"notion_obj_id": data_source_info["notion_page_id"],
|
"notion_obj_id": data_source_info["notion_page_id"],
|
||||||
@ -371,7 +374,7 @@ class IndexingRunner:
|
|||||||
):
|
):
|
||||||
raise ValueError("no website import info found")
|
raise ValueError("no website import info found")
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="website_crawl",
|
datasource_type=DatasourceType.WEBSITE.value,
|
||||||
website_info={
|
website_info={
|
||||||
"provider": data_source_info["provider"],
|
"provider": data_source_info["provider"],
|
||||||
"job_id": data_source_info["job_id"],
|
"job_id": data_source_info["job_id"],
|
||||||
|
|||||||
@ -45,7 +45,7 @@ class ExtractProcessor:
|
|||||||
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
|
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
|
||||||
) -> Union[list[Document], str]:
|
) -> Union[list[Document], str]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type="upload_file", upload_file=upload_file, document_model="text_model"
|
datasource_type=DatasourceType.FILE.value, upload_file=upload_file, document_model="text_model"
|
||||||
)
|
)
|
||||||
if return_text:
|
if return_text:
|
||||||
delimiter = "\n"
|
delimiter = "\n"
|
||||||
@ -76,7 +76,7 @@ class ExtractProcessor:
|
|||||||
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
|
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
|
||||||
file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}"
|
file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}"
|
||||||
Path(file_path).write_bytes(response.content)
|
Path(file_path).write_bytes(response.content)
|
||||||
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
|
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE.value, document_model="text_model")
|
||||||
if return_text:
|
if return_text:
|
||||||
delimiter = "\n"
|
delimiter = "\n"
|
||||||
return delimiter.join(
|
return delimiter.join(
|
||||||
|
|||||||
@ -87,7 +87,7 @@ class ClickZettaVolumeConfig(BaseModel):
|
|||||||
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
|
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
|
||||||
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
|
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
|
||||||
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
|
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
|
||||||
# 暂时禁用权限检查功能,直接设置为false
|
# Temporarily disable permission check feature, set directly to false
|
||||||
values.setdefault("permission_check", False)
|
values.setdefault("permission_check", False)
|
||||||
|
|
||||||
# Validate required fields
|
# Validate required fields
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""ClickZetta Volume文件生命周期管理
|
"""ClickZetta Volume file lifecycle management
|
||||||
|
|
||||||
该模块提供文件版本控制、自动清理、备份和恢复等生命周期管理功能。
|
This module provides file lifecycle management features including version control, automatic cleanup, backup and restore.
|
||||||
支持知识库文件的完整生命周期管理。
|
Supports complete lifecycle management for knowledge base files.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@ -15,17 +15,17 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class FileStatus(Enum):
|
class FileStatus(Enum):
|
||||||
"""文件状态枚举"""
|
"""File status enumeration"""
|
||||||
|
|
||||||
ACTIVE = "active" # 活跃状态
|
ACTIVE = "active" # Active status
|
||||||
ARCHIVED = "archived" # 已归档
|
ARCHIVED = "archived" # Archived
|
||||||
DELETED = "deleted" # 已删除(软删除)
|
DELETED = "deleted" # Deleted (soft delete)
|
||||||
BACKUP = "backup" # 备份文件
|
BACKUP = "backup" # Backup file
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FileMetadata:
|
class FileMetadata:
|
||||||
"""文件元数据"""
|
"""File metadata"""
|
||||||
|
|
||||||
filename: str
|
filename: str
|
||||||
size: int | None
|
size: int | None
|
||||||
@ -38,7 +38,7 @@ class FileMetadata:
|
|||||||
parent_version: Optional[int] = None
|
parent_version: Optional[int] = None
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""转换为字典格式"""
|
"""Convert to dictionary format"""
|
||||||
data = asdict(self)
|
data = asdict(self)
|
||||||
data["created_at"] = self.created_at.isoformat()
|
data["created_at"] = self.created_at.isoformat()
|
||||||
data["modified_at"] = self.modified_at.isoformat()
|
data["modified_at"] = self.modified_at.isoformat()
|
||||||
@ -47,7 +47,7 @@ class FileMetadata:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict) -> "FileMetadata":
|
def from_dict(cls, data: dict) -> "FileMetadata":
|
||||||
"""从字典创建实例"""
|
"""Create instance from dictionary"""
|
||||||
data = data.copy()
|
data = data.copy()
|
||||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||||
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
|
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
|
||||||
@ -56,14 +56,14 @@ class FileMetadata:
|
|||||||
|
|
||||||
|
|
||||||
class FileLifecycleManager:
|
class FileLifecycleManager:
|
||||||
"""文件生命周期管理器"""
|
"""File lifecycle manager"""
|
||||||
|
|
||||||
def __init__(self, storage, dataset_id: Optional[str] = None):
|
def __init__(self, storage, dataset_id: Optional[str] = None):
|
||||||
"""初始化生命周期管理器
|
"""Initialize lifecycle manager
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
storage: ClickZetta Volume存储实例
|
storage: ClickZetta Volume storage instance
|
||||||
dataset_id: 数据集ID(用于Table Volume)
|
dataset_id: Dataset ID (for Table Volume)
|
||||||
"""
|
"""
|
||||||
self._storage = storage
|
self._storage = storage
|
||||||
self._dataset_id = dataset_id
|
self._dataset_id = dataset_id
|
||||||
@ -72,21 +72,21 @@ class FileLifecycleManager:
|
|||||||
self._backup_prefix = ".backups/"
|
self._backup_prefix = ".backups/"
|
||||||
self._deleted_prefix = ".deleted/"
|
self._deleted_prefix = ".deleted/"
|
||||||
|
|
||||||
# 获取权限管理器(如果存在)
|
# Get permission manager (if exists)
|
||||||
self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None)
|
self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None)
|
||||||
|
|
||||||
def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata:
|
def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata:
|
||||||
"""保存文件并管理生命周期
|
"""Save file and manage lifecycle
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: 文件名
|
filename: File name
|
||||||
data: 文件内容
|
data: File content
|
||||||
tags: 文件标签
|
tags: File tags
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
文件元数据
|
File metadata
|
||||||
"""
|
"""
|
||||||
# 权限检查
|
# Permission check
|
||||||
if not self._check_permission(filename, "save"):
|
if not self._check_permission(filename, "save"):
|
||||||
from .volume_permissions import VolumePermissionError
|
from .volume_permissions import VolumePermissionError
|
||||||
|
|
||||||
@ -98,28 +98,28 @@ class FileLifecycleManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 检查是否存在旧版本
|
# 1. Check if old version exists
|
||||||
metadata_dict = self._load_metadata()
|
metadata_dict = self._load_metadata()
|
||||||
current_metadata = metadata_dict.get(filename)
|
current_metadata = metadata_dict.get(filename)
|
||||||
|
|
||||||
# 2. 如果存在旧版本,创建版本备份
|
# 2. If old version exists, create version backup
|
||||||
if current_metadata:
|
if current_metadata:
|
||||||
self._create_version_backup(filename, current_metadata)
|
self._create_version_backup(filename, current_metadata)
|
||||||
|
|
||||||
# 3. 计算文件信息
|
# 3. Calculate file information
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
checksum = self._calculate_checksum(data)
|
checksum = self._calculate_checksum(data)
|
||||||
new_version = (current_metadata["version"] + 1) if current_metadata else 1
|
new_version = (current_metadata["version"] + 1) if current_metadata else 1
|
||||||
|
|
||||||
# 4. 保存新文件
|
# 4. Save new file
|
||||||
self._storage.save(filename, data)
|
self._storage.save(filename, data)
|
||||||
|
|
||||||
# 5. 创建元数据
|
# 5. Create metadata
|
||||||
created_at = now
|
created_at = now
|
||||||
parent_version = None
|
parent_version = None
|
||||||
|
|
||||||
if current_metadata:
|
if current_metadata:
|
||||||
# 如果created_at是字符串,转换为datetime
|
# If created_at is string, convert to datetime
|
||||||
if isinstance(current_metadata["created_at"], str):
|
if isinstance(current_metadata["created_at"], str):
|
||||||
created_at = datetime.fromisoformat(current_metadata["created_at"])
|
created_at = datetime.fromisoformat(current_metadata["created_at"])
|
||||||
else:
|
else:
|
||||||
@ -138,7 +138,7 @@ class FileLifecycleManager:
|
|||||||
parent_version=parent_version,
|
parent_version=parent_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 更新元数据
|
# 6. Update metadata
|
||||||
metadata_dict[filename] = file_metadata.to_dict()
|
metadata_dict[filename] = file_metadata.to_dict()
|
||||||
self._save_metadata(metadata_dict)
|
self._save_metadata(metadata_dict)
|
||||||
|
|
||||||
@ -150,13 +150,13 @@ class FileLifecycleManager:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def get_file_metadata(self, filename: str) -> Optional[FileMetadata]:
|
def get_file_metadata(self, filename: str) -> Optional[FileMetadata]:
|
||||||
"""获取文件元数据
|
"""Get file metadata
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: 文件名
|
filename: File name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
文件元数据,如果不存在返回None
|
File metadata, returns None if not exists
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
metadata_dict = self._load_metadata()
|
metadata_dict = self._load_metadata()
|
||||||
@ -168,37 +168,37 @@ class FileLifecycleManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def list_file_versions(self, filename: str) -> list[FileMetadata]:
|
def list_file_versions(self, filename: str) -> list[FileMetadata]:
|
||||||
"""列出文件的所有版本
|
"""List all versions of a file
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: 文件名
|
filename: File name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
文件版本列表,按版本号排序
|
File version list, sorted by version number
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
versions = []
|
versions = []
|
||||||
|
|
||||||
# 获取当前版本
|
# Get current version
|
||||||
current_metadata = self.get_file_metadata(filename)
|
current_metadata = self.get_file_metadata(filename)
|
||||||
if current_metadata:
|
if current_metadata:
|
||||||
versions.append(current_metadata)
|
versions.append(current_metadata)
|
||||||
|
|
||||||
# 获取历史版本
|
# Get historical versions
|
||||||
try:
|
try:
|
||||||
version_files = self._storage.scan(self._dataset_id or "", files=True)
|
version_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||||
for file_path in version_files:
|
for file_path in version_files:
|
||||||
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
|
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
|
||||||
# 解析版本号
|
# Parse version number
|
||||||
version_str = file_path.split(".v")[-1].split(".")[0]
|
version_str = file_path.split(".v")[-1].split(".")[0]
|
||||||
try:
|
try:
|
||||||
version_num = int(version_str)
|
version_num = int(version_str)
|
||||||
# 这里简化处理,实际应该从版本文件中读取元数据
|
# Simplified processing here, should actually read metadata from version file
|
||||||
# 暂时创建基本的元数据信息
|
# Temporarily create basic metadata information
|
||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
except:
|
except:
|
||||||
# 如果无法扫描版本文件,只返回当前版本
|
# If cannot scan version files, only return current version
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return sorted(versions, key=lambda x: x.version or 0, reverse=True)
|
return sorted(versions, key=lambda x: x.version or 0, reverse=True)
|
||||||
@ -208,32 +208,32 @@ class FileLifecycleManager:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
def restore_version(self, filename: str, version: int) -> bool:
|
def restore_version(self, filename: str, version: int) -> bool:
|
||||||
"""恢复文件到指定版本
|
"""Restore file to specified version
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: 文件名
|
filename: File name
|
||||||
version: 要恢复的版本号
|
version: Version number to restore
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
恢复是否成功
|
Whether restore succeeded
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
version_filename = f"{self._version_prefix}{filename}.v{version}"
|
version_filename = f"{self._version_prefix}{filename}.v{version}"
|
||||||
|
|
||||||
# 检查版本文件是否存在
|
# Check if version file exists
|
||||||
if not self._storage.exists(version_filename):
|
if not self._storage.exists(version_filename):
|
||||||
logger.warning("Version %s of %s not found", version, filename)
|
logger.warning("Version %s of %s not found", version, filename)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 读取版本文件内容
|
# Read version file content
|
||||||
version_data = self._storage.load_once(version_filename)
|
version_data = self._storage.load_once(version_filename)
|
||||||
|
|
||||||
# 保存当前版本为备份
|
# Save current version as backup
|
||||||
current_metadata = self.get_file_metadata(filename)
|
current_metadata = self.get_file_metadata(filename)
|
||||||
if current_metadata:
|
if current_metadata:
|
||||||
self._create_version_backup(filename, current_metadata.to_dict())
|
self._create_version_backup(filename, current_metadata.to_dict())
|
||||||
|
|
||||||
# 恢复文件
|
# Restore file
|
||||||
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
|
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -242,21 +242,21 @@ class FileLifecycleManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def archive_file(self, filename: str) -> bool:
|
def archive_file(self, filename: str) -> bool:
|
||||||
"""归档文件
|
"""Archive file
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: 文件名
|
filename: File name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
归档是否成功
|
Whether archive succeeded
|
||||||
"""
|
"""
|
||||||
# 权限检查
|
# Permission check
|
||||||
if not self._check_permission(filename, "archive"):
|
if not self._check_permission(filename, "archive"):
|
||||||
logger.warning("Permission denied for archive operation on file: %s", filename)
|
logger.warning("Permission denied for archive operation on file: %s", filename)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 更新文件状态为归档
|
# Update file status to archived
|
||||||
metadata_dict = self._load_metadata()
|
metadata_dict = self._load_metadata()
|
||||||
if filename not in metadata_dict:
|
if filename not in metadata_dict:
|
||||||
logger.warning("File %s not found in metadata", filename)
|
logger.warning("File %s not found in metadata", filename)
|
||||||
@ -275,36 +275,36 @@ class FileLifecycleManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def soft_delete_file(self, filename: str) -> bool:
|
def soft_delete_file(self, filename: str) -> bool:
|
||||||
"""软删除文件(移动到删除目录)
|
"""Soft delete file (move to deleted directory)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: 文件名
|
filename: File name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
删除是否成功
|
Whether delete succeeded
|
||||||
"""
|
"""
|
||||||
# 权限检查
|
# Permission check
|
||||||
if not self._check_permission(filename, "delete"):
|
if not self._check_permission(filename, "delete"):
|
||||||
logger.warning("Permission denied for soft delete operation on file: %s", filename)
|
logger.warning("Permission denied for soft delete operation on file: %s", filename)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查文件是否存在
|
# Check if file exists
|
||||||
if not self._storage.exists(filename):
|
if not self._storage.exists(filename):
|
||||||
logger.warning("File %s not found", filename)
|
logger.warning("File %s not found", filename)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 读取文件内容
|
# Read file content
|
||||||
file_data = self._storage.load_once(filename)
|
file_data = self._storage.load_once(filename)
|
||||||
|
|
||||||
# 移动到删除目录
|
# Move to deleted directory
|
||||||
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||||
self._storage.save(deleted_filename, file_data)
|
self._storage.save(deleted_filename, file_data)
|
||||||
|
|
||||||
# 删除原文件
|
# Delete original file
|
||||||
self._storage.delete(filename)
|
self._storage.delete(filename)
|
||||||
|
|
||||||
# 更新元数据
|
# Update metadata
|
||||||
metadata_dict = self._load_metadata()
|
metadata_dict = self._load_metadata()
|
||||||
if filename in metadata_dict:
|
if filename in metadata_dict:
|
||||||
metadata_dict[filename]["status"] = FileStatus.DELETED.value
|
metadata_dict[filename]["status"] = FileStatus.DELETED.value
|
||||||
@ -319,27 +319,27 @@ class FileLifecycleManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
|
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
|
||||||
"""清理旧版本文件
|
"""Cleanup old version files
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
max_versions: 保留的最大版本数
|
max_versions: Maximum number of versions to keep
|
||||||
max_age_days: 版本文件的最大保留天数
|
max_age_days: Maximum retention days for version files
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
清理的文件数量
|
Number of files cleaned
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
cleaned_count = 0
|
cleaned_count = 0
|
||||||
|
|
||||||
# 获取所有版本文件
|
# Get all version files
|
||||||
try:
|
try:
|
||||||
all_files = self._storage.scan(self._dataset_id or "", files=True)
|
all_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||||
version_files = [f for f in all_files if f.startswith(self._version_prefix)]
|
version_files = [f for f in all_files if f.startswith(self._version_prefix)]
|
||||||
|
|
||||||
# 按文件分组
|
# Group by file
|
||||||
file_versions: dict[str, list[tuple[int, str]]] = {}
|
file_versions: dict[str, list[tuple[int, str]]] = {}
|
||||||
for version_file in version_files:
|
for version_file in version_files:
|
||||||
# 解析文件名和版本
|
# Parse filename and version
|
||||||
parts = version_file[len(self._version_prefix) :].split(".v")
|
parts = version_file[len(self._version_prefix) :].split(".v")
|
||||||
if len(parts) >= 2:
|
if len(parts) >= 2:
|
||||||
base_filename = parts[0]
|
base_filename = parts[0]
|
||||||
@ -352,12 +352,12 @@ class FileLifecycleManager:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 清理每个文件的旧版本
|
# Cleanup old versions for each file
|
||||||
for base_filename, versions in file_versions.items():
|
for base_filename, versions in file_versions.items():
|
||||||
# 按版本号排序
|
# Sort by version number
|
||||||
versions.sort(key=lambda x: x[0], reverse=True)
|
versions.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
|
||||||
# 保留最新的max_versions个版本,删除其余的
|
# Keep the newest max_versions versions, delete the rest
|
||||||
if len(versions) > max_versions:
|
if len(versions) > max_versions:
|
||||||
to_delete = versions[max_versions:]
|
to_delete = versions[max_versions:]
|
||||||
for version_num, version_file in to_delete:
|
for version_num, version_file in to_delete:
|
||||||
@ -377,10 +377,10 @@ class FileLifecycleManager:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
def get_storage_statistics(self) -> dict[str, Any]:
|
def get_storage_statistics(self) -> dict[str, Any]:
|
||||||
"""获取存储统计信息
|
"""Get storage statistics
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
存储统计字典
|
Storage statistics dictionary
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
metadata_dict = self._load_metadata()
|
metadata_dict = self._load_metadata()
|
||||||
@ -402,7 +402,7 @@ class FileLifecycleManager:
|
|||||||
for filename, metadata in metadata_dict.items():
|
for filename, metadata in metadata_dict.items():
|
||||||
file_meta = FileMetadata.from_dict(metadata)
|
file_meta = FileMetadata.from_dict(metadata)
|
||||||
|
|
||||||
# 统计文件状态
|
# Count file status
|
||||||
if file_meta.status == FileStatus.ACTIVE:
|
if file_meta.status == FileStatus.ACTIVE:
|
||||||
stats["active_files"] = (stats["active_files"] or 0) + 1
|
stats["active_files"] = (stats["active_files"] or 0) + 1
|
||||||
elif file_meta.status == FileStatus.ARCHIVED:
|
elif file_meta.status == FileStatus.ARCHIVED:
|
||||||
@ -410,13 +410,13 @@ class FileLifecycleManager:
|
|||||||
elif file_meta.status == FileStatus.DELETED:
|
elif file_meta.status == FileStatus.DELETED:
|
||||||
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1
|
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1
|
||||||
|
|
||||||
# 统计大小
|
# Count size
|
||||||
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)
|
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)
|
||||||
|
|
||||||
# 统计版本
|
# Count versions
|
||||||
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)
|
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)
|
||||||
|
|
||||||
# 找出最新和最旧的文件
|
# Find newest and oldest files
|
||||||
if oldest_date is None or file_meta.created_at < oldest_date:
|
if oldest_date is None or file_meta.created_at < oldest_date:
|
||||||
oldest_date = file_meta.created_at
|
oldest_date = file_meta.created_at
|
||||||
stats["oldest_file"] = filename
|
stats["oldest_file"] = filename
|
||||||
@ -432,12 +432,12 @@ class FileLifecycleManager:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _create_version_backup(self, filename: str, metadata: dict):
|
def _create_version_backup(self, filename: str, metadata: dict):
|
||||||
"""创建版本备份"""
|
"""Create version backup"""
|
||||||
try:
|
try:
|
||||||
# 读取当前文件内容
|
# Read current file content
|
||||||
current_data = self._storage.load_once(filename)
|
current_data = self._storage.load_once(filename)
|
||||||
|
|
||||||
# 保存为版本文件
|
# Save as version file
|
||||||
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
|
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
|
||||||
self._storage.save(version_filename, current_data)
|
self._storage.save(version_filename, current_data)
|
||||||
|
|
||||||
@ -447,7 +447,7 @@ class FileLifecycleManager:
|
|||||||
logger.warning("Failed to create version backup for %s: %s", filename, e)
|
logger.warning("Failed to create version backup for %s: %s", filename, e)
|
||||||
|
|
||||||
def _load_metadata(self) -> dict[str, Any]:
|
def _load_metadata(self) -> dict[str, Any]:
|
||||||
"""加载元数据文件"""
|
"""Load metadata file"""
|
||||||
try:
|
try:
|
||||||
if self._storage.exists(self._metadata_file):
|
if self._storage.exists(self._metadata_file):
|
||||||
metadata_content = self._storage.load_once(self._metadata_file)
|
metadata_content = self._storage.load_once(self._metadata_file)
|
||||||
@ -460,7 +460,7 @@ class FileLifecycleManager:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _save_metadata(self, metadata_dict: dict):
|
def _save_metadata(self, metadata_dict: dict):
|
||||||
"""保存元数据文件"""
|
"""Save metadata file"""
|
||||||
try:
|
try:
|
||||||
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
||||||
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
|
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
|
||||||
@ -470,45 +470,45 @@ class FileLifecycleManager:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _calculate_checksum(self, data: bytes) -> str:
|
def _calculate_checksum(self, data: bytes) -> str:
|
||||||
"""计算文件校验和"""
|
"""Calculate file checksum"""
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
return hashlib.md5(data).hexdigest()
|
return hashlib.md5(data).hexdigest()
|
||||||
|
|
||||||
def _check_permission(self, filename: str, operation: str) -> bool:
|
def _check_permission(self, filename: str, operation: str) -> bool:
|
||||||
"""检查文件操作权限
|
"""Check file operation permission
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: 文件名
|
filename: File name
|
||||||
operation: 操作类型
|
operation: Operation type
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if permission granted, False otherwise
|
True if permission granted, False otherwise
|
||||||
"""
|
"""
|
||||||
# 如果没有权限管理器,默认允许
|
# If no permission manager, allow by default
|
||||||
if not self._permission_manager:
|
if not self._permission_manager:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 根据操作类型映射到权限
|
# Map operation type to permission
|
||||||
operation_mapping = {
|
operation_mapping = {
|
||||||
"save": "save",
|
"save": "save",
|
||||||
"load": "load_once",
|
"load": "load_once",
|
||||||
"delete": "delete",
|
"delete": "delete",
|
||||||
"archive": "delete", # 归档需要删除权限
|
"archive": "delete", # Archive requires delete permission
|
||||||
"restore": "save", # 恢复需要写权限
|
"restore": "save", # Restore requires write permission
|
||||||
"cleanup": "delete", # 清理需要删除权限
|
"cleanup": "delete", # Cleanup requires delete permission
|
||||||
"read": "load_once",
|
"read": "load_once",
|
||||||
"write": "save",
|
"write": "save",
|
||||||
}
|
}
|
||||||
|
|
||||||
mapped_operation = operation_mapping.get(operation, operation)
|
mapped_operation = operation_mapping.get(operation, operation)
|
||||||
|
|
||||||
# 检查权限
|
# Check permission
|
||||||
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
|
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
|
||||||
return bool(result)
|
return bool(result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Permission check failed for %s operation %s", filename, operation)
|
logger.exception("Permission check failed for %s operation %s", filename, operation)
|
||||||
# 安全默认:权限检查失败时拒绝访问
|
# Safe default: deny access when permission check fails
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""ClickZetta Volume权限管理机制
|
"""ClickZetta Volume permission management mechanism
|
||||||
|
|
||||||
该模块提供Volume权限检查、验证和管理功能。
|
This module provides Volume permission checking, validation and management features.
|
||||||
根据ClickZetta的权限模型,不同Volume类型有不同的权限要求。
|
According to ClickZetta's permission model, different Volume types have different permission requirements.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -12,29 +12,29 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class VolumePermission(Enum):
|
class VolumePermission(Enum):
|
||||||
"""Volume权限类型枚举"""
|
"""Volume permission type enumeration"""
|
||||||
|
|
||||||
READ = "SELECT" # 对应ClickZetta的SELECT权限
|
READ = "SELECT" # Corresponds to ClickZetta's SELECT permission
|
||||||
WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限
|
WRITE = "INSERT,UPDATE,DELETE" # Corresponds to ClickZetta's write permissions
|
||||||
LIST = "SELECT" # 列出文件需要SELECT权限
|
LIST = "SELECT" # Listing files requires SELECT permission
|
||||||
DELETE = "INSERT,UPDATE,DELETE" # 删除文件需要写权限
|
DELETE = "INSERT,UPDATE,DELETE" # Deleting files requires write permissions
|
||||||
USAGE = "USAGE" # External Volume需要的基本权限
|
USAGE = "USAGE" # Basic permission required for External Volume
|
||||||
|
|
||||||
|
|
||||||
class VolumePermissionManager:
|
class VolumePermissionManager:
|
||||||
"""Volume权限管理器"""
|
"""Volume permission manager"""
|
||||||
|
|
||||||
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None):
|
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None):
|
||||||
"""初始化权限管理器
|
"""Initialize permission manager
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
connection_or_config: ClickZetta连接对象或配置字典
|
connection_or_config: ClickZetta connection object or configuration dictionary
|
||||||
volume_type: Volume类型 (user|table|external)
|
volume_type: Volume type (user|table|external)
|
||||||
volume_name: Volume名称 (用于external volume)
|
volume_name: Volume name (for external volume)
|
||||||
"""
|
"""
|
||||||
# 支持两种初始化方式:连接对象或配置字典
|
# Support two initialization methods: connection object or configuration dictionary
|
||||||
if isinstance(connection_or_config, dict):
|
if isinstance(connection_or_config, dict):
|
||||||
# 从配置字典创建连接
|
# Create connection from configuration dictionary
|
||||||
import clickzetta # type: ignore[import-untyped]
|
import clickzetta # type: ignore[import-untyped]
|
||||||
|
|
||||||
config = connection_or_config
|
config = connection_or_config
|
||||||
@ -50,7 +50,7 @@ class VolumePermissionManager:
|
|||||||
self._volume_type = config.get("volume_type", volume_type)
|
self._volume_type = config.get("volume_type", volume_type)
|
||||||
self._volume_name = config.get("volume_name", volume_name)
|
self._volume_name = config.get("volume_name", volume_name)
|
||||||
else:
|
else:
|
||||||
# 直接使用连接对象
|
# Use connection object directly
|
||||||
self._connection = connection_or_config
|
self._connection = connection_or_config
|
||||||
self._volume_type = volume_type
|
self._volume_type = volume_type
|
||||||
self._volume_name = volume_name
|
self._volume_name = volume_name
|
||||||
@ -61,14 +61,14 @@ class VolumePermissionManager:
|
|||||||
raise ValueError("volume_type is required")
|
raise ValueError("volume_type is required")
|
||||||
|
|
||||||
self._permission_cache: dict[str, set[str]] = {}
|
self._permission_cache: dict[str, set[str]] = {}
|
||||||
self._current_username = None # 将从连接中获取当前用户名
|
self._current_username = None # Will get current username from connection
|
||||||
|
|
||||||
def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool:
|
def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool:
|
||||||
"""检查用户是否有执行特定操作的权限
|
"""Check if user has permission to perform specific operation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
operation: 要执行的操作类型
|
operation: Type of operation to perform
|
||||||
dataset_id: 数据集ID (用于table volume)
|
dataset_id: Dataset ID (for table volume)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if user has permission, False otherwise
|
True if user has permission, False otherwise
|
||||||
@ -89,20 +89,20 @@ class VolumePermissionManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
|
def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
|
||||||
"""检查User Volume权限
|
"""Check User Volume permission
|
||||||
|
|
||||||
User Volume权限规则:
|
User Volume permission rules:
|
||||||
- 用户对自己的User Volume有全部权限
|
- User has full permissions on their own User Volume
|
||||||
- 只要用户能够连接到ClickZetta,就默认具有User Volume的基本权限
|
- As long as user can connect to ClickZetta, they have basic User Volume permissions by default
|
||||||
- 更注重连接身份验证,而不是复杂的权限检查
|
- Focus more on connection authentication rather than complex permission checking
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 获取当前用户名
|
# Get current username
|
||||||
current_user = self._get_current_username()
|
current_user = self._get_current_username()
|
||||||
|
|
||||||
# 检查基本连接状态
|
# Check basic connection status
|
||||||
with self._connection.cursor() as cursor:
|
with self._connection.cursor() as cursor:
|
||||||
# 简单的连接测试,如果能执行查询说明用户有基本权限
|
# Simple connection test, if query can be executed user has basic permissions
|
||||||
cursor.execute("SELECT 1")
|
cursor.execute("SELECT 1")
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
|
|
||||||
@ -121,17 +121,18 @@ class VolumePermissionManager:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("User Volume permission check failed")
|
logger.exception("User Volume permission check failed")
|
||||||
# 对于User Volume,如果权限检查失败,可能是配置问题,给出更友好的错误提示
|
# For User Volume, if permission check fails, it might be a configuration issue,
|
||||||
|
# provide friendlier error message
|
||||||
logger.info("User Volume permission check failed, but permission checking is disabled in this version")
|
logger.info("User Volume permission check failed, but permission checking is disabled in this version")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool:
|
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool:
|
||||||
"""检查Table Volume权限
|
"""Check Table Volume permission
|
||||||
|
|
||||||
Table Volume权限规则:
|
Table Volume permission rules:
|
||||||
- Table Volume权限继承对应表的权限
|
- Table Volume permissions inherit from corresponding table permissions
|
||||||
- SELECT权限 -> 可以READ/LIST文件
|
- SELECT permission -> can READ/LIST files
|
||||||
- INSERT,UPDATE,DELETE权限 -> 可以WRITE/DELETE文件
|
- INSERT,UPDATE,DELETE permissions -> can WRITE/DELETE files
|
||||||
"""
|
"""
|
||||||
if not dataset_id:
|
if not dataset_id:
|
||||||
logger.warning("dataset_id is required for table volume permission check")
|
logger.warning("dataset_id is required for table volume permission check")
|
||||||
@ -140,11 +141,11 @@ class VolumePermissionManager:
|
|||||||
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id
|
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查表权限
|
# Check table permissions
|
||||||
permissions = self._get_table_permissions(table_name)
|
permissions = self._get_table_permissions(table_name)
|
||||||
required_permissions = set(operation.value.split(","))
|
required_permissions = set(operation.value.split(","))
|
||||||
|
|
||||||
# 检查是否有所需的所有权限
|
# Check if has all required permissions
|
||||||
has_permission = required_permissions.issubset(permissions)
|
has_permission = required_permissions.issubset(permissions)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -163,22 +164,22 @@ class VolumePermissionManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
|
def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
|
||||||
"""检查External Volume权限
|
"""Check External Volume permission
|
||||||
|
|
||||||
External Volume权限规则:
|
External Volume permission rules:
|
||||||
- 尝试获取对External Volume的权限
|
- Try to get permissions for External Volume
|
||||||
- 如果权限检查失败,进行备选验证
|
- If permission check fails, perform fallback verification
|
||||||
- 对于开发环境,提供更宽松的权限检查
|
- For development environment, provide more lenient permission checking
|
||||||
"""
|
"""
|
||||||
if not self._volume_name:
|
if not self._volume_name:
|
||||||
logger.warning("volume_name is required for external volume permission check")
|
logger.warning("volume_name is required for external volume permission check")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查External Volume权限
|
# Check External Volume permissions
|
||||||
permissions = self._get_external_volume_permissions(self._volume_name)
|
permissions = self._get_external_volume_permissions(self._volume_name)
|
||||||
|
|
||||||
# External Volume权限映射:根据操作类型确定所需权限
|
# External Volume permission mapping: determine required permissions based on operation type
|
||||||
required_permissions = set()
|
required_permissions = set()
|
||||||
|
|
||||||
if operation in [VolumePermission.READ, VolumePermission.LIST]:
|
if operation in [VolumePermission.READ, VolumePermission.LIST]:
|
||||||
@ -186,7 +187,7 @@ class VolumePermissionManager:
|
|||||||
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
|
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
|
||||||
required_permissions.add("write")
|
required_permissions.add("write")
|
||||||
|
|
||||||
# 检查是否有所需的所有权限
|
# Check if has all required permissions
|
||||||
has_permission = required_permissions.issubset(permissions)
|
has_permission = required_permissions.issubset(permissions)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -198,11 +199,11 @@ class VolumePermissionManager:
|
|||||||
has_permission,
|
has_permission,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 如果权限检查失败,尝试备选验证
|
# If permission check fails, try fallback verification
|
||||||
if not has_permission:
|
if not has_permission:
|
||||||
logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name)
|
logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name)
|
||||||
|
|
||||||
# 备选验证:尝试列出Volume来验证基本访问权限
|
# Fallback verification: try listing Volume to verify basic access permissions
|
||||||
try:
|
try:
|
||||||
with self._connection.cursor() as cursor:
|
with self._connection.cursor() as cursor:
|
||||||
cursor.execute("SHOW VOLUMES")
|
cursor.execute("SHOW VOLUMES")
|
||||||
@ -222,13 +223,13 @@ class VolumePermissionManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_table_permissions(self, table_name: str) -> set[str]:
|
def _get_table_permissions(self, table_name: str) -> set[str]:
|
||||||
"""获取用户对指定表的权限
|
"""Get user permissions for specified table
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table_name: 表名
|
table_name: Table name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
用户对该表的权限集合
|
Set of user permissions for this table
|
||||||
"""
|
"""
|
||||||
cache_key = f"table:{table_name}"
|
cache_key = f"table:{table_name}"
|
||||||
|
|
||||||
@ -239,18 +240,18 @@ class VolumePermissionManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with self._connection.cursor() as cursor:
|
with self._connection.cursor() as cursor:
|
||||||
# 使用正确的ClickZetta语法检查当前用户权限
|
# Use correct ClickZetta syntax to check current user permissions
|
||||||
cursor.execute("SHOW GRANTS")
|
cursor.execute("SHOW GRANTS")
|
||||||
grants = cursor.fetchall()
|
grants = cursor.fetchall()
|
||||||
|
|
||||||
# 解析权限结果,查找对该表的权限
|
# Parse permission results, find permissions for this table
|
||||||
for grant in grants:
|
for grant in grants:
|
||||||
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
|
if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...)
|
||||||
privilege = grant[0].upper()
|
privilege = grant[0].upper()
|
||||||
object_type = grant[1].upper() if len(grant) > 1 else ""
|
object_type = grant[1].upper() if len(grant) > 1 else ""
|
||||||
object_name = grant[2] if len(grant) > 2 else ""
|
object_name = grant[2] if len(grant) > 2 else ""
|
||||||
|
|
||||||
# 检查是否是对该表的权限
|
# Check if it's permission for this table
|
||||||
if (
|
if (
|
||||||
object_type == "TABLE"
|
object_type == "TABLE"
|
||||||
and object_name == table_name
|
and object_name == table_name
|
||||||
@ -263,7 +264,7 @@ class VolumePermissionManager:
|
|||||||
else:
|
else:
|
||||||
permissions.add(privilege)
|
permissions.add(privilege)
|
||||||
|
|
||||||
# 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限
|
# If no explicit permissions found, try executing a simple query to verify permissions
|
||||||
if not permissions:
|
if not permissions:
|
||||||
try:
|
try:
|
||||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
|
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
|
||||||
@ -273,15 +274,15 @@ class VolumePermissionManager:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Could not check table permissions for %s: %s", table_name, e)
|
logger.warning("Could not check table permissions for %s: %s", table_name, e)
|
||||||
# 安全默认:权限检查失败时拒绝访问
|
# Safe default: deny access when permission check fails
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 缓存权限信息
|
# Cache permission information
|
||||||
self._permission_cache[cache_key] = permissions
|
self._permission_cache[cache_key] = permissions
|
||||||
return permissions
|
return permissions
|
||||||
|
|
||||||
def _get_current_username(self) -> str:
|
def _get_current_username(self) -> str:
|
||||||
"""获取当前用户名"""
|
"""Get current username"""
|
||||||
if self._current_username:
|
if self._current_username:
|
||||||
return self._current_username
|
return self._current_username
|
||||||
|
|
||||||
@ -298,7 +299,7 @@ class VolumePermissionManager:
|
|||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
def _get_user_permissions(self, username: str) -> set[str]:
|
def _get_user_permissions(self, username: str) -> set[str]:
|
||||||
"""获取用户的基本权限集合"""
|
"""Get user's basic permission set"""
|
||||||
cache_key = f"user_permissions:{username}"
|
cache_key = f"user_permissions:{username}"
|
||||||
|
|
||||||
if cache_key in self._permission_cache:
|
if cache_key in self._permission_cache:
|
||||||
@ -308,17 +309,17 @@ class VolumePermissionManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with self._connection.cursor() as cursor:
|
with self._connection.cursor() as cursor:
|
||||||
# 使用正确的ClickZetta语法检查当前用户权限
|
# Use correct ClickZetta syntax to check current user permissions
|
||||||
cursor.execute("SHOW GRANTS")
|
cursor.execute("SHOW GRANTS")
|
||||||
grants = cursor.fetchall()
|
grants = cursor.fetchall()
|
||||||
|
|
||||||
# 解析权限结果,查找用户的基本权限
|
# Parse permission results, find user's basic permissions
|
||||||
for grant in grants:
|
for grant in grants:
|
||||||
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
|
if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...)
|
||||||
privilege = grant[0].upper()
|
privilege = grant[0].upper()
|
||||||
object_type = grant[1].upper() if len(grant) > 1 else ""
|
object_type = grant[1].upper() if len(grant) > 1 else ""
|
||||||
|
|
||||||
# 收集所有相关权限
|
# Collect all relevant permissions
|
||||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||||
if privilege == "ALL":
|
if privilege == "ALL":
|
||||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||||
@ -327,21 +328,21 @@ class VolumePermissionManager:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Could not check user permissions for %s: %s", username, e)
|
logger.warning("Could not check user permissions for %s: %s", username, e)
|
||||||
# 安全默认:权限检查失败时拒绝访问
|
# Safe default: deny access when permission check fails
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 缓存权限信息
|
# Cache permission information
|
||||||
self._permission_cache[cache_key] = permissions
|
self._permission_cache[cache_key] = permissions
|
||||||
return permissions
|
return permissions
|
||||||
|
|
||||||
def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
|
def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
|
||||||
"""获取用户对指定External Volume的权限
|
"""Get user permissions for specified External Volume
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
volume_name: External Volume名称
|
volume_name: External Volume name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
用户对该Volume的权限集合
|
Set of user permissions for this Volume
|
||||||
"""
|
"""
|
||||||
cache_key = f"external_volume:{volume_name}"
|
cache_key = f"external_volume:{volume_name}"
|
||||||
|
|
||||||
@ -352,15 +353,15 @@ class VolumePermissionManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with self._connection.cursor() as cursor:
|
with self._connection.cursor() as cursor:
|
||||||
# 使用正确的ClickZetta语法检查Volume权限
|
# Use correct ClickZetta syntax to check Volume permissions
|
||||||
logger.info("Checking permissions for volume: %s", volume_name)
|
logger.info("Checking permissions for volume: %s", volume_name)
|
||||||
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
|
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
|
||||||
grants = cursor.fetchall()
|
grants = cursor.fetchall()
|
||||||
|
|
||||||
logger.info("Raw grants result for %s: %s", volume_name, grants)
|
logger.info("Raw grants result for %s: %s", volume_name, grants)
|
||||||
|
|
||||||
# 解析权限结果
|
# Parse permission results
|
||||||
# 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
|
# Format: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
|
||||||
# grantee_name, grantor_name, grant_option, granted_time)
|
# grantee_name, grantor_name, grant_option, granted_time)
|
||||||
for grant in grants:
|
for grant in grants:
|
||||||
logger.info("Processing grant: %s", grant)
|
logger.info("Processing grant: %s", grant)
|
||||||
@ -378,7 +379,7 @@ class VolumePermissionManager:
|
|||||||
object_name,
|
object_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否是对该Volume的权限或者是层级权限
|
# Check if it's permission for this Volume or hierarchical permission
|
||||||
if (
|
if (
|
||||||
granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
|
granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
|
||||||
) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
|
) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
|
||||||
@ -399,14 +400,14 @@ class VolumePermissionManager:
|
|||||||
|
|
||||||
logger.info("Final permissions for %s: %s", volume_name, permissions)
|
logger.info("Final permissions for %s: %s", volume_name, permissions)
|
||||||
|
|
||||||
# 如果没有找到明确的权限,尝试查看Volume列表来验证基本权限
|
# If no explicit permissions found, try viewing Volume list to verify basic permissions
|
||||||
if not permissions:
|
if not permissions:
|
||||||
try:
|
try:
|
||||||
cursor.execute("SHOW VOLUMES")
|
cursor.execute("SHOW VOLUMES")
|
||||||
volumes = cursor.fetchall()
|
volumes = cursor.fetchall()
|
||||||
for volume in volumes:
|
for volume in volumes:
|
||||||
if len(volume) > 0 and volume[0] == volume_name:
|
if len(volume) > 0 and volume[0] == volume_name:
|
||||||
permissions.add("read") # 至少有读权限
|
permissions.add("read") # At least has read permission
|
||||||
logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name)
|
logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name)
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -414,7 +415,7 @@ class VolumePermissionManager:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Could not check external volume permissions for %s: %s", volume_name, e)
|
logger.warning("Could not check external volume permissions for %s: %s", volume_name, e)
|
||||||
# 在权限检查失败时,尝试基本的Volume访问验证
|
# When permission check fails, try basic Volume access verification
|
||||||
try:
|
try:
|
||||||
with self._connection.cursor() as cursor:
|
with self._connection.cursor() as cursor:
|
||||||
cursor.execute("SHOW VOLUMES")
|
cursor.execute("SHOW VOLUMES")
|
||||||
@ -423,30 +424,30 @@ class VolumePermissionManager:
|
|||||||
if len(volume) > 0 and volume[0] == volume_name:
|
if len(volume) > 0 and volume[0] == volume_name:
|
||||||
logger.info("Basic volume access verified for %s", volume_name)
|
logger.info("Basic volume access verified for %s", volume_name)
|
||||||
permissions.add("read")
|
permissions.add("read")
|
||||||
permissions.add("write") # 假设有写权限
|
permissions.add("write") # Assume has write permission
|
||||||
break
|
break
|
||||||
except Exception as basic_e:
|
except Exception as basic_e:
|
||||||
logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e)
|
logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e)
|
||||||
# 最后的备选方案:假设有基本权限
|
# Last fallback: assume basic permissions
|
||||||
permissions.add("read")
|
permissions.add("read")
|
||||||
|
|
||||||
# 缓存权限信息
|
# Cache permission information
|
||||||
self._permission_cache[cache_key] = permissions
|
self._permission_cache[cache_key] = permissions
|
||||||
return permissions
|
return permissions
|
||||||
|
|
||||||
def clear_permission_cache(self):
|
def clear_permission_cache(self):
|
||||||
"""清空权限缓存"""
|
"""Clear permission cache"""
|
||||||
self._permission_cache.clear()
|
self._permission_cache.clear()
|
||||||
logger.debug("Permission cache cleared")
|
logger.debug("Permission cache cleared")
|
||||||
|
|
||||||
def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]:
|
def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]:
|
||||||
"""获取权限摘要
|
"""Get permission summary
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_id: 数据集ID (用于table volume)
|
dataset_id: Dataset ID (for table volume)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
权限摘要字典
|
Permission summary dictionary
|
||||||
"""
|
"""
|
||||||
summary = {}
|
summary = {}
|
||||||
|
|
||||||
@ -456,43 +457,43 @@ class VolumePermissionManager:
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
|
def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
|
||||||
"""检查文件路径的权限继承
|
"""Check permission inheritance for file path
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: 文件路径
|
file_path: File path
|
||||||
operation: 要执行的操作
|
operation: Operation to perform
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if user has permission, False otherwise
|
True if user has permission, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 解析文件路径
|
# Parse file path
|
||||||
path_parts = file_path.strip("/").split("/")
|
path_parts = file_path.strip("/").split("/")
|
||||||
|
|
||||||
if not path_parts:
|
if not path_parts:
|
||||||
logger.warning("Invalid file path for permission inheritance check")
|
logger.warning("Invalid file path for permission inheritance check")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 对于Table Volume,第一层是dataset_id
|
# For Table Volume, first layer is dataset_id
|
||||||
if self._volume_type == "table":
|
if self._volume_type == "table":
|
||||||
if len(path_parts) < 1:
|
if len(path_parts) < 1:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
dataset_id = path_parts[0]
|
dataset_id = path_parts[0]
|
||||||
|
|
||||||
# 检查对dataset的权限
|
# Check permissions for dataset
|
||||||
has_dataset_permission = self.check_permission(operation, dataset_id)
|
has_dataset_permission = self.check_permission(operation, dataset_id)
|
||||||
|
|
||||||
if not has_dataset_permission:
|
if not has_dataset_permission:
|
||||||
logger.debug("Permission denied for dataset %s", dataset_id)
|
logger.debug("Permission denied for dataset %s", dataset_id)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查路径遍历攻击
|
# Check path traversal attack
|
||||||
if self._contains_path_traversal(file_path):
|
if self._contains_path_traversal(file_path):
|
||||||
logger.warning("Path traversal attack detected: %s", file_path)
|
logger.warning("Path traversal attack detected: %s", file_path)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查是否访问敏感目录
|
# Check if accessing sensitive directory
|
||||||
if self._is_sensitive_path(file_path):
|
if self._is_sensitive_path(file_path):
|
||||||
logger.warning("Access to sensitive path denied: %s", file_path)
|
logger.warning("Access to sensitive path denied: %s", file_path)
|
||||||
return False
|
return False
|
||||||
@ -501,20 +502,20 @@ class VolumePermissionManager:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
elif self._volume_type == "user":
|
elif self._volume_type == "user":
|
||||||
# User Volume的权限继承
|
# User Volume permission inheritance
|
||||||
current_user = self._get_current_username()
|
current_user = self._get_current_username()
|
||||||
|
|
||||||
# 检查是否试图访问其他用户的目录
|
# Check if attempting to access other user's directory
|
||||||
if len(path_parts) > 1 and path_parts[0] != current_user:
|
if len(path_parts) > 1 and path_parts[0] != current_user:
|
||||||
logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0])
|
logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0])
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查基本权限
|
# Check basic permissions
|
||||||
return self.check_permission(operation)
|
return self.check_permission(operation)
|
||||||
|
|
||||||
elif self._volume_type == "external":
|
elif self._volume_type == "external":
|
||||||
# External Volume的权限继承
|
# External Volume permission inheritance
|
||||||
# 检查对External Volume的权限
|
# Check permissions for External Volume
|
||||||
return self.check_permission(operation)
|
return self.check_permission(operation)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -526,8 +527,8 @@ class VolumePermissionManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _contains_path_traversal(self, file_path: str) -> bool:
|
def _contains_path_traversal(self, file_path: str) -> bool:
|
||||||
"""检查路径是否包含路径遍历攻击"""
|
"""Check if path contains path traversal attack"""
|
||||||
# 检查常见的路径遍历模式
|
# Check common path traversal patterns
|
||||||
traversal_patterns = [
|
traversal_patterns = [
|
||||||
"../",
|
"../",
|
||||||
"..\\",
|
"..\\",
|
||||||
@ -547,18 +548,18 @@ class VolumePermissionManager:
|
|||||||
if pattern in file_path_lower:
|
if pattern in file_path_lower:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 检查绝对路径
|
# Check absolute path
|
||||||
if file_path.startswith("/") or file_path.startswith("\\"):
|
if file_path.startswith("/") or file_path.startswith("\\"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 检查Windows驱动器路径
|
# Check Windows drive path
|
||||||
if len(file_path) >= 2 and file_path[1] == ":":
|
if len(file_path) >= 2 and file_path[1] == ":":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _is_sensitive_path(self, file_path: str) -> bool:
|
def _is_sensitive_path(self, file_path: str) -> bool:
|
||||||
"""检查路径是否为敏感路径"""
|
"""Check if path is sensitive path"""
|
||||||
sensitive_patterns = [
|
sensitive_patterns = [
|
||||||
"passwd",
|
"passwd",
|
||||||
"shadow",
|
"shadow",
|
||||||
@ -582,11 +583,11 @@ class VolumePermissionManager:
|
|||||||
return any(pattern in file_path_lower for pattern in sensitive_patterns)
|
return any(pattern in file_path_lower for pattern in sensitive_patterns)
|
||||||
|
|
||||||
def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool:
|
def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool:
|
||||||
"""验证操作权限
|
"""Validate operation permission
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
operation: 操作名称 (save|load|exists|delete|scan)
|
operation: Operation name (save|load|exists|delete|scan)
|
||||||
dataset_id: 数据集ID
|
dataset_id: Dataset ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if operation is allowed, False otherwise
|
True if operation is allowed, False otherwise
|
||||||
@ -611,7 +612,7 @@ class VolumePermissionManager:
|
|||||||
|
|
||||||
|
|
||||||
class VolumePermissionError(Exception):
|
class VolumePermissionError(Exception):
|
||||||
"""Volume权限错误异常"""
|
"""Volume permission error exception"""
|
||||||
|
|
||||||
def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None):
|
def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None):
|
||||||
self.operation = operation
|
self.operation = operation
|
||||||
@ -623,15 +624,15 @@ class VolumePermissionError(Exception):
|
|||||||
def check_volume_permission(
|
def check_volume_permission(
|
||||||
permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None
|
permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""权限检查装饰器函数
|
"""Permission check decorator function
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
permission_manager: 权限管理器
|
permission_manager: Permission manager
|
||||||
operation: 操作名称
|
operation: Operation name
|
||||||
dataset_id: 数据集ID
|
dataset_id: Dataset ID
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
VolumePermissionError: 如果没有权限
|
VolumePermissionError: If no permission
|
||||||
"""
|
"""
|
||||||
if not permission_manager.validate_operation(operation, dataset_id):
|
if not permission_manager.validate_operation(operation, dataset_id):
|
||||||
error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume"
|
error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume"
|
||||||
|
|||||||
@ -1,11 +0,0 @@
|
|||||||
from tests.integration_tests.utils.parent_class import ParentClass
|
|
||||||
|
|
||||||
|
|
||||||
class LazyLoadChildClass(ParentClass):
|
|
||||||
"""Test lazy load child class for module import helper tests"""
|
|
||||||
|
|
||||||
def __init__(self, name):
|
|
||||||
super().__init__(name)
|
|
||||||
|
|
||||||
def get_name(self):
|
|
||||||
return self.name
|
|
||||||
@ -0,0 +1,716 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from faker import Faker
|
||||||
|
|
||||||
|
from models.tools import WorkflowToolProvider
|
||||||
|
from models.workflow import Workflow as WorkflowModel
|
||||||
|
from services.account_service import AccountService, TenantService
|
||||||
|
from services.app_service import AppService
|
||||||
|
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkflowToolManageService:
|
||||||
|
"""Integration tests for WorkflowToolManageService using testcontainers."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_external_service_dependencies(self):
|
||||||
|
"""Mock setup for external service dependencies."""
|
||||||
|
with (
|
||||||
|
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||||
|
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||||
|
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||||
|
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||||
|
patch(
|
||||||
|
"services.tools.workflow_tools_manage_service.WorkflowToolProviderController"
|
||||||
|
) as mock_workflow_tool_provider_controller,
|
||||||
|
patch("services.tools.workflow_tools_manage_service.ToolLabelManager") as mock_tool_label_manager,
|
||||||
|
patch("services.tools.workflow_tools_manage_service.ToolTransformService") as mock_tool_transform_service,
|
||||||
|
):
|
||||||
|
# Setup default mock returns for app service
|
||||||
|
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||||
|
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||||
|
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||||
|
|
||||||
|
# Setup default mock returns for account service
|
||||||
|
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||||
|
|
||||||
|
# Mock ModelManager for model configuration
|
||||||
|
mock_model_instance = mock_model_manager.return_value
|
||||||
|
mock_model_instance.get_default_model_instance.return_value = None
|
||||||
|
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||||
|
|
||||||
|
# Mock WorkflowToolProviderController
|
||||||
|
mock_workflow_tool_provider_controller.from_db.return_value = None
|
||||||
|
|
||||||
|
# Mock ToolLabelManager
|
||||||
|
mock_tool_label_manager.update_tool_labels.return_value = None
|
||||||
|
|
||||||
|
# Mock ToolTransformService
|
||||||
|
mock_tool_transform_service.workflow_provider_to_controller.return_value = None
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"feature_service": mock_feature_service,
|
||||||
|
"enterprise_service": mock_enterprise_service,
|
||||||
|
"model_manager": mock_model_manager,
|
||||||
|
"account_feature_service": mock_account_feature_service,
|
||||||
|
"workflow_tool_provider_controller": mock_workflow_tool_provider_controller,
|
||||||
|
"tool_label_manager": mock_tool_label_manager,
|
||||||
|
"tool_transform_service": mock_tool_transform_service,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
|
"""
|
||||||
|
Helper method to create a test app and account for testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session_with_containers: Database session from testcontainers infrastructure
|
||||||
|
mock_external_service_dependencies: Mock dependencies
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (app, account, workflow) - Created app, account and workflow instances
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Setup mocks for account creation
|
||||||
|
mock_external_service_dependencies[
|
||||||
|
"account_feature_service"
|
||||||
|
].get_system_features.return_value.is_allow_register = True
|
||||||
|
|
||||||
|
# Create account and tenant
|
||||||
|
account = AccountService.create_account(
|
||||||
|
email=fake.email(),
|
||||||
|
name=fake.name(),
|
||||||
|
interface_language="en-US",
|
||||||
|
password=fake.password(length=12),
|
||||||
|
)
|
||||||
|
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||||
|
tenant = account.current_tenant
|
||||||
|
|
||||||
|
# Create app with realistic data
|
||||||
|
app_args = {
|
||||||
|
"name": fake.company(),
|
||||||
|
"description": fake.text(max_nb_chars=100),
|
||||||
|
"mode": "workflow",
|
||||||
|
"icon_type": "emoji",
|
||||||
|
"icon": "🤖",
|
||||||
|
"icon_background": "#FF6B6B",
|
||||||
|
"api_rph": 100,
|
||||||
|
"api_rpm": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
app_service = AppService()
|
||||||
|
app = app_service.create_app(tenant.id, app_args, account)
|
||||||
|
|
||||||
|
# Create workflow for the app
|
||||||
|
workflow = WorkflowModel(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
app_id=app.id,
|
||||||
|
type="workflow",
|
||||||
|
version="1.0.0",
|
||||||
|
graph=json.dumps({}),
|
||||||
|
features=json.dumps({}),
|
||||||
|
created_by=account.id,
|
||||||
|
environment_variables=[],
|
||||||
|
conversation_variables=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
db.session.add(workflow)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Update app to reference the workflow
|
||||||
|
app.workflow_id = workflow.id
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return app, account, workflow
|
||||||
|
|
||||||
|
def _create_test_workflow_tool_parameters(self):
|
||||||
|
"""Helper method to create valid workflow tool parameters."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "input_text",
|
||||||
|
"description": "Input text for processing",
|
||||||
|
"form": "form",
|
||||||
|
"type": "string",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "output_format",
|
||||||
|
"description": "Output format specification",
|
||||||
|
"form": "form",
|
||||||
|
"type": "select",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
|
"""
|
||||||
|
Test successful workflow tool creation with valid parameters.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper workflow tool creation with all required fields
|
||||||
|
- Correct database state after creation
|
||||||
|
- Proper relationship establishment
|
||||||
|
- External service integration
|
||||||
|
- Return value correctness
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
app, account, workflow = self._create_test_app_and_account(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup workflow tool creation parameters
|
||||||
|
tool_name = fake.word()
|
||||||
|
tool_label = fake.word()
|
||||||
|
tool_icon = {"type": "emoji", "emoji": "🔧"}
|
||||||
|
tool_description = fake.text(max_nb_chars=200)
|
||||||
|
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||||
|
tool_privacy_policy = fake.text(max_nb_chars=100)
|
||||||
|
tool_labels = ["automation", "workflow"]
|
||||||
|
|
||||||
|
# Execute the method under test
|
||||||
|
result = WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=app.id,
|
||||||
|
name=tool_name,
|
||||||
|
label=tool_label,
|
||||||
|
icon=tool_icon,
|
||||||
|
description=tool_description,
|
||||||
|
parameters=tool_parameters,
|
||||||
|
privacy_policy=tool_privacy_policy,
|
||||||
|
labels=tool_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == {"result": "success"}
|
||||||
|
|
||||||
|
# Verify database state
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
# Check if workflow tool provider was created
|
||||||
|
created_tool_provider = (
|
||||||
|
db.session.query(WorkflowToolProvider)
|
||||||
|
.where(
|
||||||
|
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||||
|
WorkflowToolProvider.app_id == app.id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert created_tool_provider is not None
|
||||||
|
assert created_tool_provider.name == tool_name
|
||||||
|
assert created_tool_provider.label == tool_label
|
||||||
|
assert created_tool_provider.icon == json.dumps(tool_icon)
|
||||||
|
assert created_tool_provider.description == tool_description
|
||||||
|
assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters)
|
||||||
|
assert created_tool_provider.privacy_policy == tool_privacy_policy
|
||||||
|
assert created_tool_provider.version == workflow.version
|
||||||
|
assert created_tool_provider.user_id == account.id
|
||||||
|
assert created_tool_provider.tenant_id == account.current_tenant.id
|
||||||
|
assert created_tool_provider.app_id == app.id
|
||||||
|
|
||||||
|
# Verify external service calls
|
||||||
|
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.assert_called_once()
|
||||||
|
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once()
|
||||||
|
mock_external_service_dependencies[
|
||||||
|
"tool_transform_service"
|
||||||
|
].workflow_provider_to_controller.assert_called_once()
|
||||||
|
|
||||||
|
def test_create_workflow_tool_duplicate_name_error(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test workflow tool creation fails when name already exists.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper error handling for duplicate tool names
|
||||||
|
- Database constraint enforcement
|
||||||
|
- Correct error message
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
app, account, workflow = self._create_test_app_and_account(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create first workflow tool
|
||||||
|
first_tool_name = fake.word()
|
||||||
|
first_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||||
|
|
||||||
|
WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=app.id,
|
||||||
|
name=first_tool_name,
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "🔧"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=first_tool_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempt to create second workflow tool with same name
|
||||||
|
second_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=app.id,
|
||||||
|
name=first_tool_name, # Same name
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "⚙️"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=second_tool_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error message
|
||||||
|
assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
|
||||||
|
|
||||||
|
# Verify only one tool was created
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
tool_count = (
|
||||||
|
db.session.query(WorkflowToolProvider)
|
||||||
|
.where(
|
||||||
|
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool_count == 1
|
||||||
|
|
||||||
|
def test_create_workflow_tool_invalid_app_error(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test workflow tool creation fails when app does not exist.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper error handling for non-existent apps
|
||||||
|
- Correct error message
|
||||||
|
- No database changes when app is invalid
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
app, account, workflow = self._create_test_app_and_account(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate non-existent app ID
|
||||||
|
non_existent_app_id = fake.uuid4()
|
||||||
|
|
||||||
|
# Attempt to create workflow tool with non-existent app
|
||||||
|
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=non_existent_app_id, # Non-existent app ID
|
||||||
|
name=fake.word(),
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "🔧"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=tool_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error message
|
||||||
|
assert f"App {non_existent_app_id} not found" in str(exc_info.value)
|
||||||
|
|
||||||
|
# Verify no workflow tool was created
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
tool_count = (
|
||||||
|
db.session.query(WorkflowToolProvider)
|
||||||
|
.where(
|
||||||
|
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool_count == 0
|
||||||
|
|
||||||
|
def test_create_workflow_tool_invalid_parameters_error(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test workflow tool creation fails when parameters are invalid.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper error handling for invalid parameter configurations
|
||||||
|
- Parameter validation enforcement
|
||||||
|
- Correct error message
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
app, account, workflow = self._create_test_app_and_account(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup invalid workflow tool parameters (missing required fields)
|
||||||
|
invalid_parameters = [
|
||||||
|
{
|
||||||
|
"name": "input_text",
|
||||||
|
# Missing description and form fields
|
||||||
|
"type": "string",
|
||||||
|
"required": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Attempt to create workflow tool with invalid parameters
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=app.id,
|
||||||
|
name=fake.word(),
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "🔧"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=invalid_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error message contains validation error
|
||||||
|
assert "validation error" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
# Verify no workflow tool was created
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
tool_count = (
|
||||||
|
db.session.query(WorkflowToolProvider)
|
||||||
|
.where(
|
||||||
|
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool_count == 0
|
||||||
|
|
||||||
|
def test_create_workflow_tool_duplicate_app_id_error(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test workflow tool creation fails when app_id already exists.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper error handling for duplicate app_id
|
||||||
|
- Database constraint enforcement for app_id uniqueness
|
||||||
|
- Correct error message
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
app, account, workflow = self._create_test_app_and_account(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create first workflow tool
|
||||||
|
first_tool_name = fake.word()
|
||||||
|
first_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||||
|
|
||||||
|
WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=app.id,
|
||||||
|
name=first_tool_name,
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "🔧"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=first_tool_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempt to create second workflow tool with same app_id but different name
|
||||||
|
second_tool_name = fake.word()
|
||||||
|
second_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=app.id, # Same app_id
|
||||||
|
name=second_tool_name, # Different name
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "⚙️"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=second_tool_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error message
|
||||||
|
assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
|
||||||
|
|
||||||
|
# Verify only one tool was created
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
tool_count = (
|
||||||
|
db.session.query(WorkflowToolProvider)
|
||||||
|
.where(
|
||||||
|
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool_count == 1
|
||||||
|
|
||||||
|
def test_create_workflow_tool_workflow_not_found_error(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test workflow tool creation fails when app has no workflow.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper error handling for apps without workflows
|
||||||
|
- Correct error message
|
||||||
|
- No database changes when workflow is missing
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create test data but without workflow
|
||||||
|
app, account, _ = self._create_test_app_and_account(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove workflow reference from app
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
app.workflow_id = None
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Attempt to create workflow tool for app without workflow
|
||||||
|
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=app.id,
|
||||||
|
name=fake.word(),
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "🔧"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=tool_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error message
|
||||||
|
assert f"Workflow not found for app {app.id}" in str(exc_info.value)
|
||||||
|
|
||||||
|
# Verify no workflow tool was created
|
||||||
|
tool_count = (
|
||||||
|
db.session.query(WorkflowToolProvider)
|
||||||
|
.where(
|
||||||
|
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool_count == 0
|
||||||
|
|
||||||
|
def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
|
"""
|
||||||
|
Test successful workflow tool update with valid parameters.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper workflow tool update with all required fields
|
||||||
|
- Correct database state after update
|
||||||
|
- Proper relationship maintenance
|
||||||
|
- External service integration
|
||||||
|
- Return value correctness
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
app, account, workflow = self._create_test_app_and_account(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create initial workflow tool
|
||||||
|
initial_tool_name = fake.word()
|
||||||
|
initial_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||||
|
|
||||||
|
WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=app.id,
|
||||||
|
name=initial_tool_name,
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "🔧"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=initial_tool_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the created tool
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
created_tool = (
|
||||||
|
db.session.query(WorkflowToolProvider)
|
||||||
|
.where(
|
||||||
|
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||||
|
WorkflowToolProvider.app_id == app.id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup update parameters
|
||||||
|
updated_tool_name = fake.word()
|
||||||
|
updated_tool_label = fake.word()
|
||||||
|
updated_tool_icon = {"type": "emoji", "emoji": "⚙️"}
|
||||||
|
updated_tool_description = fake.text(max_nb_chars=200)
|
||||||
|
updated_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||||
|
updated_tool_privacy_policy = fake.text(max_nb_chars=100)
|
||||||
|
updated_tool_labels = ["automation", "updated"]
|
||||||
|
|
||||||
|
# Execute the update method
|
||||||
|
result = WorkflowToolManageService.update_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_tool_id=created_tool.id,
|
||||||
|
name=updated_tool_name,
|
||||||
|
label=updated_tool_label,
|
||||||
|
icon=updated_tool_icon,
|
||||||
|
description=updated_tool_description,
|
||||||
|
parameters=updated_tool_parameters,
|
||||||
|
privacy_policy=updated_tool_privacy_policy,
|
||||||
|
labels=updated_tool_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert result == {"result": "success"}
|
||||||
|
|
||||||
|
# Verify database state was updated
|
||||||
|
db.session.refresh(created_tool)
|
||||||
|
assert created_tool.name == updated_tool_name
|
||||||
|
assert created_tool.label == updated_tool_label
|
||||||
|
assert created_tool.icon == json.dumps(updated_tool_icon)
|
||||||
|
assert created_tool.description == updated_tool_description
|
||||||
|
assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters)
|
||||||
|
assert created_tool.privacy_policy == updated_tool_privacy_policy
|
||||||
|
assert created_tool.version == workflow.version
|
||||||
|
assert created_tool.updated_at is not None
|
||||||
|
|
||||||
|
# Verify external service calls
|
||||||
|
mock_external_service_dependencies["workflow_tool_provider_controller"].from_db.assert_called()
|
||||||
|
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called()
|
||||||
|
mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called()
|
||||||
|
|
||||||
|
def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||||
|
"""
|
||||||
|
Test workflow tool update fails when tool does not exist.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper error handling for non-existent tools
|
||||||
|
- Correct error message
|
||||||
|
- No database changes when tool is invalid
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
app, account, workflow = self._create_test_app_and_account(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate non-existent tool ID
|
||||||
|
non_existent_tool_id = fake.uuid4()
|
||||||
|
|
||||||
|
# Attempt to update non-existent workflow tool
|
||||||
|
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
WorkflowToolManageService.update_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_tool_id=non_existent_tool_id, # Non-existent tool ID
|
||||||
|
name=fake.word(),
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "🔧"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=tool_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify error message
|
||||||
|
assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value)
|
||||||
|
|
||||||
|
# Verify no workflow tool was created
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
tool_count = (
|
||||||
|
db.session.query(WorkflowToolProvider)
|
||||||
|
.where(
|
||||||
|
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tool_count == 0
|
||||||
|
|
||||||
|
def test_update_workflow_tool_same_name_success(
|
||||||
|
self, db_session_with_containers, mock_external_service_dependencies
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test workflow tool update succeeds when keeping the same name.
|
||||||
|
|
||||||
|
This test verifies:
|
||||||
|
- Proper handling when updating tool with same name
|
||||||
|
- Database state maintenance
|
||||||
|
- Update timestamp is set
|
||||||
|
"""
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
app, account, workflow = self._create_test_app_and_account(
|
||||||
|
db_session_with_containers, mock_external_service_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create first workflow tool
|
||||||
|
first_tool_name = fake.word()
|
||||||
|
first_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||||
|
|
||||||
|
WorkflowToolManageService.create_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_app_id=app.id,
|
||||||
|
name=first_tool_name,
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "🔧"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=first_tool_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the created tool
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
created_tool = (
|
||||||
|
db.session.query(WorkflowToolProvider)
|
||||||
|
.where(
|
||||||
|
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||||
|
WorkflowToolProvider.app_id == app.id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attempt to update tool with same name (should not fail)
|
||||||
|
result = WorkflowToolManageService.update_workflow_tool(
|
||||||
|
user_id=account.id,
|
||||||
|
tenant_id=account.current_tenant.id,
|
||||||
|
workflow_tool_id=created_tool.id,
|
||||||
|
name=first_tool_name, # Same name
|
||||||
|
label=fake.word(),
|
||||||
|
icon={"type": "emoji", "emoji": "⚙️"},
|
||||||
|
description=fake.text(max_nb_chars=200),
|
||||||
|
parameters=first_tool_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify update was successful
|
||||||
|
assert result == {"result": "success"}
|
||||||
|
|
||||||
|
# Verify tool still exists with the same name
|
||||||
|
db.session.refresh(created_tool)
|
||||||
|
assert created_tool.name == first_tool_name
|
||||||
|
assert created_tool.updated_at is not None
|
||||||
@ -14,11 +14,5 @@ uv run --directory api --dev ruff format ./
|
|||||||
# run dotenv-linter linter
|
# run dotenv-linter linter
|
||||||
uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
|
uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
|
||||||
|
|
||||||
# run import-linter
|
|
||||||
uv run --directory api --dev lint-imports
|
|
||||||
|
|
||||||
# run ty check
|
|
||||||
dev/ty-check
|
|
||||||
|
|
||||||
# run mypy check
|
# run mypy check
|
||||||
dev/mypy-check
|
dev/mypy-check
|
||||||
|
|||||||
@ -41,15 +41,6 @@ if $api_modified; then
|
|||||||
echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# run ty checks
|
|
||||||
uv run --directory api --dev ty check || status=$?
|
|
||||||
status=${status:-0}
|
|
||||||
if [ $status -ne 0 ]; then
|
|
||||||
echo "ty type checker on api module error, exit code: $status"
|
|
||||||
echo "Please run 'dev/ty-check' to check the type errors."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if $web_modified; then
|
if $web_modified; then
|
||||||
|
|||||||
@ -38,7 +38,7 @@ const Field: FC<Props> = ({
|
|||||||
<div className={cn(className, inline && 'flex w-full items-center justify-between')}>
|
<div className={cn(className, inline && 'flex w-full items-center justify-between')}>
|
||||||
<div
|
<div
|
||||||
onClick={() => supportFold && toggleFold()}
|
onClick={() => supportFold && toggleFold()}
|
||||||
className={cn('sticky top-0 z-10 flex items-center justify-between bg-components-panel-bg', supportFold && 'cursor-pointer')}>
|
className={cn('sticky top-0 flex items-center justify-between bg-components-panel-bg', supportFold && 'cursor-pointer')}>
|
||||||
<div className='flex h-6 items-center'>
|
<div className='flex h-6 items-center'>
|
||||||
<div className={cn(isSubTitle ? 'system-xs-medium-uppercase text-text-tertiary' : 'system-sm-semibold-uppercase text-text-secondary')}>
|
<div className={cn(isSubTitle ? 'system-xs-medium-uppercase text-text-tertiary' : 'system-sm-semibold-uppercase text-text-secondary')}>
|
||||||
{title} {required && <span className='text-text-destructive'>*</span>}
|
{title} {required && <span className='text-text-destructive'>*</span>}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user