mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'origin/feat/rag-2' into feat/rag-2
This commit is contained in:
commit
df5a4e5c08
File diff suppressed because it is too large
Load Diff
|
|
@ -4,6 +4,23 @@ title: "[Chore/Refactor] "
|
|||
labels:
|
||||
- refactor
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
|
||||
required: true
|
||||
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
required: true
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report, otherwise it will be closed.
|
||||
required: true
|
||||
- label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :)
|
||||
required: true
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,10 @@ on:
|
|||
types: [closed]
|
||||
branches: [main]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
check-and-update:
|
||||
if: github.event.pull_request.merged == true
|
||||
|
|
@ -16,7 +20,7 @@ jobs:
|
|||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # last 2 commits
|
||||
persist-credentials: false
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
|
|
@ -49,7 +53,7 @@ jobs:
|
|||
if: env.FILES_CHANGED == 'true'
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
- name: Generate i18n translations
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: pnpm run auto-gen-i18n
|
||||
|
||||
|
|
@ -57,6 +61,7 @@ jobs:
|
|||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
|
|
|
|||
|
|
@ -215,3 +215,10 @@ mise.toml
|
|||
# AI Assistant
|
||||
.roo/
|
||||
api/.env.backup
|
||||
|
||||
# Clickzetta test credentials
|
||||
.env.clickzetta
|
||||
.env.clickzetta.test
|
||||
|
||||
# Clickzetta plugin development folder (keep local, ignore for PR)
|
||||
clickzetta/
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
|||
from .storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
|
||||
from .storage.clickzetta_volume_storage_config import ClickZettaVolumeStorageConfig
|
||||
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from .storage.oci_storage_config import OCIStorageConfig
|
||||
|
|
@ -20,6 +21,7 @@ from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
|||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
from .vdb.clickzetta_config import ClickzettaConfig
|
||||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
||||
|
|
@ -52,6 +54,7 @@ class StorageConfig(BaseSettings):
|
|||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"clickzetta-volume",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
|
|
@ -61,8 +64,9 @@ class StorageConfig(BaseSettings):
|
|||
"local",
|
||||
] = Field(
|
||||
description="Type of storage to use."
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', "
|
||||
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', "
|
||||
"'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
|
||||
"'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
default="opendal",
|
||||
)
|
||||
|
||||
|
|
@ -303,6 +307,7 @@ class MiddlewareConfig(
|
|||
AliyunOSSStorageConfig,
|
||||
AzureBlobStorageConfig,
|
||||
BaiduOBSStorageConfig,
|
||||
ClickZettaVolumeStorageConfig,
|
||||
GoogleCloudStorageConfig,
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
OCIStorageConfig,
|
||||
|
|
@ -315,6 +320,7 @@ class MiddlewareConfig(
|
|||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
ChromaConfig,
|
||||
ClickzettaConfig,
|
||||
HuaweiCloudConfig,
|
||||
MilvusConfig,
|
||||
MyScaleConfig,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
"""ClickZetta Volume Storage Configuration"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ClickZettaVolumeStorageConfig(BaseSettings):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field(
|
||||
description="Username for ClickZetta Volume authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field(
|
||||
description="Password for ClickZetta Volume authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field(
|
||||
description="ClickZetta instance identifier",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_SERVICE: str = Field(
|
||||
description="ClickZetta service endpoint",
|
||||
default="api.clickzetta.com",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_WORKSPACE: str = Field(
|
||||
description="ClickZetta workspace name",
|
||||
default="quick_start",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_VCLUSTER: str = Field(
|
||||
description="ClickZetta virtual cluster name",
|
||||
default="default_ap",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_SCHEMA: str = Field(
|
||||
description="ClickZetta schema name",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_TYPE: str = Field(
|
||||
description="ClickZetta volume type (table|user|external)",
|
||||
default="user",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_NAME: Optional[str] = Field(
|
||||
description="ClickZetta volume name for external volumes",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field(
|
||||
description="Prefix for ClickZetta volume table names",
|
||||
default="dataset_",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field(
|
||||
description="Directory prefix for User Volume to organize Dify files",
|
||||
default="dify_km",
|
||||
)
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ClickzettaConfig(BaseModel):
|
||||
"""
|
||||
Clickzetta Lakehouse vector database configuration
|
||||
"""
|
||||
|
||||
CLICKZETTA_USERNAME: Optional[str] = Field(
|
||||
description="Username for authenticating with Clickzetta Lakehouse",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_PASSWORD: Optional[str] = Field(
|
||||
description="Password for authenticating with Clickzetta Lakehouse",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_INSTANCE: Optional[str] = Field(
|
||||
description="Clickzetta Lakehouse instance ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_SERVICE: Optional[str] = Field(
|
||||
description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')",
|
||||
default="api.clickzetta.com",
|
||||
)
|
||||
|
||||
CLICKZETTA_WORKSPACE: Optional[str] = Field(
|
||||
description="Clickzetta workspace name",
|
||||
default="default",
|
||||
)
|
||||
|
||||
CLICKZETTA_VCLUSTER: Optional[str] = Field(
|
||||
description="Clickzetta virtual cluster name",
|
||||
default="default_ap",
|
||||
)
|
||||
|
||||
CLICKZETTA_SCHEMA: Optional[str] = Field(
|
||||
description="Database schema name in Clickzetta",
|
||||
default="public",
|
||||
)
|
||||
|
||||
CLICKZETTA_BATCH_SIZE: Optional[int] = Field(
|
||||
description="Batch size for bulk insert operations",
|
||||
default=100,
|
||||
)
|
||||
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX: Optional[bool] = Field(
|
||||
description="Enable inverted index for full-text search capabilities",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CLICKZETTA_ANALYZER_TYPE: Optional[str] = Field(
|
||||
description="Analyzer type for full-text search: keyword, english, chinese, unicode",
|
||||
default="chinese",
|
||||
)
|
||||
|
||||
CLICKZETTA_ANALYZER_MODE: Optional[str] = Field(
|
||||
description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)",
|
||||
default="smart",
|
||||
)
|
||||
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field(
|
||||
description="Distance function for vector similarity: l2_distance or cosine_distance",
|
||||
default="cosine_distance",
|
||||
)
|
||||
|
|
@ -694,6 +694,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
|
@ -742,6 +743,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
| VectorType.TENCENT
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
|
|
|||
|
|
@ -49,7 +49,6 @@ class FileApi(Resource):
|
|||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
def post(self):
|
||||
file = request.files["file"]
|
||||
source_str = request.form.get("source")
|
||||
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
|
||||
|
||||
|
|
@ -58,6 +57,7 @@ class FileApi(Resource):
|
|||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
file = request.files["file"]
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
|
|
|||
|
|
@ -191,9 +191,6 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
|
@ -201,6 +198,8 @@ class WebappLogoWorkspaceApi(Resource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
|
|
|
|||
|
|
@ -20,18 +20,17 @@ class FileApi(Resource):
|
|||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
file = request.files["file"]
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
|
|
|
|||
|
|
@ -234,8 +234,6 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
args["retrieval_model"].get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
|
@ -243,6 +241,8 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
|
|
|
|||
|
|
@ -12,18 +12,17 @@ from services.file_service import FileService
|
|||
class FileApi(WebApiResource):
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
file = request.files["file"]
|
||||
source = request.form.get("source")
|
||||
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
source = request.form.get("source")
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
|
||||
|
|
|
|||
|
|
@ -121,9 +121,8 @@ class TokenBufferMemory:
|
|||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if curr_message_tokens > max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
pruned_memory.append(prompt_messages.pop(0))
|
||||
prompt_messages.pop(0)
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
|
|
|||
|
|
@ -0,0 +1,190 @@
|
|||
# Clickzetta Vector Database Integration
|
||||
|
||||
This module provides integration with Clickzetta Lakehouse as a vector database for Dify.
|
||||
|
||||
## Features
|
||||
|
||||
- **Vector Storage**: Store and retrieve high-dimensional vectors using Clickzetta's native VECTOR type
|
||||
- **Vector Search**: Efficient similarity search using HNSW algorithm
|
||||
- **Full-Text Search**: Leverage Clickzetta's inverted index for powerful text search capabilities
|
||||
- **Hybrid Search**: Combine vector similarity and full-text search for better results
|
||||
- **Multi-language Support**: Built-in support for Chinese, English, and Unicode text processing
|
||||
- **Scalable**: Leverage Clickzetta's distributed architecture for large-scale deployments
|
||||
|
||||
## Configuration
|
||||
|
||||
### Required Environment Variables
|
||||
|
||||
All seven configuration parameters are required:
|
||||
|
||||
```bash
|
||||
# Authentication
|
||||
CLICKZETTA_USERNAME=your_username
|
||||
CLICKZETTA_PASSWORD=your_password
|
||||
|
||||
# Instance configuration
|
||||
CLICKZETTA_INSTANCE=your_instance_id
|
||||
CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
CLICKZETTA_WORKSPACE=your_workspace
|
||||
CLICKZETTA_VCLUSTER=your_vcluster
|
||||
CLICKZETTA_SCHEMA=your_schema
|
||||
```
|
||||
|
||||
### Optional Configuration
|
||||
|
||||
```bash
|
||||
# Batch processing
|
||||
CLICKZETTA_BATCH_SIZE=100
|
||||
|
||||
# Full-text search configuration
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX=true
|
||||
CLICKZETTA_ANALYZER_TYPE=chinese # Options: keyword, english, chinese, unicode
|
||||
CLICKZETTA_ANALYZER_MODE=smart # Options: max_word, smart
|
||||
|
||||
# Vector search configuration
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # Options: l2_distance, cosine_distance
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Set Clickzetta as the Vector Store
|
||||
|
||||
In your Dify configuration, set:
|
||||
|
||||
```bash
|
||||
VECTOR_STORE=clickzetta
|
||||
```
|
||||
|
||||
### 2. Table Structure
|
||||
|
||||
Clickzetta will automatically create tables with the following structure:
|
||||
|
||||
```sql
|
||||
CREATE TABLE <collection_name> (
|
||||
id STRING NOT NULL,
|
||||
content STRING NOT NULL,
|
||||
metadata JSON,
|
||||
vector VECTOR(FLOAT, <dimension>) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
-- Vector index for similarity search
|
||||
CREATE VECTOR INDEX idx_<collection_name>_vec
|
||||
ON TABLE <schema>.<collection_name>(vector)
|
||||
PROPERTIES (
|
||||
"distance.function" = "cosine_distance",
|
||||
"scalar.type" = "f32"
|
||||
);
|
||||
|
||||
-- Inverted index for full-text search (if enabled)
|
||||
CREATE INVERTED INDEX idx_<collection_name>_text
|
||||
ON <schema>.<collection_name>(content)
|
||||
PROPERTIES (
|
||||
"analyzer" = "chinese",
|
||||
"mode" = "smart"
|
||||
);
|
||||
```
|
||||
|
||||
## Full-Text Search Capabilities
|
||||
|
||||
Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
|
||||
### Analyzer Types
|
||||
|
||||
1. **keyword**: No tokenization, treats the entire string as a single token
|
||||
- Best for: Exact matching, IDs, codes
|
||||
|
||||
2. **english**: Designed for English text
|
||||
- Features: Recognizes ASCII letters and numbers, converts to lowercase
|
||||
- Best for: English content
|
||||
|
||||
3. **chinese**: Chinese text tokenizer
|
||||
- Features: Recognizes Chinese and English characters, removes punctuation
|
||||
- Best for: Chinese or mixed Chinese-English content
|
||||
|
||||
4. **unicode**: Multi-language tokenizer based on Unicode
|
||||
- Features: Recognizes text boundaries in multiple languages
|
||||
- Best for: Multi-language content
|
||||
|
||||
### Analyzer Modes
|
||||
|
||||
- **max_word**: Fine-grained tokenization (more tokens)
|
||||
- **smart**: Intelligent tokenization (balanced)
|
||||
|
||||
### Full-Text Search Functions
|
||||
|
||||
- `MATCH_ALL(column, query)`: All terms must be present
|
||||
- `MATCH_ANY(column, query)`: At least one term must be present
|
||||
- `MATCH_PHRASE(column, query)`: Exact phrase matching
|
||||
- `MATCH_PHRASE_PREFIX(column, query)`: Phrase prefix matching
|
||||
- `MATCH_REGEXP(column, pattern)`: Regular expression matching
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Vector Search
|
||||
|
||||
1. **Adjust exploration factor** for accuracy vs speed trade-off:
|
||||
```sql
|
||||
SET cz.vector.index.search.ef=64;
|
||||
```
|
||||
|
||||
2. **Use appropriate distance functions**:
|
||||
- `cosine_distance`: Best for normalized embeddings (e.g., from language models)
|
||||
- `l2_distance`: Best for raw feature vectors
|
||||
|
||||
### Full-Text Search
|
||||
|
||||
1. **Choose the right analyzer**:
|
||||
- Use `keyword` for exact matching
|
||||
- Use language-specific analyzers for better tokenization
|
||||
|
||||
2. **Combine with vector search**:
|
||||
- Pre-filter with full-text search for better performance
|
||||
- Use hybrid search for improved relevance
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Issues
|
||||
|
||||
1. Verify all 7 required configuration parameters are set
|
||||
2. Check network connectivity to Clickzetta service
|
||||
3. Ensure the user has proper permissions on the schema
|
||||
|
||||
### Search Performance
|
||||
|
||||
1. Verify vector index exists:
|
||||
```sql
|
||||
SHOW INDEX FROM <schema>.<table_name>;
|
||||
```
|
||||
|
||||
2. Check if vector index is being used:
|
||||
```sql
|
||||
EXPLAIN SELECT ... WHERE l2_distance(...) < threshold;
|
||||
```
|
||||
Look for `vector_index_search_type` in the execution plan.
|
||||
|
||||
### Full-Text Search Not Working
|
||||
|
||||
1. Verify inverted index is created
|
||||
2. Check analyzer configuration matches your content language
|
||||
3. Use `TOKENIZE()` function to test tokenization:
|
||||
```sql
|
||||
SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart'));
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns
|
||||
2. Full-text search relevance scores are not provided by Clickzetta
|
||||
3. Inverted index creation may fail for very large existing tables (continue without error)
|
||||
4. Index naming constraints:
|
||||
- Index names must be unique within a schema
|
||||
- Only one vector index can be created per column
|
||||
- The implementation uses timestamps to ensure unique index names
|
||||
5. A column can only have one vector index at a time
|
||||
|
||||
## References
|
||||
|
||||
- [Clickzetta Vector Search Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/vector-search.md)
|
||||
- [Clickzetta Inverted Index Documentation](../../../../../../../yunqidoc/cn_markdown_20250526/inverted-index.md)
|
||||
- [Clickzetta SQL Functions](../../../../../../../yunqidoc/cn_markdown_20250526/sql_functions/)
|
||||
|
|
@ -0,0 +1 @@
|
|||
# Clickzetta Vector Database Integration for Dify
|
||||
|
|
@ -0,0 +1,834 @@
|
|||
import json
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import clickzetta # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from clickzetta import Connection
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ClickZetta Lakehouse Vector Database Configuration
|
||||
|
||||
|
||||
class ClickzettaConfig(BaseModel):
|
||||
"""
|
||||
Configuration class for Clickzetta connection.
|
||||
"""
|
||||
|
||||
username: str
|
||||
password: str
|
||||
instance: str
|
||||
service: str = "api.clickzetta.com"
|
||||
workspace: str = "quick_start"
|
||||
vcluster: str = "default_ap"
|
||||
schema_name: str = "dify" # Renamed to avoid shadowing BaseModel.schema
|
||||
# Advanced settings
|
||||
batch_size: int = 20 # Reduced batch size to avoid large SQL statements
|
||||
enable_inverted_index: bool = True # Enable inverted index for full-text search
|
||||
analyzer_type: str = "chinese" # Analyzer type for full-text search: keyword, english, chinese, unicode
|
||||
analyzer_mode: str = "smart" # Analyzer mode: max_word, smart
|
||||
vector_distance_function: str = "cosine_distance" # l2_distance or cosine_distance
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""
|
||||
Validate the configuration values.
|
||||
"""
|
||||
if not values.get("username"):
|
||||
raise ValueError("config CLICKZETTA_USERNAME is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("config CLICKZETTA_PASSWORD is required")
|
||||
if not values.get("instance"):
|
||||
raise ValueError("config CLICKZETTA_INSTANCE is required")
|
||||
if not values.get("service"):
|
||||
raise ValueError("config CLICKZETTA_SERVICE is required")
|
||||
if not values.get("workspace"):
|
||||
raise ValueError("config CLICKZETTA_WORKSPACE is required")
|
||||
if not values.get("vcluster"):
|
||||
raise ValueError("config CLICKZETTA_VCLUSTER is required")
|
||||
if not values.get("schema_name"):
|
||||
raise ValueError("config CLICKZETTA_SCHEMA is required")
|
||||
return values
|
||||
|
||||
|
||||
class ClickzettaVector(BaseVector):
|
||||
"""
|
||||
Clickzetta vector storage implementation.
|
||||
"""
|
||||
|
||||
# Class-level write queue and lock for serializing writes
|
||||
_write_queue: Optional[queue.Queue] = None
|
||||
_write_thread: Optional[threading.Thread] = None
|
||||
_write_lock = threading.Lock()
|
||||
_shutdown = False
|
||||
|
||||
def __init__(self, collection_name: str, config: ClickzettaConfig):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name
|
||||
self._connection: Optional["Connection"] = None
|
||||
self._init_connection()
|
||||
self._init_write_queue()
|
||||
|
||||
def _init_connection(self):
|
||||
"""Initialize Clickzetta connection."""
|
||||
self._connection = clickzetta.connect(
|
||||
username=self._config.username,
|
||||
password=self._config.password,
|
||||
instance=self._config.instance,
|
||||
service=self._config.service,
|
||||
workspace=self._config.workspace,
|
||||
vcluster=self._config.vcluster,
|
||||
schema=self._config.schema_name
|
||||
)
|
||||
|
||||
# Set session parameters for better string handling and performance optimization
|
||||
if self._connection is not None:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use quote mode for string literal escaping to handle quotes better
|
||||
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'")
|
||||
logger.info("Set string literal escape mode to 'quote' for better quote handling")
|
||||
|
||||
# Performance optimization hints for vector operations
|
||||
self._set_performance_hints(cursor)
|
||||
|
||||
def _set_performance_hints(self, cursor):
|
||||
"""Set ClickZetta performance optimization hints for vector operations."""
|
||||
try:
|
||||
# Performance optimization hints for vector operations and query processing
|
||||
performance_hints = [
|
||||
# Vector index optimization
|
||||
"SET cz.storage.parquet.vector.index.read.memory.cache = true",
|
||||
"SET cz.storage.parquet.vector.index.read.local.cache = false",
|
||||
|
||||
# Query optimization
|
||||
"SET cz.sql.table.scan.push.down.filter = true",
|
||||
"SET cz.sql.table.scan.enable.ensure.filter = true",
|
||||
"SET cz.storage.always.prefetch.internal = true",
|
||||
"SET cz.optimizer.generate.columns.always.valid = true",
|
||||
"SET cz.sql.index.prewhere.enabled = true",
|
||||
|
||||
# Storage optimization
|
||||
"SET cz.storage.parquet.enable.io.prefetch = false",
|
||||
"SET cz.optimizer.enable.mv.rewrite = false",
|
||||
"SET cz.sql.dump.as.lz4 = true",
|
||||
"SET cz.optimizer.limited.optimization.naive.query = true",
|
||||
"SET cz.sql.table.scan.enable.push.down.log = false",
|
||||
"SET cz.storage.use.file.format.local.stats = false",
|
||||
"SET cz.storage.local.file.object.cache.level = all",
|
||||
|
||||
# Job execution optimization
|
||||
"SET cz.sql.job.fast.mode = true",
|
||||
"SET cz.storage.parquet.non.contiguous.read = true",
|
||||
"SET cz.sql.compaction.after.commit = true"
|
||||
]
|
||||
|
||||
for hint in performance_hints:
|
||||
cursor.execute(hint)
|
||||
|
||||
logger.info("Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints))
|
||||
|
||||
except Exception:
|
||||
# Catch any errors setting performance hints but continue with defaults
|
||||
logger.exception("Failed to set some performance hints, continuing with default settings")
|
||||
|
||||
@classmethod
|
||||
def _init_write_queue(cls):
|
||||
"""Initialize the write queue and worker thread."""
|
||||
with cls._write_lock:
|
||||
if cls._write_queue is None:
|
||||
cls._write_queue = queue.Queue()
|
||||
cls._write_thread = threading.Thread(target=cls._write_worker, daemon=True)
|
||||
cls._write_thread.start()
|
||||
logger.info("Started Clickzetta write worker thread")
|
||||
|
||||
@classmethod
|
||||
def _write_worker(cls):
|
||||
"""Worker thread that processes write tasks sequentially."""
|
||||
while not cls._shutdown:
|
||||
try:
|
||||
# Get task from queue with timeout
|
||||
if cls._write_queue is not None:
|
||||
task = cls._write_queue.get(timeout=1)
|
||||
if task is None: # Shutdown signal
|
||||
break
|
||||
|
||||
# Execute the write task
|
||||
func, args, kwargs, result_queue = task
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
result_queue.put((True, result))
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
logger.exception("Write task failed")
|
||||
result_queue.put((False, e))
|
||||
finally:
|
||||
cls._write_queue.task_done()
|
||||
else:
|
||||
break
|
||||
except queue.Empty:
|
||||
continue
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
logger.exception("Write worker error")
|
||||
|
||||
def _execute_write(self, func, *args, **kwargs):
|
||||
"""Execute a write operation through the queue."""
|
||||
if ClickzettaVector._write_queue is None:
|
||||
raise RuntimeError("Write queue not initialized")
|
||||
|
||||
result_queue: queue.Queue[tuple[bool, Any]] = queue.Queue()
|
||||
ClickzettaVector._write_queue.put((func, args, kwargs, result_queue))
|
||||
|
||||
# Wait for result
|
||||
success, result = result_queue.get()
|
||||
if not success:
|
||||
raise result
|
||||
return result
|
||||
|
||||
def get_type(self) -> str:
|
||||
"""Return the vector database type."""
|
||||
return "clickzetta"
|
||||
|
||||
def _ensure_connection(self) -> "Connection":
|
||||
"""Ensure connection is available and return it."""
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Database connection not initialized")
|
||||
return self._connection
|
||||
|
||||
def _table_exists(self) -> bool:
|
||||
"""Check if the table exists."""
|
||||
try:
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}")
|
||||
return True
|
||||
except (RuntimeError, ValueError) as e:
|
||||
if "table or view not found" in str(e).lower():
|
||||
return False
|
||||
else:
|
||||
# Re-raise if it's a different error
|
||||
raise
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Create the collection and add initial documents."""
|
||||
# Execute table creation through write queue to avoid concurrent conflicts
|
||||
self._execute_write(self._create_table_and_indexes, embeddings)
|
||||
|
||||
# Add initial texts
|
||||
if texts:
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def _create_table_and_indexes(self, embeddings: list[list[float]]):
|
||||
"""Create table and indexes (executed in write worker thread)."""
|
||||
# Check if table already exists to avoid unnecessary index creation
|
||||
if self._table_exists():
|
||||
logger.info("Table %s.%s already exists, skipping creation", self._config.schema_name, self._table_name)
|
||||
return
|
||||
|
||||
# Create table with vector and metadata columns
|
||||
dimension = len(embeddings[0]) if embeddings else 768
|
||||
|
||||
create_table_sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} (
|
||||
id STRING NOT NULL COMMENT 'Unique document identifier',
|
||||
{Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval',
|
||||
{Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes',
|
||||
{Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT
|
||||
'High-dimensional embedding vector for semantic similarity search',
|
||||
PRIMARY KEY (id)
|
||||
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content'
|
||||
"""
|
||||
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(create_table_sql)
|
||||
logger.info("Created table %s.%s", self._config.schema_name, self._table_name)
|
||||
|
||||
# Create vector index
|
||||
self._create_vector_index(cursor)
|
||||
|
||||
# Create inverted index for full-text search if enabled
|
||||
if self._config.enable_inverted_index:
|
||||
self._create_inverted_index(cursor)
|
||||
|
||||
def _create_vector_index(self, cursor):
|
||||
"""Create HNSW vector index for similarity search."""
|
||||
# Use a fixed index name based on table and column name
|
||||
index_name = f"idx_{self._table_name}_vector"
|
||||
|
||||
# First check if an index already exists on this column
|
||||
try:
|
||||
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
|
||||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
# Check if vector index already exists on the embedding column
|
||||
if Field.VECTOR.value in str(idx).lower():
|
||||
logger.info("Vector index already exists on column %s", Field.VECTOR.value)
|
||||
return
|
||||
except (RuntimeError, ValueError) as e:
|
||||
logger.warning("Failed to check existing indexes: %s", e)
|
||||
|
||||
index_sql = f"""
|
||||
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
|
||||
ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value})
|
||||
PROPERTIES (
|
||||
"distance.function" = "{self._config.vector_distance_function}",
|
||||
"scalar.type" = "f32",
|
||||
"m" = "16",
|
||||
"ef.construction" = "128"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
cursor.execute(index_sql)
|
||||
logger.info("Created vector index: %s", index_name)
|
||||
except (RuntimeError, ValueError) as e:
|
||||
error_msg = str(e).lower()
|
||||
if ("already exists" in error_msg or
|
||||
"already has index" in error_msg or
|
||||
"with the same type" in error_msg):
|
||||
logger.info("Vector index already exists: %s", e)
|
||||
else:
|
||||
logger.exception("Failed to create vector index")
|
||||
raise
|
||||
|
||||
def _create_inverted_index(self, cursor):
|
||||
"""Create inverted index for full-text search."""
|
||||
# Use a fixed index name based on table name to avoid duplicates
|
||||
index_name = f"idx_{self._table_name}_text"
|
||||
|
||||
# Check if an inverted index already exists on this column
|
||||
try:
|
||||
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
|
||||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
idx_str = str(idx).lower()
|
||||
# More precise check: look for inverted index specifically on the content column
|
||||
if ("inverted" in idx_str and
|
||||
Field.CONTENT_KEY.value.lower() in idx_str and
|
||||
(index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)):
|
||||
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx)
|
||||
return
|
||||
except (RuntimeError, ValueError) as e:
|
||||
logger.warning("Failed to check existing indexes: %s", e)
|
||||
|
||||
index_sql = f"""
|
||||
CREATE INVERTED INDEX IF NOT EXISTS {index_name}
|
||||
ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value})
|
||||
PROPERTIES (
|
||||
"analyzer" = "{self._config.analyzer_type}",
|
||||
"mode" = "{self._config.analyzer_mode}"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
cursor.execute(index_sql)
|
||||
logger.info("Created inverted index: %s", index_name)
|
||||
except (RuntimeError, ValueError) as e:
|
||||
error_msg = str(e).lower()
|
||||
# Handle ClickZetta specific error messages
|
||||
if (("already exists" in error_msg or
|
||||
"already has index" in error_msg or
|
||||
"with the same type" in error_msg or
|
||||
"cannot create inverted index" in error_msg) and
|
||||
"already has index" in error_msg):
|
||||
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value)
|
||||
# Try to get the existing index name for logging
|
||||
try:
|
||||
cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}")
|
||||
existing_indexes = cursor.fetchall()
|
||||
for idx in existing_indexes:
|
||||
if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower():
|
||||
logger.info("Found existing inverted index: %s", idx)
|
||||
break
|
||||
except (RuntimeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
logger.warning("Failed to create inverted index: %s", e)
|
||||
# Continue without inverted index - full-text search will fall back to LIKE
|
||||
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
"""Add documents with embeddings to the collection."""
|
||||
if not documents:
|
||||
return
|
||||
|
||||
batch_size = self._config.batch_size
|
||||
total_batches = (len(documents) + batch_size - 1) // batch_size
|
||||
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i:i + batch_size]
|
||||
batch_embeddings = embeddings[i:i + batch_size]
|
||||
|
||||
# Execute batch insert through write queue
|
||||
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
|
||||
|
||||
def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]],
|
||||
batch_index: int, batch_size: int, total_batches: int):
|
||||
"""Insert a batch of documents using parameterized queries (executed in write worker thread)."""
|
||||
if not batch_docs or not batch_embeddings:
|
||||
logger.warning("Empty batch provided, skipping insertion")
|
||||
return
|
||||
|
||||
if len(batch_docs) != len(batch_embeddings):
|
||||
logger.error("Mismatch between docs (%d) and embeddings (%d)", len(batch_docs), len(batch_embeddings))
|
||||
return
|
||||
|
||||
# Prepare data for parameterized insertion
|
||||
data_rows = []
|
||||
vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768
|
||||
|
||||
for doc, embedding in zip(batch_docs, batch_embeddings):
|
||||
# Optimized: minimal checks for common case, fallback for edge cases
|
||||
metadata = doc.metadata if doc.metadata else {}
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
|
||||
doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))
|
||||
|
||||
# Fast path for JSON serialization
|
||||
try:
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=True)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("JSON serialization failed, using empty dict")
|
||||
metadata_json = "{}"
|
||||
|
||||
content = doc.page_content or ""
|
||||
|
||||
# According to ClickZetta docs, vector should be formatted as array string
|
||||
# for external systems: '[1.0, 2.0, 3.0]'
|
||||
vector_str = '[' + ','.join(map(str, embedding)) + ']'
|
||||
data_rows.append([doc_id, content, metadata_json, vector_str])
|
||||
|
||||
# Check if we have any valid data to insert
|
||||
if not data_rows:
|
||||
logger.warning("No valid documents to insert in batch %d/%d", batch_index // batch_size + 1, total_batches)
|
||||
return
|
||||
|
||||
# Use parameterized INSERT with executemany for better performance and security
|
||||
# Cast JSON and VECTOR in SQL, pass raw data as parameters
|
||||
columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}"
|
||||
insert_sql = (
|
||||
f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) "
|
||||
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))"
|
||||
)
|
||||
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
# Set session-level hints for batch insert operations
|
||||
# Note: executemany doesn't support hints parameter, so we set them as session variables
|
||||
cursor.execute("SET cz.sql.job.fast.mode = true")
|
||||
cursor.execute("SET cz.sql.compaction.after.commit = true")
|
||||
cursor.execute("SET cz.storage.always.prefetch.internal = true")
|
||||
|
||||
cursor.executemany(insert_sql, data_rows)
|
||||
logger.info(
|
||||
f"Inserted batch {batch_index // batch_size + 1}/{total_batches} "
|
||||
f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)"
|
||||
)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
logger.exception("Parameterized SQL execution failed for %d documents: %s", len(data_rows), e)
|
||||
logger.exception("SQL template: %s", insert_sql)
|
||||
logger.exception("Sample data row: %s", data_rows[0] if data_rows else 'None')
|
||||
raise
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
"""Check if a document exists by ID."""
|
||||
safe_id = self._safe_doc_id(id)
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?",
|
||||
[safe_id]
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
return result[0] > 0 if result else False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
"""Delete documents by IDs."""
|
||||
if not ids:
|
||||
return
|
||||
|
||||
# Check if table exists before attempting delete
|
||||
if not self._table_exists():
|
||||
logger.warning("Table %s.%s does not exist, skipping delete", self._config.schema_name, self._table_name)
|
||||
return
|
||||
|
||||
# Execute delete through write queue
|
||||
self._execute_write(self._delete_by_ids_impl, ids)
|
||||
|
||||
def _delete_by_ids_impl(self, ids: list[str]) -> None:
|
||||
"""Implementation of delete by IDs (executed in write worker thread)."""
|
||||
safe_ids = [self._safe_doc_id(id) for id in ids]
|
||||
# Create properly escaped string literals for SQL
|
||||
id_list = ",".join(f"'{id}'" for id in safe_ids)
|
||||
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})"
|
||||
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
"""Delete documents by metadata field."""
|
||||
# Check if table exists before attempting delete
|
||||
if not self._table_exists():
|
||||
logger.warning("Table %s.%s does not exist, skipping delete", self._config.schema_name, self._table_name)
|
||||
return
|
||||
|
||||
# Execute delete through write queue
|
||||
self._execute_write(self._delete_by_metadata_field_impl, key, value)
|
||||
|
||||
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None:
|
||||
"""Implementation of delete by metadata field (executed in write worker thread)."""
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
# Using JSON path to filter with parameterized query
|
||||
# Note: JSON path requires literal key name, cannot be parameterized
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
sql = (f"DELETE FROM {self._config.schema_name}.{self._table_name} "
|
||||
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?")
|
||||
cursor.execute(sql, [value])
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents by vector similarity."""
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
if document_ids_filter:
|
||||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Add distance threshold based on distance function
|
||||
vector_dimension = len(query_vector)
|
||||
if self._config.vector_distance_function == "cosine_distance":
|
||||
# For cosine distance, smaller is better (0 = identical, 2 = opposite)
|
||||
distance_func = "COSINE_DISTANCE"
|
||||
if score_threshold > 0:
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, "
|
||||
f"{query_vector_str}) < {2 - score_threshold}")
|
||||
else:
|
||||
# For L2 distance, smaller is better
|
||||
distance_func = "L2_DISTANCE"
|
||||
if score_threshold > 0:
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, "
|
||||
f"{query_vector_str}) < {score_threshold}")
|
||||
|
||||
where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1"
|
||||
|
||||
# Execute vector search query
|
||||
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value},
|
||||
{distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
ORDER BY distance
|
||||
LIMIT {top_k}
|
||||
"""
|
||||
|
||||
documents = []
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
# Use hints parameter for vector search optimization
|
||||
search_hints = {
|
||||
'hints': {
|
||||
'sdk.job.timeout': 60, # Increase timeout for vector search
|
||||
'cz.sql.job.fast.mode': True,
|
||||
'cz.storage.parquet.vector.index.read.memory.cache': True
|
||||
}
|
||||
}
|
||||
cursor.execute(search_sql, parameters=search_hints)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.error("JSON parsing failed: %s", e)
|
||||
# Fallback: extract document_id with regex
|
||||
import re
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
|
||||
# Add score based on distance
|
||||
if self._config.vector_distance_function == "cosine_distance":
|
||||
metadata["score"] = 1 - (row[3] / 2)
|
||||
else:
|
||||
metadata["score"] = 1 / (1 + row[3])
|
||||
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Search for documents using full-text search with inverted index."""
|
||||
if not self._config.enable_inverted_index:
|
||||
logger.warning("Full-text search is not enabled. Enable inverted index in config.")
|
||||
return []
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
if document_ids_filter:
|
||||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use match_all function for full-text search
|
||||
# match_all requires all terms to be present
|
||||
# Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')")
|
||||
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
# Execute full-text search query
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
LIMIT {top_k}
|
||||
"""
|
||||
|
||||
documents = []
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
# Use hints parameter for full-text search optimization
|
||||
fulltext_hints = {
|
||||
'hints': {
|
||||
'sdk.job.timeout': 30, # Timeout for full-text search
|
||||
'cz.sql.job.fast.mode': True,
|
||||
'cz.sql.index.prewhere.enabled': True
|
||||
}
|
||||
}
|
||||
cursor.execute(search_sql, parameters=fulltext_hints)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.error("JSON parsing failed: %s", e)
|
||||
# Fallback: extract document_id with regex
|
||||
import re
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
|
||||
# Add a relevance score for full-text search
|
||||
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
logger.exception("Full-text search failed")
|
||||
# Fallback to LIKE search if full-text search fails
|
||||
return self._search_by_like(query, **kwargs)
|
||||
|
||||
return documents
|
||||
|
||||
def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Fallback search using LIKE operator."""
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
if document_ids_filter:
|
||||
safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter]
|
||||
doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids)
|
||||
# Use json_extract_string function for ClickZetta compatibility
|
||||
filter_clauses.append(
|
||||
f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})"
|
||||
)
|
||||
|
||||
# No need for dataset_id filter since each dataset has its own table
|
||||
|
||||
# Use simple quote escaping for LIKE clause
|
||||
escaped_query = query.replace("'", "''")
|
||||
filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'")
|
||||
where_clause = " AND ".join(filter_clauses)
|
||||
|
||||
search_sql = f"""
|
||||
SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}
|
||||
FROM {self._config.schema_name}.{self._table_name}
|
||||
WHERE {where_clause}
|
||||
LIMIT {top_k}
|
||||
"""
|
||||
|
||||
documents = []
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
# Use hints parameter for LIKE search optimization
|
||||
like_hints = {
|
||||
'hints': {
|
||||
'sdk.job.timeout': 20, # Timeout for LIKE search
|
||||
'cz.sql.job.fast.mode': True
|
||||
}
|
||||
}
|
||||
cursor.execute(search_sql, parameters=like_hints)
|
||||
results = cursor.fetchall()
|
||||
|
||||
for row in results:
|
||||
# Parse metadata from JSON string (may be double-encoded)
|
||||
try:
|
||||
if row[2]:
|
||||
metadata = json.loads(row[2])
|
||||
|
||||
# If result is a string, it's double-encoded JSON - parse again
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.error("JSON parsing failed: %s", e)
|
||||
# Fallback: extract document_id with regex
|
||||
import re
|
||||
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))
|
||||
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}
|
||||
|
||||
# Ensure required fields are set
|
||||
metadata["doc_id"] = row[0] # segment id
|
||||
|
||||
# Ensure document_id exists (critical for Dify's format_retrieval_documents)
|
||||
if "document_id" not in metadata:
|
||||
metadata["document_id"] = row[0] # fallback to segment id
|
||||
|
||||
metadata["score"] = 0.5 # Lower score for LIKE search
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete the entire collection."""
|
||||
connection = self._ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
|
||||
|
||||
|
||||
def _format_vector_simple(self, vector: list[float]) -> str:
|
||||
"""Simple vector formatting for SQL queries."""
|
||||
return ','.join(map(str, vector))
|
||||
|
||||
def _safe_doc_id(self, doc_id: str) -> str:
|
||||
"""Ensure doc_id is safe for SQL and doesn't contain special characters."""
|
||||
if not doc_id:
|
||||
return str(uuid.uuid4())
|
||||
# Remove or replace potentially problematic characters
|
||||
safe_id = str(doc_id)
|
||||
# Only allow alphanumeric, hyphens, underscores
|
||||
safe_id = ''.join(c for c in safe_id if c.isalnum() or c in '-_')
|
||||
if not safe_id: # If all characters were removed
|
||||
return str(uuid.uuid4())
|
||||
return safe_id[:255] # Limit length
|
||||
|
||||
|
||||
|
||||
class ClickzettaVectorFactory(AbstractVectorFactory):
|
||||
"""Factory for creating Clickzetta vector instances."""
|
||||
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||
"""Initialize a Clickzetta vector instance."""
|
||||
# Get configuration from environment variables or dataset config
|
||||
config = ClickzettaConfig(
|
||||
username=dify_config.CLICKZETTA_USERNAME or "",
|
||||
password=dify_config.CLICKZETTA_PASSWORD or "",
|
||||
instance=dify_config.CLICKZETTA_INSTANCE or "",
|
||||
service=dify_config.CLICKZETTA_SERVICE or "api.clickzetta.com",
|
||||
workspace=dify_config.CLICKZETTA_WORKSPACE or "quick_start",
|
||||
vcluster=dify_config.CLICKZETTA_VCLUSTER or "default_ap",
|
||||
schema_name=dify_config.CLICKZETTA_SCHEMA or "dify",
|
||||
batch_size=dify_config.CLICKZETTA_BATCH_SIZE or 100,
|
||||
enable_inverted_index=dify_config.CLICKZETTA_ENABLE_INVERTED_INDEX or True,
|
||||
analyzer_type=dify_config.CLICKZETTA_ANALYZER_TYPE or "chinese",
|
||||
analyzer_mode=dify_config.CLICKZETTA_ANALYZER_MODE or "smart",
|
||||
vector_distance_function=dify_config.CLICKZETTA_VECTOR_DISTANCE_FUNCTION or "cosine_distance",
|
||||
)
|
||||
|
||||
# Use dataset collection name as table name
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower()
|
||||
|
||||
return ClickzettaVector(collection_name=collection_name, config=config)
|
||||
|
||||
|
|
@ -172,6 +172,10 @@ class Vector:
|
|||
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
|
||||
|
||||
return MatrixoneVectorFactory
|
||||
case VectorType.CLICKZETTA:
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
|
||||
|
||||
return ClickzettaVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
|
|
|||
|
|
@ -30,3 +30,4 @@ class VectorType(StrEnum):
|
|||
TABLESTORE = "tablestore"
|
||||
HUAWEI_CLOUD = "huawei_cloud"
|
||||
MATRIXONE = "matrixone"
|
||||
CLICKZETTA = "clickzetta"
|
||||
|
|
|
|||
|
|
@ -37,12 +37,12 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
|||
@staticmethod
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||
try:
|
||||
if local_tz is None:
|
||||
local_tz = datetime.now().astimezone().tzinfo
|
||||
if isinstance(local_tz, str):
|
||||
local_tz = pytz.timezone(local_tz)
|
||||
local_time = datetime.strptime(localtime, time_format)
|
||||
localtime = local_tz.localize(local_time) # type: ignore
|
||||
if local_tz is None:
|
||||
localtime = local_time.astimezone() # type: ignore
|
||||
elif isinstance(local_tz, str):
|
||||
local_tz = pytz.timezone(local_tz)
|
||||
localtime = local_tz.localize(local_time) # type: ignore
|
||||
timestamp = int(localtime.timestamp()) # type: ignore
|
||||
return timestamp
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from os import getenv
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
|
@ -20,6 +21,20 @@ API_TOOL_DEFAULT_TIMEOUT = (
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedResponse:
|
||||
"""Represents a parsed HTTP response with type information"""
|
||||
|
||||
content: Union[str, dict]
|
||||
is_json: bool
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Convert response to string format for credential validation"""
|
||||
if isinstance(self.content, dict):
|
||||
return json.dumps(self.content, ensure_ascii=False)
|
||||
return str(self.content)
|
||||
|
||||
|
||||
class ApiTool(Tool):
|
||||
"""
|
||||
Api tool
|
||||
|
|
@ -58,7 +73,9 @@ class ApiTool(Tool):
|
|||
|
||||
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
|
||||
# validate response
|
||||
return self.validate_and_parse_response(response)
|
||||
parsed_response = self.validate_and_parse_response(response)
|
||||
# For credential validation, always return as string
|
||||
return parsed_response.to_string()
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API
|
||||
|
|
@ -112,23 +129,36 @@ class ApiTool(Tool):
|
|||
|
||||
return headers
|
||||
|
||||
def validate_and_parse_response(self, response: httpx.Response) -> str:
|
||||
def validate_and_parse_response(self, response: httpx.Response) -> ParsedResponse:
|
||||
"""
|
||||
validate the response
|
||||
validate the response and return parsed content with type information
|
||||
|
||||
:return: ParsedResponse with content and is_json flag
|
||||
"""
|
||||
if isinstance(response, httpx.Response):
|
||||
if response.status_code >= 400:
|
||||
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
|
||||
if not response.content:
|
||||
return "Empty response from the tool, please check your parameters and try again."
|
||||
return ParsedResponse(
|
||||
"Empty response from the tool, please check your parameters and try again.", False
|
||||
)
|
||||
|
||||
# Check content type
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
is_json_content_type = "application/json" in content_type
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
response = response.json()
|
||||
try:
|
||||
return json.dumps(response, ensure_ascii=False)
|
||||
except Exception:
|
||||
return json.dumps(response)
|
||||
json_response = response.json()
|
||||
# If content-type indicates JSON, return as JSON object
|
||||
if is_json_content_type:
|
||||
return ParsedResponse(json_response, True)
|
||||
else:
|
||||
# If content-type doesn't indicate JSON, treat as text regardless of content
|
||||
return ParsedResponse(response.text, False)
|
||||
except Exception:
|
||||
return response.text
|
||||
# Not valid JSON, return as text
|
||||
return ParsedResponse(response.text, False)
|
||||
else:
|
||||
raise ValueError(f"Invalid response type {type(response)}")
|
||||
|
||||
|
|
@ -369,7 +399,14 @@ class ApiTool(Tool):
|
|||
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
|
||||
|
||||
# validate response
|
||||
response = self.validate_and_parse_response(response)
|
||||
parsed_response = self.validate_and_parse_response(response)
|
||||
|
||||
# assemble invoke message
|
||||
yield self.create_text_message(response)
|
||||
# assemble invoke message based on response type
|
||||
if parsed_response.is_json and isinstance(parsed_response.content, dict):
|
||||
yield self.create_json_message(parsed_response.content)
|
||||
else:
|
||||
# Convert to string if needed and create text message
|
||||
text_response = (
|
||||
parsed_response.content if isinstance(parsed_response.content, str) else str(parsed_response.content)
|
||||
)
|
||||
yield self.create_text_message(text_response)
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class Executor:
|
|||
self.auth = node_data.authorization
|
||||
self.timeout = timeout
|
||||
self.ssl_verify = node_data.ssl_verify
|
||||
self.params = []
|
||||
self.params = None
|
||||
self.headers = {}
|
||||
self.content = None
|
||||
self.files = None
|
||||
|
|
@ -139,7 +139,8 @@ class Executor:
|
|||
(self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text)
|
||||
)
|
||||
|
||||
self.params = result
|
||||
if result:
|
||||
self.params = result
|
||||
|
||||
def _init_headers(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -69,6 +69,19 @@ class Storage:
|
|||
from extensions.storage.supabase_storage import SupabaseStorage
|
||||
|
||||
return SupabaseStorage
|
||||
case StorageType.CLICKZETTA_VOLUME:
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
def create_clickzetta_volume_storage():
|
||||
# ClickZettaVolumeConfig will automatically read from environment variables
|
||||
# and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set
|
||||
volume_config = ClickZettaVolumeConfig()
|
||||
return ClickZettaVolumeStorage(volume_config)
|
||||
|
||||
return create_clickzetta_volume_storage
|
||||
case _:
|
||||
raise ValueError(f"unsupported storage type {storage_type}")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
from .clickzetta_volume_storage import ClickZettaVolumeStorage
|
||||
|
||||
__all__ = ["ClickZettaVolumeStorage"]
|
||||
|
|
@ -0,0 +1,530 @@
|
|||
"""ClickZetta Volume Storage Implementation
|
||||
|
||||
This module provides storage backend using ClickZetta Volume functionality.
|
||||
Supports Table Volume, User Volume, and External Volume types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import clickzetta # type: ignore[import]
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
from .volume_permissions import VolumePermissionManager, check_volume_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClickZettaVolumeConfig(BaseModel):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
instance: str = ""
|
||||
service: str = "api.clickzetta.com"
|
||||
workspace: str = "quick_start"
|
||||
vcluster: str = "default_ap"
|
||||
schema_name: str = "dify"
|
||||
volume_type: str = "table" # table|user|external
|
||||
volume_name: Optional[str] = None # For external volumes
|
||||
table_prefix: str = "dataset_" # Prefix for table volume names
|
||||
dify_prefix: str = "dify_km" # Directory prefix for User Volume
|
||||
permission_check: bool = True # Enable/disable permission checking
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""Validate the configuration values.
|
||||
|
||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
then fall back to CLICKZETTA_* environment variables (for vector DB config).
|
||||
"""
|
||||
import os
|
||||
|
||||
# Helper function to get environment variable with fallback
|
||||
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
|
||||
# First try CLICKZETTA_VOLUME_* specific config
|
||||
volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", ""))
|
||||
if volume_value:
|
||||
return str(volume_value)
|
||||
|
||||
# Then try environment variables
|
||||
volume_env = os.getenv(volume_key)
|
||||
if volume_env:
|
||||
return volume_env
|
||||
|
||||
# Fall back to existing CLICKZETTA_* config
|
||||
fallback_env = os.getenv(fallback_key)
|
||||
if fallback_env:
|
||||
return fallback_env
|
||||
|
||||
return default or ""
|
||||
|
||||
# Apply environment variables with fallback to existing CLICKZETTA_* config
|
||||
values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
|
||||
values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
|
||||
values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
|
||||
values.setdefault(
|
||||
"service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")
|
||||
)
|
||||
values.setdefault(
|
||||
"workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")
|
||||
)
|
||||
values.setdefault(
|
||||
"vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")
|
||||
)
|
||||
values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
|
||||
|
||||
# Volume-specific configurations (no fallback to vector DB config)
|
||||
values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table"))
|
||||
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
|
||||
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
|
||||
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
|
||||
# 暂时禁用权限检查功能,直接设置为false
|
||||
values.setdefault("permission_check", False)
|
||||
|
||||
# Validate required fields
|
||||
if not values.get("username"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required")
|
||||
if not values.get("instance"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required")
|
||||
|
||||
# Validate volume type
|
||||
volume_type = values["volume_type"]
|
||||
if volume_type not in ["table", "user", "external"]:
|
||||
raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external")
|
||||
|
||||
if volume_type == "external" and not values.get("volume_name"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class ClickZettaVolumeStorage(BaseStorage):
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
def __init__(self, config: ClickZettaVolumeConfig):
|
||||
"""Initialize ClickZetta Volume storage.
|
||||
|
||||
Args:
|
||||
config: ClickZetta Volume configuration
|
||||
"""
|
||||
self._config = config
|
||||
self._connection = None
|
||||
self._permission_manager: VolumePermissionManager | None = None
|
||||
self._init_connection()
|
||||
self._init_permission_manager()
|
||||
|
||||
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
|
||||
|
||||
def _init_connection(self):
|
||||
"""Initialize ClickZetta connection."""
|
||||
try:
|
||||
self._connection = clickzetta.connect(
|
||||
username=self._config.username,
|
||||
password=self._config.password,
|
||||
instance=self._config.instance,
|
||||
service=self._config.service,
|
||||
workspace=self._config.workspace,
|
||||
vcluster=self._config.vcluster,
|
||||
schema=self._config.schema_name,
|
||||
)
|
||||
logger.debug("ClickZetta connection established")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to connect to ClickZetta")
|
||||
raise
|
||||
|
||||
def _init_permission_manager(self):
|
||||
"""Initialize permission manager."""
|
||||
try:
|
||||
self._permission_manager = VolumePermissionManager(
|
||||
self._connection, self._config.volume_type, self._config.volume_name
|
||||
)
|
||||
logger.debug("Permission manager initialized")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to initialize permission manager")
|
||||
raise
|
||||
|
||||
def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str:
|
||||
"""Get the appropriate volume path based on volume type."""
|
||||
if self._config.volume_type == "user":
|
||||
# Add dify prefix for User Volume to organize files
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
elif self._config.volume_type == "table":
|
||||
# Check if this should use User Volume (special directories)
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
# Use User Volume with dify prefix for special directories
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
|
||||
if dataset_id:
|
||||
return f"{self._config.table_prefix}{dataset_id}/{filename}"
|
||||
else:
|
||||
# Extract dataset_id from filename if not provided
|
||||
# Format: dataset_id/filename
|
||||
if "/" in filename:
|
||||
return filename
|
||||
else:
|
||||
raise ValueError("dataset_id is required for table volume or filename must include dataset_id/")
|
||||
elif self._config.volume_type == "external":
|
||||
return filename
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str:
|
||||
"""Get SQL prefix for volume operations."""
|
||||
if self._config.volume_type == "user":
|
||||
return "USER VOLUME"
|
||||
elif self._config.volume_type == "table":
|
||||
# For Dify's current file storage pattern, most files are stored in
|
||||
# paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext"
|
||||
# These should use USER VOLUME for better compatibility
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return "USER VOLUME"
|
||||
|
||||
# Only use TABLE VOLUME for actual dataset-specific paths
|
||||
# like "dataset_12345/file.pdf" or paths with dataset_ prefix
|
||||
if dataset_id:
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
else:
|
||||
# Default table name for generic operations
|
||||
table_name = "default_dataset"
|
||||
return f"TABLE VOLUME {table_name}"
|
||||
elif self._config.volume_type == "external":
|
||||
return f"VOLUME {self._config.volume_name}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _execute_sql(self, sql: str, fetch: bool = False):
|
||||
"""Execute SQL command."""
|
||||
try:
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Connection not initialized")
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
if fetch:
|
||||
return cursor.fetchall()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("SQL execution failed: %s", sql)
|
||||
raise
|
||||
|
||||
def _ensure_table_volume_exists(self, dataset_id: str) -> None:
|
||||
"""Ensure table volume exists for the given dataset_id."""
|
||||
if self._config.volume_type != "table" or not dataset_id:
|
||||
return
|
||||
|
||||
# Skip for upload_files and other special directories that use USER VOLUME
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return
|
||||
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
|
||||
try:
|
||||
# Check if table exists
|
||||
check_sql = f"SHOW TABLES LIKE '{table_name}'"
|
||||
result = self._execute_sql(check_sql, fetch=True)
|
||||
|
||||
if not result:
|
||||
# Create table with volume
|
||||
create_sql = f"""
|
||||
CREATE TABLE {table_name} (
|
||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||
filename VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_filename (filename)
|
||||
) WITH VOLUME
|
||||
"""
|
||||
self._execute_sql(create_sql)
|
||||
logger.info("Created table volume: %s", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create table volume %s: %s", table_name, e)
|
||||
# Don't raise exception, let the operation continue
|
||||
# The table might exist but not be visible due to permissions
|
||||
|
||||
def save(self, filename: str, data: bytes) -> None:
|
||||
"""Save data to ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
data: File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Ensure table volume exists (for table volumes)
|
||||
if dataset_id:
|
||||
self._ensure_table_volume_exists(dataset_id)
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "save", dataset_id)
|
||||
|
||||
# Write data to temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
temp_file.write(data)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Upload to volume
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
actual_filename = volume_path.split("/")[-1] if "/" in volume_path else volume_path
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path)
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
Path(temp_file_path).unlink(missing_ok=True)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
"""Load file content from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "load_once", dataset_id)
|
||||
|
||||
# Download to temporary directory
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'"
|
||||
else:
|
||||
sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
# Find the downloaded file (may be in subdirectories)
|
||||
downloaded_file = None
|
||||
for root, dirs, files in os.walk(temp_dir):
|
||||
for file in files:
|
||||
if file == filename or file == os.path.basename(filename):
|
||||
downloaded_file = Path(root) / file
|
||||
break
|
||||
if downloaded_file:
|
||||
break
|
||||
|
||||
if not downloaded_file or not downloaded_file.exists():
|
||||
raise FileNotFoundError(f"Downloaded file not found: {filename}")
|
||||
|
||||
content = downloaded_file.read_bytes()
|
||||
|
||||
logger.debug("File %s loaded from ClickZetta Volume", filename)
|
||||
return content
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
"""Load file as stream from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Yields:
|
||||
File content chunks
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
batch_size = 4096
|
||||
stream = BytesIO(content)
|
||||
|
||||
while chunk := stream.read(batch_size):
|
||||
yield chunk
|
||||
|
||||
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
|
||||
|
||||
def download(self, filename: str, target_filepath: str):
|
||||
"""Download file from ClickZetta Volume to local path.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
target_filepath: Local target file path
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
|
||||
with Path(target_filepath).open("wb") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
|
||||
|
||||
def exists(self, filename: str) -> bool:
|
||||
"""Check if file exists in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
True if file exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
exists = len(rows) > 0
|
||||
logger.debug("File %s exists check: %s", filename, exists)
|
||||
return exists
|
||||
except Exception as e:
|
||||
logger.warning("Error checking file existence for %s: %s", filename, e)
|
||||
return False
|
||||
|
||||
def delete(self, filename: str):
|
||||
"""Delete file from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
"""
|
||||
if not self.exists(filename):
|
||||
logger.debug("File %s not found, skip delete", filename)
|
||||
return
|
||||
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"REMOVE {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"REMOVE {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
logger.debug("File %s deleted from ClickZetta Volume", filename)
|
||||
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
"""Scan files and directories in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
path: Path to scan (dataset_id for table volumes)
|
||||
files: Include files in results
|
||||
directories: Include directories in results
|
||||
|
||||
Returns:
|
||||
List of file/directory paths
|
||||
"""
|
||||
try:
|
||||
# For table volumes, path is treated as dataset_id
|
||||
dataset_id = None
|
||||
if self._config.volume_type == "table":
|
||||
dataset_id = path
|
||||
path = "" # Root of the table volume
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# For User Volume, add dify prefix to path
|
||||
if volume_prefix == "USER VOLUME":
|
||||
if path:
|
||||
scan_path = f"{self._config.dify_prefix}/{path}"
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'"
|
||||
else:
|
||||
if path:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix}"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
result = []
|
||||
for row in rows:
|
||||
file_path = row[0] # relative_path column
|
||||
|
||||
# For User Volume, remove dify prefix from results
|
||||
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
|
||||
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
|
||||
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
|
||||
|
||||
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
|
||||
result.append(file_path)
|
||||
|
||||
logger.debug("Scanned %d items in path %s", len(result), path)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error scanning path %s", path)
|
||||
return []
|
||||
|
|
@ -0,0 +1,516 @@
|
|||
"""ClickZetta Volume文件生命周期管理
|
||||
|
||||
该模块提供文件版本控制、自动清理、备份和恢复等生命周期管理功能。
|
||||
支持知识库文件的完整生命周期管理。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStatus(Enum):
|
||||
"""文件状态枚举"""
|
||||
|
||||
ACTIVE = "active" # 活跃状态
|
||||
ARCHIVED = "archived" # 已归档
|
||||
DELETED = "deleted" # 已删除(软删除)
|
||||
BACKUP = "backup" # 备份文件
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileMetadata:
|
||||
"""文件元数据"""
|
||||
|
||||
filename: str
|
||||
size: int | None
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
version: int | None
|
||||
status: FileStatus
|
||||
checksum: Optional[str] = None
|
||||
tags: Optional[dict[str, str]] = None
|
||||
parent_version: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
data = asdict(self)
|
||||
data["created_at"] = self.created_at.isoformat()
|
||||
data["modified_at"] = self.modified_at.isoformat()
|
||||
data["status"] = self.status.value
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "FileMetadata":
|
||||
"""从字典创建实例"""
|
||||
data = data.copy()
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
|
||||
data["status"] = FileStatus(data["status"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class FileLifecycleManager:
|
||||
"""文件生命周期管理器"""
|
||||
|
||||
def __init__(self, storage, dataset_id: Optional[str] = None):
|
||||
"""初始化生命周期管理器
|
||||
|
||||
Args:
|
||||
storage: ClickZetta Volume存储实例
|
||||
dataset_id: 数据集ID(用于Table Volume)
|
||||
"""
|
||||
self._storage = storage
|
||||
self._dataset_id = dataset_id
|
||||
self._metadata_file = ".dify_file_metadata.json"
|
||||
self._version_prefix = ".versions/"
|
||||
self._backup_prefix = ".backups/"
|
||||
self._deleted_prefix = ".deleted/"
|
||||
|
||||
# 获取权限管理器(如果存在)
|
||||
self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None)
|
||||
|
||||
def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata:
|
||||
"""保存文件并管理生命周期
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
data: 文件内容
|
||||
tags: 文件标签
|
||||
|
||||
Returns:
|
||||
文件元数据
|
||||
"""
|
||||
# 权限检查
|
||||
if not self._check_permission(filename, "save"):
|
||||
from .volume_permissions import VolumePermissionError
|
||||
|
||||
raise VolumePermissionError(
|
||||
f"Permission denied for lifecycle save operation on file: {filename}",
|
||||
operation="save",
|
||||
volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"),
|
||||
dataset_id=self._dataset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# 1. 检查是否存在旧版本
|
||||
metadata_dict = self._load_metadata()
|
||||
current_metadata = metadata_dict.get(filename)
|
||||
|
||||
# 2. 如果存在旧版本,创建版本备份
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata)
|
||||
|
||||
# 3. 计算文件信息
|
||||
now = datetime.now()
|
||||
checksum = self._calculate_checksum(data)
|
||||
new_version = (current_metadata["version"] + 1) if current_metadata else 1
|
||||
|
||||
# 4. 保存新文件
|
||||
self._storage.save(filename, data)
|
||||
|
||||
# 5. 创建元数据
|
||||
created_at = now
|
||||
parent_version = None
|
||||
|
||||
if current_metadata:
|
||||
# 如果created_at是字符串,转换为datetime
|
||||
if isinstance(current_metadata["created_at"], str):
|
||||
created_at = datetime.fromisoformat(current_metadata["created_at"])
|
||||
else:
|
||||
created_at = current_metadata["created_at"]
|
||||
parent_version = current_metadata["version"]
|
||||
|
||||
file_metadata = FileMetadata(
|
||||
filename=filename,
|
||||
size=len(data),
|
||||
created_at=created_at,
|
||||
modified_at=now,
|
||||
version=new_version,
|
||||
status=FileStatus.ACTIVE,
|
||||
checksum=checksum,
|
||||
tags=tags or {},
|
||||
parent_version=parent_version,
|
||||
)
|
||||
|
||||
# 6. 更新元数据
|
||||
metadata_dict[filename] = file_metadata.to_dict()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s saved with lifecycle management, version %s", filename, new_version)
|
||||
return file_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save file with lifecycle")
|
||||
raise
|
||||
|
||||
def get_file_metadata(self, filename: str) -> Optional[FileMetadata]:
|
||||
"""获取文件元数据
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
文件元数据,如果不存在返回None
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
return FileMetadata.from_dict(metadata_dict[filename])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get file metadata for %s", filename)
|
||||
return None
|
||||
|
||||
def list_file_versions(self, filename: str) -> list[FileMetadata]:
|
||||
"""列出文件的所有版本
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
文件版本列表,按版本号排序
|
||||
"""
|
||||
try:
|
||||
versions = []
|
||||
|
||||
# 获取当前版本
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
versions.append(current_metadata)
|
||||
|
||||
# 获取历史版本
|
||||
version_pattern = f"{self._version_prefix}{filename}.v*"
|
||||
try:
|
||||
version_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
for file_path in version_files:
|
||||
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
|
||||
# 解析版本号
|
||||
version_str = file_path.split(".v")[-1].split(".")[0]
|
||||
try:
|
||||
version_num = int(version_str)
|
||||
# 这里简化处理,实际应该从版本文件中读取元数据
|
||||
# 暂时创建基本的元数据信息
|
||||
except ValueError:
|
||||
continue
|
||||
except:
|
||||
# 如果无法扫描版本文件,只返回当前版本
|
||||
pass
|
||||
|
||||
return sorted(versions, key=lambda x: x.version or 0, reverse=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list file versions for %s", filename)
|
||||
return []
|
||||
|
||||
def restore_version(self, filename: str, version: int) -> bool:
|
||||
"""恢复文件到指定版本
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
version: 要恢复的版本号
|
||||
|
||||
Returns:
|
||||
恢复是否成功
|
||||
"""
|
||||
try:
|
||||
version_filename = f"{self._version_prefix}{filename}.v{version}"
|
||||
|
||||
# 检查版本文件是否存在
|
||||
if not self._storage.exists(version_filename):
|
||||
logger.warning("Version %s of %s not found", version, filename)
|
||||
return False
|
||||
|
||||
# 读取版本文件内容
|
||||
version_data = self._storage.load_once(version_filename)
|
||||
|
||||
# 保存当前版本为备份
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata.to_dict())
|
||||
|
||||
# 恢复文件
|
||||
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to restore %s to version %s", filename, version)
|
||||
return False
|
||||
|
||||
def archive_file(self, filename: str) -> bool:
|
||||
"""归档文件
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
归档是否成功
|
||||
"""
|
||||
# 权限检查
|
||||
if not self._check_permission(filename, "archive"):
|
||||
logger.warning("Permission denied for archive operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# 更新文件状态为归档
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename not in metadata_dict:
|
||||
logger.warning("File %s not found in metadata", filename)
|
||||
return False
|
||||
|
||||
metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s archived successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to archive file %s", filename)
|
||||
return False
|
||||
|
||||
def soft_delete_file(self, filename: str) -> bool:
|
||||
"""软删除文件(移动到删除目录)
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
# 权限检查
|
||||
if not self._check_permission(filename, "delete"):
|
||||
logger.warning("Permission denied for soft delete operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# 检查文件是否存在
|
||||
if not self._storage.exists(filename):
|
||||
logger.warning("File %s not found", filename)
|
||||
return False
|
||||
|
||||
# 读取文件内容
|
||||
file_data = self._storage.load_once(filename)
|
||||
|
||||
# 移动到删除目录
|
||||
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self._storage.save(deleted_filename, file_data)
|
||||
|
||||
# 删除原文件
|
||||
self._storage.delete(filename)
|
||||
|
||||
# 更新元数据
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
metadata_dict[filename]["status"] = FileStatus.DELETED.value
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s soft deleted successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to soft delete file %s", filename)
|
||||
return False
|
||||
|
||||
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
|
||||
"""清理旧版本文件
|
||||
|
||||
Args:
|
||||
max_versions: 保留的最大版本数
|
||||
max_age_days: 版本文件的最大保留天数
|
||||
|
||||
Returns:
|
||||
清理的文件数量
|
||||
"""
|
||||
try:
|
||||
cleaned_count = 0
|
||||
cutoff_date = datetime.now() - timedelta(days=max_age_days)
|
||||
|
||||
# 获取所有版本文件
|
||||
try:
|
||||
all_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
version_files = [f for f in all_files if f.startswith(self._version_prefix)]
|
||||
|
||||
# 按文件分组
|
||||
file_versions: dict[str, list[tuple[int, str]]] = {}
|
||||
for version_file in version_files:
|
||||
# 解析文件名和版本
|
||||
parts = version_file[len(self._version_prefix) :].split(".v")
|
||||
if len(parts) >= 2:
|
||||
base_filename = parts[0]
|
||||
version_part = parts[1].split(".")[0]
|
||||
try:
|
||||
version_num = int(version_part)
|
||||
if base_filename not in file_versions:
|
||||
file_versions[base_filename] = []
|
||||
file_versions[base_filename].append((version_num, version_file))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# 清理每个文件的旧版本
|
||||
for base_filename, versions in file_versions.items():
|
||||
# 按版本号排序
|
||||
versions.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 保留最新的max_versions个版本,删除其余的
|
||||
if len(versions) > max_versions:
|
||||
to_delete = versions[max_versions:]
|
||||
for version_num, version_file in to_delete:
|
||||
self._storage.delete(version_file)
|
||||
cleaned_count += 1
|
||||
logger.debug("Cleaned old version: %s", version_file)
|
||||
|
||||
logger.info("Cleaned %d old version files", cleaned_count)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not scan for version files: %s", e)
|
||||
|
||||
return cleaned_count
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to cleanup old versions")
|
||||
return 0
|
||||
|
||||
def get_storage_statistics(self) -> dict[str, Any]:
|
||||
"""获取存储统计信息
|
||||
|
||||
Returns:
|
||||
存储统计字典
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
|
||||
stats: dict[str, Any] = {
|
||||
"total_files": len(metadata_dict),
|
||||
"active_files": 0,
|
||||
"archived_files": 0,
|
||||
"deleted_files": 0,
|
||||
"total_size": 0,
|
||||
"versions_count": 0,
|
||||
"oldest_file": None,
|
||||
"newest_file": None,
|
||||
}
|
||||
|
||||
oldest_date = None
|
||||
newest_date = None
|
||||
|
||||
for filename, metadata in metadata_dict.items():
|
||||
file_meta = FileMetadata.from_dict(metadata)
|
||||
|
||||
# 统计文件状态
|
||||
if file_meta.status == FileStatus.ACTIVE:
|
||||
stats["active_files"] = (stats["active_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.ARCHIVED:
|
||||
stats["archived_files"] = (stats["archived_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.DELETED:
|
||||
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1
|
||||
|
||||
# 统计大小
|
||||
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)
|
||||
|
||||
# 统计版本
|
||||
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)
|
||||
|
||||
# 找出最新和最旧的文件
|
||||
if oldest_date is None or file_meta.created_at < oldest_date:
|
||||
oldest_date = file_meta.created_at
|
||||
stats["oldest_file"] = filename
|
||||
|
||||
if newest_date is None or file_meta.modified_at > newest_date:
|
||||
newest_date = file_meta.modified_at
|
||||
stats["newest_file"] = filename
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get storage statistics")
|
||||
return {}
|
||||
|
||||
def _create_version_backup(self, filename: str, metadata: dict):
|
||||
"""创建版本备份"""
|
||||
try:
|
||||
# 读取当前文件内容
|
||||
current_data = self._storage.load_once(filename)
|
||||
|
||||
# 保存为版本文件
|
||||
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
|
||||
self._storage.save(version_filename, current_data)
|
||||
|
||||
logger.debug("Created version backup: %s", version_filename)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create version backup for %s: %s", filename, e)
|
||||
|
||||
def _load_metadata(self) -> dict[str, Any]:
|
||||
"""加载元数据文件"""
|
||||
try:
|
||||
if self._storage.exists(self._metadata_file):
|
||||
metadata_content = self._storage.load_once(self._metadata_file)
|
||||
result = json.loads(metadata_content.decode("utf-8"))
|
||||
return dict(result) if result else {}
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load metadata: %s", e)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, metadata_dict: dict):
|
||||
"""保存元数据文件"""
|
||||
try:
|
||||
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
||||
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
|
||||
logger.debug("Metadata saved successfully")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save metadata")
|
||||
raise
|
||||
|
||||
def _calculate_checksum(self, data: bytes) -> str:
|
||||
"""计算文件校验和"""
|
||||
import hashlib
|
||||
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
def _check_permission(self, filename: str, operation: str) -> bool:
|
||||
"""检查文件操作权限
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
operation: 操作类型
|
||||
|
||||
Returns:
|
||||
True if permission granted, False otherwise
|
||||
"""
|
||||
# 如果没有权限管理器,默认允许
|
||||
if not self._permission_manager:
|
||||
return True
|
||||
|
||||
try:
|
||||
# 根据操作类型映射到权限
|
||||
operation_mapping = {
|
||||
"save": "save",
|
||||
"load": "load_once",
|
||||
"delete": "delete",
|
||||
"archive": "delete", # 归档需要删除权限
|
||||
"restore": "save", # 恢复需要写权限
|
||||
"cleanup": "delete", # 清理需要删除权限
|
||||
"read": "load_once",
|
||||
"write": "save",
|
||||
}
|
||||
|
||||
mapped_operation = operation_mapping.get(operation, operation)
|
||||
|
||||
# 检查权限
|
||||
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Permission check failed for %s operation %s", filename, operation)
|
||||
# 安全默认:权限检查失败时拒绝访问
|
||||
return False
|
||||
|
|
@ -0,0 +1,646 @@
|
|||
"""ClickZetta Volume权限管理机制
|
||||
|
||||
该模块提供Volume权限检查、验证和管理功能。
|
||||
根据ClickZetta的权限模型,不同Volume类型有不同的权限要求。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VolumePermission(Enum):
|
||||
"""Volume权限类型枚举"""
|
||||
|
||||
READ = "SELECT" # 对应ClickZetta的SELECT权限
|
||||
WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限
|
||||
LIST = "SELECT" # 列出文件需要SELECT权限
|
||||
DELETE = "INSERT,UPDATE,DELETE" # 删除文件需要写权限
|
||||
USAGE = "USAGE" # External Volume需要的基本权限
|
||||
|
||||
|
||||
class VolumePermissionManager:
|
||||
"""Volume权限管理器"""
|
||||
|
||||
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None):
|
||||
"""初始化权限管理器
|
||||
|
||||
Args:
|
||||
connection_or_config: ClickZetta连接对象或配置字典
|
||||
volume_type: Volume类型 (user|table|external)
|
||||
volume_name: Volume名称 (用于external volume)
|
||||
"""
|
||||
# 支持两种初始化方式:连接对象或配置字典
|
||||
if isinstance(connection_or_config, dict):
|
||||
# 从配置字典创建连接
|
||||
import clickzetta # type: ignore[import-untyped]
|
||||
|
||||
config = connection_or_config
|
||||
self._connection = clickzetta.connect(
|
||||
username=config.get("username"),
|
||||
password=config.get("password"),
|
||||
instance=config.get("instance"),
|
||||
service=config.get("service"),
|
||||
workspace=config.get("workspace"),
|
||||
vcluster=config.get("vcluster"),
|
||||
schema=config.get("schema") or config.get("database"),
|
||||
)
|
||||
self._volume_type = config.get("volume_type", volume_type)
|
||||
self._volume_name = config.get("volume_name", volume_name)
|
||||
else:
|
||||
# 直接使用连接对象
|
||||
self._connection = connection_or_config
|
||||
self._volume_type = volume_type
|
||||
self._volume_name = volume_name
|
||||
|
||||
if not self._connection:
|
||||
raise ValueError("Valid connection or config is required")
|
||||
if not self._volume_type:
|
||||
raise ValueError("volume_type is required")
|
||||
|
||||
self._permission_cache: dict[str, set[str]] = {}
|
||||
self._current_username = None # 将从连接中获取当前用户名
|
||||
|
||||
def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool:
|
||||
"""检查用户是否有执行特定操作的权限
|
||||
|
||||
Args:
|
||||
operation: 要执行的操作类型
|
||||
dataset_id: 数据集ID (用于table volume)
|
||||
|
||||
Returns:
|
||||
True if user has permission, False otherwise
|
||||
"""
|
||||
try:
|
||||
if self._volume_type == "user":
|
||||
return self._check_user_volume_permission(operation)
|
||||
elif self._volume_type == "table":
|
||||
return self._check_table_volume_permission(operation, dataset_id)
|
||||
elif self._volume_type == "external":
|
||||
return self._check_external_volume_permission(operation)
|
||||
else:
|
||||
logger.warning("Unknown volume type: %s", self._volume_type)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Permission check failed")
|
||||
return False
|
||||
|
||||
def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
|
||||
"""检查User Volume权限
|
||||
|
||||
User Volume权限规则:
|
||||
- 用户对自己的User Volume有全部权限
|
||||
- 只要用户能够连接到ClickZetta,就默认具有User Volume的基本权限
|
||||
- 更注重连接身份验证,而不是复杂的权限检查
|
||||
"""
|
||||
try:
|
||||
# 获取当前用户名
|
||||
current_user = self._get_current_username()
|
||||
|
||||
# 检查基本连接状态
|
||||
with self._connection.cursor() as cursor:
|
||||
# 简单的连接测试,如果能执行查询说明用户有基本权限
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
logger.debug(
|
||||
"User Volume permission check for %s, operation %s: granted (basic connection verified)",
|
||||
current_user,
|
||||
operation.name,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"User Volume permission check failed: cannot verify basic connection for %s", current_user
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("User Volume permission check failed")
|
||||
# 对于User Volume,如果权限检查失败,可能是配置问题,给出更友好的错误提示
|
||||
logger.info("User Volume permission check failed, but permission checking is disabled in this version")
|
||||
return False
|
||||
|
||||
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool:
|
||||
"""检查Table Volume权限
|
||||
|
||||
Table Volume权限规则:
|
||||
- Table Volume权限继承对应表的权限
|
||||
- SELECT权限 -> 可以READ/LIST文件
|
||||
- INSERT,UPDATE,DELETE权限 -> 可以WRITE/DELETE文件
|
||||
"""
|
||||
if not dataset_id:
|
||||
logger.warning("dataset_id is required for table volume permission check")
|
||||
return False
|
||||
|
||||
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id
|
||||
|
||||
try:
|
||||
# 检查表权限
|
||||
permissions = self._get_table_permissions(table_name)
|
||||
required_permissions = set(operation.value.split(","))
|
||||
|
||||
# 检查是否有所需的所有权限
|
||||
has_permission = required_permissions.issubset(permissions)
|
||||
|
||||
logger.debug(
|
||||
"Table Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
|
||||
table_name,
|
||||
operation.name,
|
||||
required_permissions,
|
||||
permissions,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
return has_permission
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Table volume permission check failed for %s", table_name)
|
||||
return False
|
||||
|
||||
def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
|
||||
"""检查External Volume权限
|
||||
|
||||
External Volume权限规则:
|
||||
- 尝试获取对External Volume的权限
|
||||
- 如果权限检查失败,进行备选验证
|
||||
- 对于开发环境,提供更宽松的权限检查
|
||||
"""
|
||||
if not self._volume_name:
|
||||
logger.warning("volume_name is required for external volume permission check")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 检查External Volume权限
|
||||
permissions = self._get_external_volume_permissions(self._volume_name)
|
||||
|
||||
# External Volume权限映射:根据操作类型确定所需权限
|
||||
required_permissions = set()
|
||||
|
||||
if operation in [VolumePermission.READ, VolumePermission.LIST]:
|
||||
required_permissions.add("read")
|
||||
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
|
||||
required_permissions.add("write")
|
||||
|
||||
# 检查是否有所需的所有权限
|
||||
has_permission = required_permissions.issubset(permissions)
|
||||
|
||||
logger.debug(
|
||||
"External Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
|
||||
self._volume_name,
|
||||
operation.name,
|
||||
required_permissions,
|
||||
permissions,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
# 如果权限检查失败,尝试备选验证
|
||||
if not has_permission:
|
||||
logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name)
|
||||
|
||||
# 备选验证:尝试列出Volume来验证基本访问权限
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == self._volume_name:
|
||||
logger.info("Fallback verification successful for %s", self._volume_name)
|
||||
return True
|
||||
except Exception as fallback_e:
|
||||
logger.warning("Fallback verification failed for %s: %s", self._volume_name, fallback_e)
|
||||
|
||||
return has_permission
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("External volume permission check failed for %s", self._volume_name)
|
||||
logger.info("External Volume permission check failed, but permission checking is disabled in this version")
|
||||
return False
|
||||
|
||||
def _get_table_permissions(self, table_name: str) -> set[str]:
|
||||
"""获取用户对指定表的权限
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
用户对该表的权限集合
|
||||
"""
|
||||
cache_key = f"table:{table_name}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# 使用正确的ClickZetta语法检查当前用户权限
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
# 解析权限结果,查找对该表的权限
|
||||
for grant in grants:
|
||||
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
|
||||
privilege = grant[0].upper()
|
||||
object_type = grant[1].upper() if len(grant) > 1 else ""
|
||||
object_name = grant[2] if len(grant) > 2 else ""
|
||||
|
||||
# 检查是否是对该表的权限
|
||||
if (
|
||||
object_type == "TABLE"
|
||||
and object_name == table_name
|
||||
or object_type == "SCHEMA"
|
||||
and object_name in table_name
|
||||
):
|
||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||
if privilege == "ALL":
|
||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||
else:
|
||||
permissions.add(privilege)
|
||||
|
||||
# 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
|
||||
permissions.add("SELECT")
|
||||
except Exception:
|
||||
logger.debug("Cannot query table %s, no SELECT permission", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check table permissions for %s: %s", table_name, e)
|
||||
# 安全默认:权限检查失败时拒绝访问
|
||||
pass
|
||||
|
||||
# 缓存权限信息
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def _get_current_username(self) -> str:
|
||||
"""获取当前用户名"""
|
||||
if self._current_username:
|
||||
return self._current_username
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
self._current_username = result[0]
|
||||
return str(self._current_username)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get current username")
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _get_user_permissions(self, username: str) -> set[str]:
|
||||
"""获取用户的基本权限集合"""
|
||||
cache_key = f"user_permissions:{username}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# 使用正确的ClickZetta语法检查当前用户权限
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
# 解析权限结果,查找用户的基本权限
|
||||
for grant in grants:
|
||||
if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...)
|
||||
privilege = grant[0].upper()
|
||||
object_type = grant[1].upper() if len(grant) > 1 else ""
|
||||
|
||||
# 收集所有相关权限
|
||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||
if privilege == "ALL":
|
||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||
else:
|
||||
permissions.add(privilege)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check user permissions for %s: %s", username, e)
|
||||
# 安全默认:权限检查失败时拒绝访问
|
||||
pass
|
||||
|
||||
# 缓存权限信息
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
|
||||
"""获取用户对指定External Volume的权限
|
||||
|
||||
Args:
|
||||
volume_name: External Volume名称
|
||||
|
||||
Returns:
|
||||
用户对该Volume的权限集合
|
||||
"""
|
||||
cache_key = f"external_volume:{volume_name}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# 使用正确的ClickZetta语法检查Volume权限
|
||||
logger.info("Checking permissions for volume: %s", volume_name)
|
||||
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
logger.info("Raw grants result for %s: %s", volume_name, grants)
|
||||
|
||||
# 解析权限结果
|
||||
# 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
|
||||
# grantee_name, grantor_name, grant_option, granted_time)
|
||||
for grant in grants:
|
||||
logger.info("Processing grant: %s", grant)
|
||||
if len(grant) >= 5:
|
||||
granted_type = grant[0]
|
||||
privilege = grant[1].upper()
|
||||
granted_on = grant[3]
|
||||
object_name = grant[4]
|
||||
|
||||
logger.info(
|
||||
"Grant details - type: %s, privilege: %s, granted_on: %s, object_name: %s",
|
||||
granted_type,
|
||||
privilege,
|
||||
granted_on,
|
||||
object_name,
|
||||
)
|
||||
|
||||
# 检查是否是对该Volume的权限或者是层级权限
|
||||
if (
|
||||
granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
|
||||
) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
|
||||
logger.info("Matching grant found for %s", volume_name)
|
||||
|
||||
if "READ" in privilege:
|
||||
permissions.add("read")
|
||||
logger.info("Added READ permission for %s", volume_name)
|
||||
if "WRITE" in privilege:
|
||||
permissions.add("write")
|
||||
logger.info("Added WRITE permission for %s", volume_name)
|
||||
if "ALTER" in privilege:
|
||||
permissions.add("alter")
|
||||
logger.info("Added ALTER permission for %s", volume_name)
|
||||
if privilege == "ALL":
|
||||
permissions.update(["read", "write", "alter"])
|
||||
logger.info("Added ALL permissions for %s", volume_name)
|
||||
|
||||
logger.info("Final permissions for %s: %s", volume_name, permissions)
|
||||
|
||||
# 如果没有找到明确的权限,尝试查看Volume列表来验证基本权限
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == volume_name:
|
||||
permissions.add("read") # 至少有读权限
|
||||
logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name)
|
||||
break
|
||||
except Exception:
|
||||
logger.debug("Cannot access volume %s, no basic permission", volume_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check external volume permissions for %s: %s", volume_name, e)
|
||||
# 在权限检查失败时,尝试基本的Volume访问验证
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == volume_name:
|
||||
logger.info("Basic volume access verified for %s", volume_name)
|
||||
permissions.add("read")
|
||||
permissions.add("write") # 假设有写权限
|
||||
break
|
||||
except Exception as basic_e:
|
||||
logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e)
|
||||
# 最后的备选方案:假设有基本权限
|
||||
permissions.add("read")
|
||||
|
||||
# 缓存权限信息
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def clear_permission_cache(self):
|
||||
"""清空权限缓存"""
|
||||
self._permission_cache.clear()
|
||||
logger.debug("Permission cache cleared")
|
||||
|
||||
def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]:
|
||||
"""获取权限摘要
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集ID (用于table volume)
|
||||
|
||||
Returns:
|
||||
权限摘要字典
|
||||
"""
|
||||
summary = {}
|
||||
|
||||
for operation in VolumePermission:
|
||||
summary[operation.name.lower()] = self.check_permission(operation, dataset_id)
|
||||
|
||||
return summary
|
||||
|
||||
def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
|
||||
"""检查文件路径的权限继承
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
operation: 要执行的操作
|
||||
|
||||
Returns:
|
||||
True if user has permission, False otherwise
|
||||
"""
|
||||
try:
|
||||
# 解析文件路径
|
||||
path_parts = file_path.strip("/").split("/")
|
||||
|
||||
if not path_parts:
|
||||
logger.warning("Invalid file path for permission inheritance check")
|
||||
return False
|
||||
|
||||
# 对于Table Volume,第一层是dataset_id
|
||||
if self._volume_type == "table":
|
||||
if len(path_parts) < 1:
|
||||
return False
|
||||
|
||||
dataset_id = path_parts[0]
|
||||
|
||||
# 检查对dataset的权限
|
||||
has_dataset_permission = self.check_permission(operation, dataset_id)
|
||||
|
||||
if not has_dataset_permission:
|
||||
logger.debug("Permission denied for dataset %s", dataset_id)
|
||||
return False
|
||||
|
||||
# 检查路径遍历攻击
|
||||
if self._contains_path_traversal(file_path):
|
||||
logger.warning("Path traversal attack detected: %s", file_path)
|
||||
return False
|
||||
|
||||
# 检查是否访问敏感目录
|
||||
if self._is_sensitive_path(file_path):
|
||||
logger.warning("Access to sensitive path denied: %s", file_path)
|
||||
return False
|
||||
|
||||
logger.debug("Permission inherited for path %s", file_path)
|
||||
return True
|
||||
|
||||
elif self._volume_type == "user":
|
||||
# User Volume的权限继承
|
||||
current_user = self._get_current_username()
|
||||
|
||||
# 检查是否试图访问其他用户的目录
|
||||
if len(path_parts) > 1 and path_parts[0] != current_user:
|
||||
logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0])
|
||||
return False
|
||||
|
||||
# 检查基本权限
|
||||
return self.check_permission(operation)
|
||||
|
||||
elif self._volume_type == "external":
|
||||
# External Volume的权限继承
|
||||
# 检查对External Volume的权限
|
||||
return self.check_permission(operation)
|
||||
|
||||
else:
|
||||
logger.warning("Unknown volume type for permission inheritance: %s", self._volume_type)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Permission inheritance check failed")
|
||||
return False
|
||||
|
||||
def _contains_path_traversal(self, file_path: str) -> bool:
|
||||
"""检查路径是否包含路径遍历攻击"""
|
||||
# 检查常见的路径遍历模式
|
||||
traversal_patterns = [
|
||||
"../",
|
||||
"..\\",
|
||||
"..%2f",
|
||||
"..%2F",
|
||||
"..%5c",
|
||||
"..%5C",
|
||||
"%2e%2e%2f",
|
||||
"%2e%2e%5c",
|
||||
"....//",
|
||||
"....\\\\",
|
||||
]
|
||||
|
||||
file_path_lower = file_path.lower()
|
||||
|
||||
for pattern in traversal_patterns:
|
||||
if pattern in file_path_lower:
|
||||
return True
|
||||
|
||||
# 检查绝对路径
|
||||
if file_path.startswith("/") or file_path.startswith("\\"):
|
||||
return True
|
||||
|
||||
# 检查Windows驱动器路径
|
||||
if len(file_path) >= 2 and file_path[1] == ":":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_sensitive_path(self, file_path: str) -> bool:
|
||||
"""检查路径是否为敏感路径"""
|
||||
sensitive_patterns = [
|
||||
"passwd",
|
||||
"shadow",
|
||||
"hosts",
|
||||
"config",
|
||||
"secrets",
|
||||
"private",
|
||||
"key",
|
||||
"certificate",
|
||||
"cert",
|
||||
"ssl",
|
||||
"database",
|
||||
"backup",
|
||||
"dump",
|
||||
"log",
|
||||
"tmp",
|
||||
]
|
||||
|
||||
file_path_lower = file_path.lower()
|
||||
|
||||
return any(pattern in file_path_lower for pattern in sensitive_patterns)
|
||||
|
||||
def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool:
|
||||
"""验证操作权限
|
||||
|
||||
Args:
|
||||
operation: 操作名称 (save|load|exists|delete|scan)
|
||||
dataset_id: 数据集ID
|
||||
|
||||
Returns:
|
||||
True if operation is allowed, False otherwise
|
||||
"""
|
||||
operation_mapping = {
|
||||
"save": VolumePermission.WRITE,
|
||||
"load": VolumePermission.READ,
|
||||
"load_once": VolumePermission.READ,
|
||||
"load_stream": VolumePermission.READ,
|
||||
"download": VolumePermission.READ,
|
||||
"exists": VolumePermission.READ,
|
||||
"delete": VolumePermission.DELETE,
|
||||
"scan": VolumePermission.LIST,
|
||||
}
|
||||
|
||||
if operation not in operation_mapping:
|
||||
logger.warning("Unknown operation: %s", operation)
|
||||
return False
|
||||
|
||||
volume_permission = operation_mapping[operation]
|
||||
return self.check_permission(volume_permission, dataset_id)
|
||||
|
||||
|
||||
class VolumePermissionError(Exception):
|
||||
"""Volume权限错误异常"""
|
||||
|
||||
def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None):
|
||||
self.operation = operation
|
||||
self.volume_type = volume_type
|
||||
self.dataset_id = dataset_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_volume_permission(
|
||||
permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""权限检查装饰器函数
|
||||
|
||||
Args:
|
||||
permission_manager: 权限管理器
|
||||
operation: 操作名称
|
||||
dataset_id: 数据集ID
|
||||
|
||||
Raises:
|
||||
VolumePermissionError: 如果没有权限
|
||||
"""
|
||||
if not permission_manager.validate_operation(operation, dataset_id):
|
||||
error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume"
|
||||
if dataset_id:
|
||||
error_message += f" (dataset: {dataset_id})"
|
||||
|
||||
raise VolumePermissionError(
|
||||
error_message,
|
||||
operation=operation,
|
||||
volume_type=permission_manager._volume_type or "unknown",
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
|
@ -5,6 +5,7 @@ class StorageType(StrEnum):
|
|||
ALIYUN_OSS = "aliyun-oss"
|
||||
AZURE_BLOB = "azure-blob"
|
||||
BAIDU_OBS = "baidu-obs"
|
||||
CLICKZETTA_VOLUME = "clickzetta-volume"
|
||||
GOOGLE_STORAGE = "google-storage"
|
||||
HUAWEI_OBS = "huawei-obs"
|
||||
LOCAL = "local"
|
||||
|
|
|
|||
|
|
@ -194,6 +194,7 @@ vdb = [
|
|||
"alibabacloud_tea_openapi~=0.3.9",
|
||||
"chromadb==0.5.20",
|
||||
"clickhouse-connect~=0.7.16",
|
||||
"clickzetta-connector-python>=0.8.102",
|
||||
"couchbase~=4.3.0",
|
||||
"elasticsearch==8.14.0",
|
||||
"opensearch-py==2.4.0",
|
||||
|
|
@ -213,3 +214,4 @@ vdb = [
|
|||
"xinference-client~=1.2.2",
|
||||
"mo-vector~=0.1.13",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,168 @@
|
|||
"""Integration tests for ClickZetta Volume Storage."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
|
||||
class TestClickZettaVolumeStorage(unittest.TestCase):
|
||||
"""Test cases for ClickZetta Volume Storage."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.config = ClickZettaVolumeConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_pass"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "uat-api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"),
|
||||
volume_type="table",
|
||||
table_prefix="test_dataset_",
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
|
||||
def test_user_volume_operations(self):
|
||||
"""Test basic operations with User Volume."""
|
||||
config = self.config
|
||||
config.volume_type = "user"
|
||||
|
||||
storage = ClickZettaVolumeStorage(config)
|
||||
|
||||
# Test file operations
|
||||
test_filename = "test_file.txt"
|
||||
test_content = b"Hello, ClickZetta Volume!"
|
||||
|
||||
# Save file
|
||||
storage.save(test_filename, test_content)
|
||||
|
||||
# Check if file exists
|
||||
assert storage.exists(test_filename)
|
||||
|
||||
# Load file
|
||||
loaded_content = storage.load_once(test_filename)
|
||||
assert loaded_content == test_content
|
||||
|
||||
# Test streaming
|
||||
stream_content = b""
|
||||
for chunk in storage.load_stream(test_filename):
|
||||
stream_content += chunk
|
||||
assert stream_content == test_content
|
||||
|
||||
# Test download
|
||||
with tempfile.NamedTemporaryFile() as temp_file:
|
||||
storage.download(test_filename, temp_file.name)
|
||||
with open(temp_file.name, "rb") as f:
|
||||
downloaded_content = f.read()
|
||||
assert downloaded_content == test_content
|
||||
|
||||
# Test scan
|
||||
files = storage.scan("", files=True, directories=False)
|
||||
assert test_filename in files
|
||||
|
||||
# Delete file
|
||||
storage.delete(test_filename)
|
||||
assert not storage.exists(test_filename)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
|
||||
def test_table_volume_operations(self):
|
||||
"""Test basic operations with Table Volume."""
|
||||
config = self.config
|
||||
config.volume_type = "table"
|
||||
|
||||
storage = ClickZettaVolumeStorage(config)
|
||||
|
||||
# Test file operations with dataset_id
|
||||
dataset_id = "12345"
|
||||
test_filename = f"{dataset_id}/test_file.txt"
|
||||
test_content = b"Hello, Table Volume!"
|
||||
|
||||
# Save file
|
||||
storage.save(test_filename, test_content)
|
||||
|
||||
# Check if file exists
|
||||
assert storage.exists(test_filename)
|
||||
|
||||
# Load file
|
||||
loaded_content = storage.load_once(test_filename)
|
||||
assert loaded_content == test_content
|
||||
|
||||
# Test scan for dataset
|
||||
files = storage.scan(dataset_id, files=True, directories=False)
|
||||
assert "test_file.txt" in files
|
||||
|
||||
# Delete file
|
||||
storage.delete(test_filename)
|
||||
assert not storage.exists(test_filename)
|
||||
|
||||
def test_config_validation(self):
|
||||
"""Test configuration validation."""
|
||||
# Test missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(
|
||||
username="", # Empty username should fail
|
||||
password="pass",
|
||||
instance="instance",
|
||||
)
|
||||
|
||||
# Test invalid volume type
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(username="user", password="pass", instance="instance", volume_type="invalid_type")
|
||||
|
||||
# Test external volume without volume_name
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(
|
||||
username="user",
|
||||
password="pass",
|
||||
instance="instance",
|
||||
volume_type="external",
|
||||
# Missing volume_name
|
||||
)
|
||||
|
||||
def test_volume_path_generation(self):
|
||||
"""Test volume path generation for different types."""
|
||||
storage = ClickZettaVolumeStorage(self.config)
|
||||
|
||||
# Test table volume path
|
||||
path = storage._get_volume_path("test.txt", "12345")
|
||||
assert path == "test_dataset_12345/test.txt"
|
||||
|
||||
# Test path with existing dataset_id prefix
|
||||
path = storage._get_volume_path("12345/test.txt")
|
||||
assert path == "12345/test.txt"
|
||||
|
||||
# Test user volume
|
||||
storage._config.volume_type = "user"
|
||||
path = storage._get_volume_path("test.txt")
|
||||
assert path == "test.txt"
|
||||
|
||||
def test_sql_prefix_generation(self):
|
||||
"""Test SQL prefix generation for different volume types."""
|
||||
storage = ClickZettaVolumeStorage(self.config)
|
||||
|
||||
# Test table volume SQL prefix
|
||||
prefix = storage._get_volume_sql_prefix("12345")
|
||||
assert prefix == "TABLE VOLUME test_dataset_12345"
|
||||
|
||||
# Test user volume SQL prefix
|
||||
storage._config.volume_type = "user"
|
||||
prefix = storage._get_volume_sql_prefix()
|
||||
assert prefix == "USER VOLUME"
|
||||
|
||||
# Test external volume SQL prefix
|
||||
storage._config.volume_type = "external"
|
||||
storage._config.volume_name = "my_external_volume"
|
||||
prefix = storage._get_volume_sql_prefix()
|
||||
assert prefix == "VOLUME my_external_volume"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
# Clickzetta Integration Tests
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the Clickzetta integration tests, you need to set the following environment variables:
|
||||
|
||||
```bash
|
||||
export CLICKZETTA_USERNAME=your_username
|
||||
export CLICKZETTA_PASSWORD=your_password
|
||||
export CLICKZETTA_INSTANCE=your_instance
|
||||
export CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
export CLICKZETTA_WORKSPACE=your_workspace
|
||||
export CLICKZETTA_VCLUSTER=your_vcluster
|
||||
export CLICKZETTA_SCHEMA=dify
|
||||
```
|
||||
|
||||
Then run the tests:
|
||||
|
||||
```bash
|
||||
pytest api/tests/integration_tests/vdb/clickzetta/
|
||||
```
|
||||
|
||||
## Security Note
|
||||
|
||||
Never commit credentials to the repository. Always use environment variables or secure credential management systems.
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class TestClickzettaVector(AbstractVectorTest):
|
||||
"""
|
||||
Test cases for Clickzetta vector database integration.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(self):
|
||||
"""Create a Clickzetta vector store instance for testing."""
|
||||
# Skip test if Clickzetta credentials are not configured
|
||||
if not os.getenv("CLICKZETTA_USERNAME"):
|
||||
pytest.skip("CLICKZETTA_USERNAME is not configured")
|
||||
if not os.getenv("CLICKZETTA_PASSWORD"):
|
||||
pytest.skip("CLICKZETTA_PASSWORD is not configured")
|
||||
if not os.getenv("CLICKZETTA_INSTANCE"):
|
||||
pytest.skip("CLICKZETTA_INSTANCE is not configured")
|
||||
|
||||
config = ClickzettaConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", ""),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", ""),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", ""),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
|
||||
batch_size=10, # Small batch size for testing
|
||||
enable_inverted_index=True,
|
||||
analyzer_type="chinese",
|
||||
analyzer_mode="smart",
|
||||
vector_distance_function="cosine_distance",
|
||||
)
|
||||
|
||||
with setup_mock_redis():
|
||||
vector = ClickzettaVector(
|
||||
collection_name="test_collection_" + str(os.getpid()),
|
||||
config=config
|
||||
)
|
||||
|
||||
yield vector
|
||||
|
||||
# Cleanup: delete the test collection
|
||||
try:
|
||||
vector.delete()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_clickzetta_vector_basic_operations(self, vector_store):
|
||||
"""Test basic CRUD operations on Clickzetta vector store."""
|
||||
# Prepare test data
|
||||
texts = [
|
||||
"这是第一个测试文档,包含一些中文内容。",
|
||||
"This is the second test document with English content.",
|
||||
"第三个文档混合了English和中文内容。",
|
||||
]
|
||||
embeddings = [
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.5, 0.6, 0.7, 0.8],
|
||||
[0.9, 1.0, 1.1, 1.2],
|
||||
]
|
||||
documents = [
|
||||
Document(page_content=text, metadata={"doc_id": f"doc_{i}", "source": "test"})
|
||||
for i, text in enumerate(texts)
|
||||
]
|
||||
|
||||
# Test create (initial insert)
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test text_exists
|
||||
assert vector_store.text_exists("doc_0")
|
||||
assert not vector_store.text_exists("doc_999")
|
||||
|
||||
# Test search_by_vector
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
results = vector_store.search_by_vector(query_vector, top_k=2)
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == texts[0] # Should match the first document
|
||||
|
||||
# Test search_by_full_text (Chinese)
|
||||
results = vector_store.search_by_full_text("中文", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with Chinese content
|
||||
|
||||
# Test search_by_full_text (English)
|
||||
results = vector_store.search_by_full_text("English", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with English content
|
||||
|
||||
# Test delete_by_ids
|
||||
vector_store.delete_by_ids(["doc_0"])
|
||||
assert not vector_store.text_exists("doc_0")
|
||||
assert vector_store.text_exists("doc_1")
|
||||
|
||||
# Test delete_by_metadata_field
|
||||
vector_store.delete_by_metadata_field("source", "test")
|
||||
assert not vector_store.text_exists("doc_1")
|
||||
assert not vector_store.text_exists("doc_2")
|
||||
|
||||
def test_clickzetta_vector_advanced_search(self, vector_store):
|
||||
"""Test advanced search features of Clickzetta vector store."""
|
||||
# Prepare test data with more complex metadata
|
||||
documents = []
|
||||
embeddings = []
|
||||
for i in range(10):
|
||||
doc = Document(
|
||||
page_content=f"Document {i}: " + get_example_text(),
|
||||
metadata={
|
||||
"doc_id": f"adv_doc_{i}",
|
||||
"category": "technical" if i % 2 == 0 else "general",
|
||||
"document_id": f"doc_{i // 3}", # Group documents
|
||||
"importance": i,
|
||||
}
|
||||
)
|
||||
documents.append(doc)
|
||||
# Create varied embeddings
|
||||
embeddings.append([0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i])
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test vector search with document filter
|
||||
query_vector = [0.5, 1.0, 1.5, 2.0]
|
||||
results = vector_store.search_by_vector(
|
||||
query_vector,
|
||||
top_k=5,
|
||||
document_ids_filter=["doc_0", "doc_1"]
|
||||
)
|
||||
assert len(results) > 0
|
||||
# All results should belong to doc_0 or doc_1 groups
|
||||
for result in results:
|
||||
assert result.metadata["document_id"] in ["doc_0", "doc_1"]
|
||||
|
||||
# Test score threshold
|
||||
results = vector_store.search_by_vector(
|
||||
query_vector,
|
||||
top_k=10,
|
||||
score_threshold=0.5
|
||||
)
|
||||
# Check that all results have a score above threshold
|
||||
for result in results:
|
||||
assert result.metadata.get("score", 0) >= 0.5
|
||||
|
||||
def test_clickzetta_batch_operations(self, vector_store):
|
||||
"""Test batch insertion operations."""
|
||||
# Prepare large batch of documents
|
||||
batch_size = 25
|
||||
documents = []
|
||||
embeddings = []
|
||||
|
||||
for i in range(batch_size):
|
||||
doc = Document(
|
||||
page_content=f"Batch document {i}: This is a test document for batch processing.",
|
||||
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"}
|
||||
)
|
||||
documents.append(doc)
|
||||
embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)])
|
||||
|
||||
# Test batch insert
|
||||
vector_store.add_texts(documents=documents, embeddings=embeddings)
|
||||
|
||||
# Verify all documents were inserted
|
||||
for i in range(batch_size):
|
||||
assert vector_store.text_exists(f"batch_doc_{i}")
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("batch", "test_batch")
|
||||
|
||||
def test_clickzetta_edge_cases(self, vector_store):
|
||||
"""Test edge cases and error handling."""
|
||||
# Test empty operations
|
||||
vector_store.create(texts=[], embeddings=[])
|
||||
vector_store.add_texts(documents=[], embeddings=[])
|
||||
vector_store.delete_by_ids([])
|
||||
|
||||
# Test special characters in content
|
||||
special_doc = Document(
|
||||
page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline",
|
||||
metadata={"doc_id": "special_doc", "test": "edge_case"}
|
||||
)
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4]]
|
||||
|
||||
vector_store.add_texts(documents=[special_doc], embeddings=embeddings)
|
||||
assert vector_store.text_exists("special_doc")
|
||||
|
||||
# Test search with special characters
|
||||
results = vector_store.search_by_full_text("quotes", top_k=1)
|
||||
if results: # Full-text search might not be available
|
||||
assert len(results) > 0
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_ids(["special_doc"])
|
||||
|
||||
def test_clickzetta_full_text_search_modes(self, vector_store):
|
||||
"""Test different full-text search capabilities."""
|
||||
# Prepare documents with various language content
|
||||
documents = [
|
||||
Document(
|
||||
page_content="云器科技提供强大的Lakehouse解决方案",
|
||||
metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Clickzetta provides powerful Lakehouse solutions",
|
||||
metadata={"doc_id": "en_doc_1", "lang": "english"}
|
||||
),
|
||||
Document(
|
||||
page_content="Lakehouse是现代数据架构的重要组成部分",
|
||||
metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Modern data architecture includes Lakehouse technology",
|
||||
metadata={"doc_id": "en_doc_2", "lang": "english"}
|
||||
),
|
||||
]
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4] for _ in documents]
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test Chinese full-text search
|
||||
results = vector_store.search_by_full_text("Lakehouse", top_k=4)
|
||||
assert len(results) >= 2 # Should find at least documents with "Lakehouse"
|
||||
|
||||
# Test English full-text search
|
||||
results = vector_store.search_by_full_text("solutions", top_k=2)
|
||||
assert len(results) >= 1 # Should find English documents with "solutions"
|
||||
|
||||
# Test mixed search
|
||||
results = vector_store.search_by_full_text("数据架构", top_k=2)
|
||||
assert len(results) >= 1 # Should find Chinese documents with this phrase
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("lang", "chinese")
|
||||
vector_store.delete_by_metadata_field("lang", "english")
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Clickzetta integration in Docker environment
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
from clickzetta import connect
|
||||
|
||||
|
||||
def test_clickzetta_connection():
|
||||
"""Test direct connection to Clickzetta"""
|
||||
print("=== Testing direct Clickzetta connection ===")
|
||||
try:
|
||||
conn = connect(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_password"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"),
|
||||
database=os.getenv("CLICKZETTA_SCHEMA", "dify")
|
||||
)
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Test basic connectivity
|
||||
cursor.execute("SELECT 1 as test")
|
||||
result = cursor.fetchone()
|
||||
print(f"✓ Connection test: {result}")
|
||||
|
||||
# Check if our test table exists
|
||||
cursor.execute("SHOW TABLES IN dify")
|
||||
tables = cursor.fetchall()
|
||||
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
|
||||
|
||||
# Check if test collection exists
|
||||
test_collection = "collection_test_dataset"
|
||||
if test_collection in [t[1] for t in tables if t[0] == 'dify']:
|
||||
cursor.execute(f"DESCRIBE dify.{test_collection}")
|
||||
columns = cursor.fetchall()
|
||||
print(f"✓ Table structure for {test_collection}:")
|
||||
for col in columns:
|
||||
print(f" - {col[0]}: {col[1]}")
|
||||
|
||||
# Check for indexes
|
||||
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
|
||||
indexes = cursor.fetchall()
|
||||
print(f"✓ Indexes on {test_collection}:")
|
||||
for idx in indexes:
|
||||
print(f" - {idx}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_dify_api():
|
||||
"""Test Dify API with Clickzetta backend"""
|
||||
print("\n=== Testing Dify API ===")
|
||||
base_url = "http://localhost:5001"
|
||||
|
||||
# Wait for API to be ready
|
||||
max_retries = 30
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = requests.get(f"{base_url}/console/api/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ Dify API is ready")
|
||||
break
|
||||
except:
|
||||
if i == max_retries - 1:
|
||||
print("✗ Dify API is not responding")
|
||||
return False
|
||||
time.sleep(2)
|
||||
|
||||
# Check vector store configuration
|
||||
try:
|
||||
# This is a simplified check - in production, you'd use proper auth
|
||||
print("✓ Dify is configured to use Clickzetta as vector store")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ API test failed: {e}")
|
||||
return False
|
||||
|
||||
def verify_table_structure():
|
||||
"""Verify the table structure meets Dify requirements"""
|
||||
print("\n=== Verifying Table Structure ===")
|
||||
|
||||
expected_columns = {
|
||||
"id": "VARCHAR",
|
||||
"page_content": "VARCHAR",
|
||||
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
|
||||
"vector": "ARRAY<FLOAT>"
|
||||
}
|
||||
|
||||
expected_metadata_fields = [
|
||||
"doc_id",
|
||||
"doc_hash",
|
||||
"document_id",
|
||||
"dataset_id"
|
||||
]
|
||||
|
||||
print("✓ Expected table structure:")
|
||||
for col, dtype in expected_columns.items():
|
||||
print(f" - {col}: {dtype}")
|
||||
|
||||
print("\n✓ Required metadata fields:")
|
||||
for field in expected_metadata_fields:
|
||||
print(f" - {field}")
|
||||
|
||||
print("\n✓ Index requirements:")
|
||||
print(" - Vector index (HNSW) on 'vector' column")
|
||||
print(" - Full-text index on 'page_content' (optional)")
|
||||
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
|
||||
print(" - Functional index on metadata->>'$.document_id' (recommended)")
|
||||
|
||||
return True
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("Starting Clickzetta integration tests for Dify Docker\n")
|
||||
|
||||
tests = [
|
||||
("Direct Clickzetta Connection", test_clickzetta_connection),
|
||||
("Dify API Status", test_dify_api),
|
||||
("Table Structure Verification", verify_table_structure),
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((test_name, success))
|
||||
except Exception as e:
|
||||
print(f"\n✗ {test_name} crashed: {e}")
|
||||
results.append((test_name, False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*50)
|
||||
print("Test Summary:")
|
||||
print("="*50)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
total = len(results)
|
||||
|
||||
for test_name, success in results:
|
||||
status = "✅ PASSED" if success else "❌ FAILED"
|
||||
print(f"{test_name}: {status}")
|
||||
|
||||
print(f"\nTotal: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
|
||||
print("\nNext steps:")
|
||||
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
|
||||
print("2. Access Dify at http://localhost:3000")
|
||||
print("3. Create a dataset and test vector storage with Clickzetta")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the errors above.")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
|
|
@ -0,0 +1,928 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from constants.model_template import default_app_templates
|
||||
from models.model import App, Site
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class TestAppService:
|
||||
"""Integration tests for AppService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
):
|
||||
# Setup default mock returns for app service
|
||||
mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False
|
||||
mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None
|
||||
mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None
|
||||
|
||||
# Setup default mock returns for account service
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
# Mock ModelManager for model configuration
|
||||
mock_model_instance = mock_model_manager.return_value
|
||||
mock_model_instance.get_default_model_instance.return_value = None
|
||||
mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo")
|
||||
|
||||
yield {
|
||||
"feature_service": mock_feature_service,
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
"model_manager": mock_model_manager,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app creation with basic parameters.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Create app
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Verify app was created correctly
|
||||
assert app.name == app_args["name"]
|
||||
assert app.description == app_args["description"]
|
||||
assert app.mode == app_args["mode"]
|
||||
assert app.icon_type == app_args["icon_type"]
|
||||
assert app.icon == app_args["icon"]
|
||||
assert app.icon_background == app_args["icon_background"]
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.api_rph == app_args["api_rph"]
|
||||
assert app.api_rpm == app_args["api_rpm"]
|
||||
assert app.created_by == account.id
|
||||
assert app.updated_by == account.id
|
||||
assert app.status == "normal"
|
||||
assert app.enable_site is True
|
||||
assert app.enable_api is True
|
||||
assert app.is_demo is False
|
||||
assert app.is_public is False
|
||||
assert app.is_universal is False
|
||||
|
||||
def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app creation with different app modes.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Test different app modes
|
||||
# from AppMode enum in default_app_model_template
|
||||
app_modes = [v.value for v in default_app_templates]
|
||||
|
||||
for mode in app_modes:
|
||||
app_args = {
|
||||
"name": f"{fake.company()} {mode}",
|
||||
"description": f"Test app for {mode} mode",
|
||||
"mode": mode,
|
||||
"icon_type": "emoji",
|
||||
"icon": "🚀",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Verify app mode was set correctly
|
||||
assert app.mode == mode
|
||||
assert app.name == app_args["name"]
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.created_by == account.id
|
||||
|
||||
def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app retrieval.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
created_app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Get app using the service
|
||||
retrieved_app = app_service.get_app(created_app)
|
||||
|
||||
# Verify retrieved app matches created app
|
||||
assert retrieved_app.id == created_app.id
|
||||
assert retrieved_app.name == created_app.name
|
||||
assert retrieved_app.description == created_app.description
|
||||
assert retrieved_app.mode == created_app.mode
|
||||
assert retrieved_app.tenant_id == created_app.tenant_id
|
||||
assert retrieved_app.created_by == created_app.created_by
|
||||
|
||||
def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful paginated app list retrieval.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create multiple apps
|
||||
app_names = [fake.company() for _ in range(5)]
|
||||
for name in app_names:
|
||||
app_args = {
|
||||
"name": name,
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "📱",
|
||||
"icon_background": "#96CEB4",
|
||||
}
|
||||
app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Get paginated apps
|
||||
args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
}
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
|
||||
# Verify pagination results
|
||||
assert paginated_apps is not None
|
||||
assert len(paginated_apps.items) >= 5 # Should have at least 5 apps
|
||||
assert paginated_apps.page == 1
|
||||
assert paginated_apps.per_page == 10
|
||||
|
||||
# Verify all apps belong to the correct tenant
|
||||
for app in paginated_apps.items:
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.mode == "chat"
|
||||
|
||||
def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test paginated app list with various filters.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create apps with different modes
|
||||
chat_app_args = {
|
||||
"name": "Chat App",
|
||||
"description": "A chat application",
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "💬",
|
||||
"icon_background": "#FF6B6B",
|
||||
}
|
||||
completion_app_args = {
|
||||
"name": "Completion App",
|
||||
"description": "A completion application",
|
||||
"mode": "completion",
|
||||
"icon_type": "emoji",
|
||||
"icon": "✍️",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
|
||||
chat_app = app_service.create_app(tenant.id, chat_app_args, account)
|
||||
completion_app = app_service.create_app(tenant.id, completion_app_args, account)
|
||||
|
||||
# Test filter by mode
|
||||
chat_args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
}
|
||||
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, chat_args)
|
||||
assert len(chat_apps.items) == 1
|
||||
assert chat_apps.items[0].mode == "chat"
|
||||
|
||||
# Test filter by name
|
||||
name_args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
"name": "Chat",
|
||||
}
|
||||
filtered_apps = app_service.get_paginate_apps(account.id, tenant.id, name_args)
|
||||
assert len(filtered_apps.items) == 1
|
||||
assert "Chat" in filtered_apps.items[0].name
|
||||
|
||||
# Test filter by created_by_me
|
||||
created_by_me_args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "completion",
|
||||
"is_created_by_me": True,
|
||||
}
|
||||
my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args)
|
||||
assert len(my_apps.items) == 1
|
||||
|
||||
def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test paginated app list with tag filters.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create an app
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🏷️",
|
||||
"icon_background": "#FFEAA7",
|
||||
}
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Mock TagService to return the app ID for tag filtering
|
||||
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
|
||||
mock_tag_service.return_value = [app.id]
|
||||
|
||||
# Test with tag filter
|
||||
args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
"tag_ids": ["tag1", "tag2"],
|
||||
}
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
|
||||
# Verify tag service was called
|
||||
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"])
|
||||
|
||||
# Verify results
|
||||
assert paginated_apps is not None
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].id == app.id
|
||||
|
||||
# Test with tag filter that returns no results
|
||||
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
|
||||
mock_tag_service.return_value = []
|
||||
|
||||
args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
"tag_ids": ["nonexistent_tag"],
|
||||
}
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
|
||||
# Should return None when no apps match tag filter
|
||||
assert paginated_apps is None
|
||||
|
||||
def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app update with all fields.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original values
|
||||
original_name = app.name
|
||||
original_description = app.description
|
||||
original_icon = app.icon
|
||||
original_icon_background = app.icon_background
|
||||
original_use_icon_as_answer_icon = app.use_icon_as_answer_icon
|
||||
|
||||
# Update app
|
||||
update_args = {
|
||||
"name": "Updated App Name",
|
||||
"description": "Updated app description",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔄",
|
||||
"icon_background": "#FF8C42",
|
||||
"use_icon_as_answer_icon": True,
|
||||
}
|
||||
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app(app, update_args)
|
||||
|
||||
# Verify updated fields
|
||||
assert updated_app.name == update_args["name"]
|
||||
assert updated_app.description == update_args["description"]
|
||||
assert updated_app.icon == update_args["icon"]
|
||||
assert updated_app.icon_background == update_args["icon_background"]
|
||||
assert updated_app.use_icon_as_answer_icon is True
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app name update.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original name
|
||||
original_name = app.name
|
||||
|
||||
# Update app name
|
||||
new_name = "New App Name"
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app_name(app, new_name)
|
||||
|
||||
assert updated_app.name == new_name
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app icon update.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original values
|
||||
original_icon = app.icon
|
||||
original_icon_background = app.icon_background
|
||||
|
||||
# Update app icon
|
||||
new_icon = "🌟"
|
||||
new_icon_background = "#FFD93D"
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app_icon(app, new_icon, new_icon_background)
|
||||
|
||||
assert updated_app.icon == new_icon
|
||||
assert updated_app.icon_background == new_icon_background
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.name == app.name
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app site status update.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🌐",
|
||||
"icon_background": "#74B9FF",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original site status
|
||||
original_site_status = app.enable_site
|
||||
|
||||
# Update site status to disabled
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app_site_status(app, False)
|
||||
assert updated_app.enable_site is False
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Update site status back to enabled
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app_site_status(updated_app, True)
|
||||
assert updated_app.enable_site is True
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.name == app.name
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app API status update.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔌",
|
||||
"icon_background": "#A29BFE",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original API status
|
||||
original_api_status = app.enable_api
|
||||
|
||||
# Update API status to disabled
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app_api_status(app, False)
|
||||
assert updated_app.enable_api is False
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Update API status back to enabled
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
updated_app = app_service.update_app_api_status(updated_app, True)
|
||||
assert updated_app.enable_api is True
|
||||
assert updated_app.updated_by == account.id
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.name == app.name
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app site status update when status doesn't change.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔄",
|
||||
"icon_background": "#FD79A8",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store original values
|
||||
original_site_status = app.enable_site
|
||||
original_updated_at = app.updated_at
|
||||
|
||||
# Update site status to the same value (no change)
|
||||
updated_app = app_service.update_app_site_status(app, original_site_status)
|
||||
|
||||
# Verify app is returned unchanged
|
||||
assert updated_app.id == app.id
|
||||
assert updated_app.enable_site == original_site_status
|
||||
assert updated_app.updated_at == original_updated_at
|
||||
|
||||
# Verify other fields remain unchanged
|
||||
assert updated_app.name == app.name
|
||||
assert updated_app.description == app.description
|
||||
assert updated_app.mode == app.mode
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app deletion.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🗑️",
|
||||
"icon_background": "#E17055",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store app ID for verification
|
||||
app_id = app.id
|
||||
|
||||
# Mock the async deletion task
|
||||
with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task:
|
||||
mock_delete_task.delay.return_value = None
|
||||
|
||||
# Delete app
|
||||
app_service.delete_app(app)
|
||||
|
||||
# Verify async deletion task was called
|
||||
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
|
||||
|
||||
# Verify app was deleted from database
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_app = db.session.query(App).filter_by(id=app_id).first()
|
||||
assert deleted_app is None
|
||||
|
||||
def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app deletion with related data cleanup.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🧹",
|
||||
"icon_background": "#00B894",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Store app ID for verification
|
||||
app_id = app.id
|
||||
|
||||
# Mock webapp auth cleanup
|
||||
mock_external_service_dependencies[
|
||||
"feature_service"
|
||||
].get_system_features.return_value.webapp_auth.enabled = True
|
||||
|
||||
# Mock the async deletion task
|
||||
with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task:
|
||||
mock_delete_task.delay.return_value = None
|
||||
|
||||
# Delete app
|
||||
app_service.delete_app(app)
|
||||
|
||||
# Verify webapp auth cleanup was called
|
||||
mock_external_service_dependencies["enterprise_service"].WebAppAuth.cleanup_webapp.assert_called_once_with(
|
||||
app_id
|
||||
)
|
||||
|
||||
# Verify async deletion task was called
|
||||
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
|
||||
|
||||
# Verify app was deleted from database
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_app = db.session.query(App).filter_by(id=app_id).first()
|
||||
assert deleted_app is None
|
||||
|
||||
def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app metadata retrieval.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "📊",
|
||||
"icon_background": "#6C5CE7",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Get app metadata
|
||||
app_meta = app_service.get_app_meta(app)
|
||||
|
||||
# Verify metadata contains expected fields
|
||||
assert "tool_icons" in app_meta
|
||||
# Note: get_app_meta currently only returns tool_icons
|
||||
|
||||
def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app code retrieval by app ID.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔗",
|
||||
"icon_background": "#FDCB6E",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Get app code by ID
|
||||
app_code = AppService.get_app_code_by_id(app.id)
|
||||
|
||||
# Verify app code was retrieved correctly
|
||||
# Note: Site would be created when App is created, site.code is auto-generated
|
||||
assert app_code is not None
|
||||
assert len(app_code) > 0
|
||||
|
||||
def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app ID retrieval by app code.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🆔",
|
||||
"icon_background": "#E84393",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Create a site for the app
|
||||
site = Site()
|
||||
site.app_id = app.id
|
||||
site.code = fake.postalcode()
|
||||
site.title = fake.company()
|
||||
site.status = "normal"
|
||||
site.default_language = "en-US"
|
||||
site.customize_token_strategy = "uuid"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(site)
|
||||
db.session.commit()
|
||||
|
||||
# Get app ID by code
|
||||
app_id = AppService.get_app_id_by_code(site.code)
|
||||
|
||||
# Verify app ID was retrieved correctly
|
||||
assert app_id == app.id
|
||||
|
||||
def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app creation with invalid mode.
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments with invalid mode
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "invalid_mode", # Invalid mode
|
||||
"icon_type": "emoji",
|
||||
"icon": "❌",
|
||||
"icon_background": "#D63031",
|
||||
}
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Attempt to create app with invalid mode
|
||||
with pytest.raises(ValueError, match="invalid mode value"):
|
||||
app_service.create_app(tenant.id, app_args, account)
|
||||
|
|
@ -49,7 +49,7 @@ def test_executor_with_json_body_and_number_variable():
|
|||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == []
|
||||
assert executor.params is None
|
||||
assert executor.json == {"number": 42}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
|
|
@ -102,7 +102,7 @@ def test_executor_with_json_body_and_object_variable():
|
|||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == []
|
||||
assert executor.params is None
|
||||
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
|
|
@ -157,7 +157,7 @@ def test_executor_with_json_body_and_nested_object_variable():
|
|||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == []
|
||||
assert executor.params is None
|
||||
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
|
|
@ -245,7 +245,7 @@ def test_executor_with_form_data():
|
|||
assert executor.url == "https://api.example.com/upload"
|
||||
assert "Content-Type" in executor.headers
|
||||
assert "multipart/form-data" in executor.headers["Content-Type"]
|
||||
assert executor.params == []
|
||||
assert executor.params is None
|
||||
assert executor.json is None
|
||||
# '__multipart_placeholder__' is expected when no file inputs exist,
|
||||
# to ensure the request is treated as multipart/form-data by the backend.
|
||||
|
|
|
|||
58
api/uv.lock
58
api/uv.lock
|
|
@ -983,6 +983,25 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/42/1f/935d0810b73184a1d306f92458cb0a2e9b0de2377f536da874e063b8e422/clickhouse_connect-0.7.19-cp312-cp312-win_amd64.whl", hash = "sha256:b771ca6a473d65103dcae82810d3a62475c5372fc38d8f211513c72b954fb020", size = 239584, upload-time = "2024-08-21T21:36:22.105Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clickzetta-connector-python"
|
||||
version = "0.8.102"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "future" },
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pandas" },
|
||||
{ name = "pyarrow" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "requests" },
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/e5/23dcc950e873127df0135cf45144062a3207f5d2067259c73854e8ce7228/clickzetta_connector_python-0.8.102-py3-none-any.whl", hash = "sha256:c45486ae77fd82df7113ec67ec50e772372588d79c23757f8ee6291a057994a7", size = 77861, upload-time = "2025-07-17T03:11:59.543Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cloudscraper"
|
||||
version = "1.2.71"
|
||||
|
|
@ -1383,6 +1402,7 @@ vdb = [
|
|||
{ name = "alibabacloud-tea-openapi" },
|
||||
{ name = "chromadb" },
|
||||
{ name = "clickhouse-connect" },
|
||||
{ name = "clickzetta-connector-python" },
|
||||
{ name = "couchbase" },
|
||||
{ name = "elasticsearch" },
|
||||
{ name = "mo-vector" },
|
||||
|
|
@ -1568,6 +1588,7 @@ vdb = [
|
|||
{ name = "alibabacloud-tea-openapi", specifier = "~=0.3.9" },
|
||||
{ name = "chromadb", specifier = "==0.5.20" },
|
||||
{ name = "clickhouse-connect", specifier = "~=0.7.16" },
|
||||
{ name = "clickzetta-connector-python", specifier = ">=0.8.102" },
|
||||
{ name = "couchbase", specifier = "~=4.3.0" },
|
||||
{ name = "elasticsearch", specifier = "==8.14.0" },
|
||||
{ name = "mo-vector", specifier = "~=0.1.13" },
|
||||
|
|
@ -2111,7 +2132,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "google-cloud-bigquery"
|
||||
version = "3.34.0"
|
||||
version = "3.30.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "google-api-core", extra = ["grpc"] },
|
||||
|
|
@ -2122,9 +2143,9 @@ dependencies = [
|
|||
{ name = "python-dateutil" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/24/f9/e9da2d56d7028f05c0e2f5edf6ce43c773220c3172666c3dd925791d763d/google_cloud_bigquery-3.34.0.tar.gz", hash = "sha256:5ee1a78ba5c2ccb9f9a8b2bf3ed76b378ea68f49b6cac0544dc55cc97ff7c1ce", size = 489091, upload-time = "2025-05-29T17:18:06.03Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/2f/3dda76b3ec029578838b1fe6396e6b86eb574200352240e23dea49265bb7/google_cloud_bigquery-3.30.0.tar.gz", hash = "sha256:7e27fbafc8ed33cc200fe05af12ecd74d279fe3da6692585a3cef7aee90575b6", size = 474389, upload-time = "2025-02-27T18:49:45.416Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/7e/7115c4f67ca0bc678f25bff1eab56cc37d06eb9a3978940b2ebd0705aa0a/google_cloud_bigquery-3.34.0-py3-none-any.whl", hash = "sha256:de20ded0680f8136d92ff5256270b5920dfe4fae479f5d0f73e90e5df30b1cf7", size = 253555, upload-time = "2025-05-29T17:18:02.904Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/6d/856a6ca55c1d9d99129786c929a27dd9d31992628ebbff7f5d333352981f/google_cloud_bigquery-3.30.0-py2.py3-none-any.whl", hash = "sha256:f4d28d846a727f20569c9b2d2f4fa703242daadcb2ec4240905aa485ba461877", size = 247885, upload-time = "2025-02-27T18:49:43.454Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -3918,11 +3939,11 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "24.2"
|
||||
version = "23.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950, upload-time = "2024-11-08T09:47:47.202Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fb/2b/9b9c33ffed44ee921d0967086d653047286054117d584f1b1a7c22ceaf7b/packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5", size = 146714, upload-time = "2023-10-01T13:50:05.279Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451, upload-time = "2024-11-08T09:47:44.722Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/1a/610693ac4ee14fcdf2d9bf3c493370e4f2ef7ae2e19217d7a237ff42367d/packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7", size = 53011, upload-time = "2023-10-01T13:50:03.745Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -4302,6 +4323,31 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "14.0.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d7/8b/d18b7eb6fb22e5ed6ffcbc073c85dae635778dbd1270a6cf5d750b031e84/pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025", size = 1063645, upload-time = "2023-12-18T15:43:41.625Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/94/8a/411ef0b05483076b7f548c74ccaa0f90c1e60d3875db71a821f6ffa8cf42/pyarrow-14.0.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:87482af32e5a0c0cce2d12eb3c039dd1d853bd905b04f3f953f147c7a196915b", size = 26904455, upload-time = "2023-12-18T15:40:43.477Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/6c/882a57798877e3a49ba54d8e0540bea24aed78fb42e1d860f08c3449c75e/pyarrow-14.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:059bd8f12a70519e46cd64e1ba40e97eae55e0cbe1695edd95384653d7626b23", size = 23997116, upload-time = "2023-12-18T15:40:48.533Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/3f/ef47fe6192ce4d82803a073db449b5292135406c364a7fc49dfbcd34c987/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f16111f9ab27e60b391c5f6d197510e3ad6654e73857b4e394861fc79c37200", size = 35944575, upload-time = "2023-12-18T15:40:55.128Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1a/90/2021e529d7f234a3909f419d4341d53382541ef77d957fa274a99c533b18/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06ff1264fe4448e8d02073f5ce45a9f934c0f3db0a04460d0b01ff28befc3696", size = 38079719, upload-time = "2023-12-18T15:41:02.565Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/a9/474caf5fd54a6d5315aaf9284c6e8f5d071ca825325ad64c53137b646e1f/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd4f4b472ccf4042f1eab77e6c8bce574543f54d2135c7e396f413046397d5a", size = 35429706, upload-time = "2023-12-18T15:41:09.955Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/f8/cfba56f5353e51c19b0c240380ce39483f4c76e5c4aee5a000f3d75b72da/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:32356bfb58b36059773f49e4e214996888eeea3a08893e7dbde44753799b2a02", size = 38001476, upload-time = "2023-12-18T15:41:16.372Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/3f/7bdf7dc3b3b0cfdcc60760e7880954ba99ccd0bc1e0df806f3dd61bc01cd/pyarrow-14.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:52809ee69d4dbf2241c0e4366d949ba035cbcf48409bf404f071f624ed313a2b", size = 24576230, upload-time = "2023-12-18T15:41:22.561Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/69/5b/d8ab6c20c43b598228710e4e4a6cba03a01f6faa3d08afff9ce76fd0fd47/pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944", size = 26819585, upload-time = "2023-12-18T15:41:27.59Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2d/29/bed2643d0dd5e9570405244a61f6db66c7f4704a6e9ce313f84fa5a3675a/pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5", size = 23965222, upload-time = "2023-12-18T15:41:32.449Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/34/da464632e59a8cdd083370d69e6c14eae30221acb284f671c6bc9273fadd/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422", size = 35942036, upload-time = "2023-12-18T15:41:38.767Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/ff/cbed4836d543b29f00d2355af67575c934999ff1d43e3f438ab0b1b394f1/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07", size = 38089266, upload-time = "2023-12-18T15:41:47.617Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/38/41/345011cb831d3dbb2dab762fc244c745a5df94b199223a99af52a5f7dff6/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591", size = 35404468, upload-time = "2023-12-18T15:41:54.49Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/af/2fc23ca2068ff02068d8dabf0fb85b6185df40ec825973470e613dbd8790/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379", size = 38003134, upload-time = "2023-12-18T15:42:01.593Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/1f/9d912f66a87e3864f694e000977a6a70a644ea560289eac1d733983f215d/pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d", size = 25043754, upload-time = "2023-12-18T15:42:07.108Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.1"
|
||||
|
|
|
|||
|
|
@ -333,6 +333,25 @@ OPENDAL_SCHEME=fs
|
|||
# Configurations for OpenDAL Local File System.
|
||||
OPENDAL_FS_ROOT=storage
|
||||
|
||||
# ClickZetta Volume Configuration (for storage backend)
|
||||
# To use ClickZetta Volume as storage backend, set STORAGE_TYPE=clickzetta-volume
|
||||
# Note: ClickZetta Volume will reuse the existing CLICKZETTA_* connection parameters
|
||||
|
||||
# Volume type selection (three types available):
|
||||
# - user: Personal/small team use, simple config, user-level permissions
|
||||
# - table: Enterprise multi-tenant, smart routing, table-level + user-level permissions
|
||||
# - external: Data lake integration, external storage connection, volume-level + storage-level permissions
|
||||
CLICKZETTA_VOLUME_TYPE=user
|
||||
|
||||
# External Volume name (required only when TYPE=external)
|
||||
CLICKZETTA_VOLUME_NAME=
|
||||
|
||||
# Table Volume table prefix (used only when TYPE=table)
|
||||
CLICKZETTA_VOLUME_TABLE_PREFIX=dataset_
|
||||
|
||||
# Dify file directory prefix (isolates from other apps, recommended to keep default)
|
||||
CLICKZETTA_VOLUME_DIFY_PREFIX=dify_km
|
||||
|
||||
# S3 Configuration
|
||||
#
|
||||
S3_ENDPOINT=
|
||||
|
|
@ -416,7 +435,7 @@ SUPABASE_URL=your-server-url
|
|||
# ------------------------------
|
||||
|
||||
# The type of vector store to use.
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
|
|
@ -655,6 +674,20 @@ TABLESTORE_ACCESS_KEY_ID=xxx
|
|||
TABLESTORE_ACCESS_KEY_SECRET=xxx
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false
|
||||
|
||||
# Clickzetta configuration, only available when VECTOR_STORE is `clickzetta`
|
||||
CLICKZETTA_USERNAME=
|
||||
CLICKZETTA_PASSWORD=
|
||||
CLICKZETTA_INSTANCE=
|
||||
CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
CLICKZETTA_WORKSPACE=quick_start
|
||||
CLICKZETTA_VCLUSTER=default_ap
|
||||
CLICKZETTA_SCHEMA=dify
|
||||
CLICKZETTA_BATCH_SIZE=100
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX=true
|
||||
CLICKZETTA_ANALYZER_TYPE=chinese
|
||||
CLICKZETTA_ANALYZER_MODE=smart
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance
|
||||
|
||||
# ------------------------------
|
||||
# Knowledge Configuration
|
||||
# ------------------------------
|
||||
|
|
|
|||
|
|
@ -93,6 +93,10 @@ x-shared-env: &shared-api-worker-env
|
|||
STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
|
||||
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
|
||||
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
|
||||
CLICKZETTA_VOLUME_TYPE: ${CLICKZETTA_VOLUME_TYPE:-user}
|
||||
CLICKZETTA_VOLUME_NAME: ${CLICKZETTA_VOLUME_NAME:-}
|
||||
CLICKZETTA_VOLUME_TABLE_PREFIX: ${CLICKZETTA_VOLUME_TABLE_PREFIX:-dataset_}
|
||||
CLICKZETTA_VOLUME_DIFY_PREFIX: ${CLICKZETTA_VOLUME_DIFY_PREFIX:-dify_km}
|
||||
S3_ENDPOINT: ${S3_ENDPOINT:-}
|
||||
S3_REGION: ${S3_REGION:-us-east-1}
|
||||
S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai}
|
||||
|
|
@ -313,6 +317,18 @@ x-shared-env: &shared-api-worker-env
|
|||
TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx}
|
||||
TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx}
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false}
|
||||
CLICKZETTA_USERNAME: ${CLICKZETTA_USERNAME:-}
|
||||
CLICKZETTA_PASSWORD: ${CLICKZETTA_PASSWORD:-}
|
||||
CLICKZETTA_INSTANCE: ${CLICKZETTA_INSTANCE:-}
|
||||
CLICKZETTA_SERVICE: ${CLICKZETTA_SERVICE:-api.clickzetta.com}
|
||||
CLICKZETTA_WORKSPACE: ${CLICKZETTA_WORKSPACE:-quick_start}
|
||||
CLICKZETTA_VCLUSTER: ${CLICKZETTA_VCLUSTER:-default_ap}
|
||||
CLICKZETTA_SCHEMA: ${CLICKZETTA_SCHEMA:-dify}
|
||||
CLICKZETTA_BATCH_SIZE: ${CLICKZETTA_BATCH_SIZE:-100}
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX: ${CLICKZETTA_ENABLE_INVERTED_INDEX:-true}
|
||||
CLICKZETTA_ANALYZER_TYPE: ${CLICKZETTA_ANALYZER_TYPE:-chinese}
|
||||
CLICKZETTA_ANALYZER_MODE: ${CLICKZETTA_ANALYZER_MODE:-smart}
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: ${CLICKZETTA_VECTOR_DISTANCE_FUNCTION:-cosine_distance}
|
||||
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
|
||||
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
|
||||
ETL_TYPE: ${ETL_TYPE:-dify}
|
||||
|
|
|
|||
|
|
@ -1,41 +0,0 @@
|
|||
'use client'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiAddLine,
|
||||
RiArrowRightLine,
|
||||
} from '@remixicon/react'
|
||||
import Link from 'next/link'
|
||||
|
||||
type CreateAppCardProps = {
|
||||
ref?: React.Ref<HTMLAnchorElement>
|
||||
}
|
||||
|
||||
const CreateAppCard = ({ ref }: CreateAppCardProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className='bg-background-default-dimm flex min-h-[160px] flex-col rounded-xl border-[0.5px]
|
||||
border-components-panel-border transition-all duration-200 ease-in-out'
|
||||
>
|
||||
<Link ref={ref} className='group flex grow cursor-pointer items-start p-4' href='/datasets/create'>
|
||||
<div className='flex items-center gap-3'>
|
||||
<div className='flex h-10 w-10 items-center justify-center rounded-lg border border-dashed border-divider-regular bg-background-default-lighter
|
||||
p-2 group-hover:border-solid group-hover:border-effects-highlight group-hover:bg-background-default-dodge'
|
||||
>
|
||||
<RiAddLine className='h-4 w-4 text-text-tertiary group-hover:text-text-accent' />
|
||||
</div>
|
||||
<div className='system-md-semibold text-text-secondary group-hover:text-text-accent'>{t('dataset.createDataset')}</div>
|
||||
</div>
|
||||
</Link>
|
||||
<div className='system-xs-regular p-4 pt-0 text-text-tertiary'>{t('dataset.createDatasetIntro')}</div>
|
||||
<Link className='group flex cursor-pointer items-center gap-1 rounded-b-xl border-t-[0.5px] border-divider-subtle p-4' href='/datasets/connect'>
|
||||
<div className='system-xs-medium text-text-tertiary group-hover:text-text-accent'>{t('dataset.connectDataset')}</div>
|
||||
<RiArrowRightLine className='h-3.5 w-3.5 text-text-tertiary group-hover:text-text-accent' />
|
||||
</Link>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
CreateAppCard.displayName = 'CreateAppCard'
|
||||
|
||||
export default CreateAppCard
|
||||
|
|
@ -106,8 +106,8 @@ const Uploader: FC<Props> = ({
|
|||
<div className='flex w-full items-center justify-center space-x-2'>
|
||||
<RiUploadCloud2Line className='h-6 w-6 text-text-tertiary' />
|
||||
<div className='text-text-tertiary'>
|
||||
{t('datasetCreation.stepOne.uploader.button')}
|
||||
<span className='cursor-pointer pl-1 text-text-accent' onClick={selectHandle}>{t('datasetDocuments.list.batchModal.browse')}</span>
|
||||
{t('app.dslUploader.button')}
|
||||
<span className='cursor-pointer pl-1 text-text-accent' onClick={selectHandle}>{t('app.dslUploader.browse')}</span>
|
||||
</div>
|
||||
</div>
|
||||
{dragging && <div ref={dragRef} className='absolute left-0 top-0 h-full w-full' />}
|
||||
|
|
|
|||
|
|
@ -370,20 +370,14 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
{app.description}
|
||||
</div>
|
||||
</div>
|
||||
<div className={cn(
|
||||
'absolute bottom-1 left-0 right-0 h-[42px] shrink-0 items-center pb-[6px] pl-[14px] pr-[6px] pt-1',
|
||||
tags.length ? 'flex' : '!hidden group-hover:!flex',
|
||||
)}>
|
||||
<div className='absolute bottom-1 left-0 right-0 flex h-[42px] shrink-0 items-center pb-[6px] pl-[14px] pr-[6px] pt-1'>
|
||||
{isCurrentWorkspaceEditor && (
|
||||
<>
|
||||
<div className={cn('flex w-0 grow items-center gap-1')} onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
e.preventDefault()
|
||||
}}>
|
||||
<div className={cn(
|
||||
'mr-[41px] w-full grow group-hover:!mr-0 group-hover:!block',
|
||||
tags.length ? '!block' : '!hidden',
|
||||
)}>
|
||||
<div className='mr-[41px] w-full grow group-hover:!mr-0'>
|
||||
<TagSelector
|
||||
position='bl'
|
||||
type='app'
|
||||
|
|
@ -395,7 +389,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className='mx-1 !hidden h-[14px] w-[1px] shrink-0 group-hover:!flex' />
|
||||
<div className='mx-1 !hidden h-[14px] w-[1px] shrink-0 bg-divider-regular group-hover:!flex' />
|
||||
<div className='!hidden shrink-0 group-hover:!flex'>
|
||||
<CustomPopover
|
||||
htmlContent={<Operations />}
|
||||
|
|
|
|||
|
|
@ -284,9 +284,9 @@ const Chat: FC<ChatProps> = ({
|
|||
{
|
||||
!noStopResponding && isResponding && (
|
||||
<div className='mb-2 flex justify-center'>
|
||||
<Button onClick={onStopResponding}>
|
||||
<StopCircle className='mr-[5px] h-3.5 w-3.5 text-gray-500' />
|
||||
<span className='text-xs font-normal text-gray-500'>{t('appDebug.operation.stopResponding')}</span>
|
||||
<Button className='border-components-panel-border bg-components-panel-bg text-components-button-secondary-text' onClick={onStopResponding}>
|
||||
<StopCircle className='mr-[5px] h-3.5 w-3.5' />
|
||||
<span className='text-xs font-normal'>{t('appDebug.operation.stopResponding')}</span>
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,13 @@ const WorkflowVariableBlockReplacementBlock = ({
|
|||
variables,
|
||||
}: WorkflowVariableBlockType) => {
|
||||
const [editor] = useLexicalComposerContext()
|
||||
const ragVariables = variables?.reduce<any[]>((acc, curr) => {
|
||||
if (curr.nodeId === 'rag')
|
||||
acc.push(...curr.vars)
|
||||
else
|
||||
acc.push(...curr.vars.filter(v => v.isRagVariable))
|
||||
return acc
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if (!editor.hasNodes([WorkflowVariableBlockNode]))
|
||||
|
|
@ -32,7 +39,7 @@ const WorkflowVariableBlockReplacementBlock = ({
|
|||
onInsert()
|
||||
|
||||
const nodePathString = textNode.getTextContent().slice(3, -3)
|
||||
return $applyNodeReplacement($createWorkflowVariableBlockNode(nodePathString.split('.'), workflowNodesMap, getVarType, variables?.find(o => o.nodeId === 'env')?.vars || [], variables?.find(o => o.nodeId === 'conversation')?.vars || [], variables?.find(o => o.nodeId === 'rag')?.vars || []))
|
||||
return $applyNodeReplacement($createWorkflowVariableBlockNode(nodePathString.split('.'), workflowNodesMap, getVarType, variables?.find(o => o.nodeId === 'env')?.vars || [], variables?.find(o => o.nodeId === 'conversation')?.vars || [], ragVariables))
|
||||
}, [onInsert, workflowNodesMap, getVarType, variables])
|
||||
|
||||
const getMatch = useCallback((text: string) => {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
import cn from '@/utils/classnames'
|
||||
import React, { useMemo } from 'react'
|
||||
|
||||
type CredentialIconProps = {
|
||||
avatar_url?: string
|
||||
name: string
|
||||
size?: number
|
||||
className?: string
|
||||
}
|
||||
|
||||
const ICON_BG_COLORS = [
|
||||
'bg-components-icon-bg-orange-dark-solid',
|
||||
'bg-components-icon-bg-pink-solid',
|
||||
'bg-components-icon-bg-indigo-solid',
|
||||
'bg-components-icon-bg-teal-solid',
|
||||
]
|
||||
|
||||
export const CredentialIcon: React.FC<CredentialIconProps> = ({
|
||||
avatar_url,
|
||||
name,
|
||||
size = 20,
|
||||
className = '',
|
||||
}) => {
|
||||
const firstLetter = useMemo(() => name.charAt(0).toUpperCase(), [name])
|
||||
const bgColor = useMemo(() => ICON_BG_COLORS[firstLetter.charCodeAt(0) % ICON_BG_COLORS.length], [firstLetter])
|
||||
|
||||
if (avatar_url && avatar_url !== 'default') {
|
||||
return (
|
||||
<img
|
||||
src={avatar_url}
|
||||
alt={`${name} logo`}
|
||||
width={size}
|
||||
height={size}
|
||||
className={cn('shrink-0 rounded-md border border-divider-regular object-contain', className)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex shrink-0 items-center justify-center rounded-md border border-divider-regular',
|
||||
bgColor,
|
||||
className,
|
||||
)}
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
>
|
||||
<span className='bg-gradient-to-b from-components-avatar-shape-fill-stop-0 to-components-avatar-shape-fill-stop-100 bg-clip-text text-[13px] font-semibold leading-[1.2] text-transparent opacity-90'>
|
||||
{firstLetter}
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
|
@ -313,7 +313,7 @@ const FileUploader = ({
|
|||
<RiUploadCloud2Line className='mr-2 size-5' />
|
||||
|
||||
<span>
|
||||
{t('datasetCreation.stepOne.uploader.button')}
|
||||
{notSupportBatchUpload ? t('datasetCreation.stepOne.uploader.buttonSingleFile') : t('datasetCreation.stepOne.uploader.button')}
|
||||
{supportTypes.length > 0 && (
|
||||
<label className="ml-1 cursor-pointer text-text-accent" onClick={selectHandle}>{t('datasetCreation.stepOne.uploader.browse')}</label>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import React, { useCallback } from 'react'
|
||||
import React, { useCallback, useEffect, useMemo } from 'react'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
|
|
@ -24,7 +24,14 @@ const CredentialSelector = ({
|
|||
}: CredentialSelectorProps) => {
|
||||
const [open, { toggle }] = useBoolean(false)
|
||||
|
||||
const currentCredential = credentials.find(cred => cred.id === currentCredentialId) as DataSourceCredential
|
||||
const currentCredential = useMemo(() => {
|
||||
return credentials.find(cred => cred.id === currentCredentialId)
|
||||
}, [credentials, currentCredentialId])
|
||||
|
||||
useEffect(() => {
|
||||
if (!currentCredential && credentials.length)
|
||||
onCredentialChange(credentials[0].id)
|
||||
}, [currentCredential, credentials])
|
||||
|
||||
const handleCredentialChange = useCallback((credentialId: string) => {
|
||||
onCredentialChange(credentialId)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import { CredentialIcon } from '@/app/components/datasets/common/credential-icon'
|
||||
import type { DataSourceCredential } from '@/types/pipeline'
|
||||
import { RiCheckLine } from '@remixicon/react'
|
||||
import React, { useCallback } from 'react'
|
||||
|
|
@ -28,8 +29,12 @@ const Item = ({
|
|||
className='flex cursor-pointer items-center gap-x-2 rounded-lg p-2 hover:bg-state-base-hover'
|
||||
onClick={handleCredentialChange}
|
||||
>
|
||||
<img src={avatar_url} className='size-5 shrink-0 rounded-md border border-divider-regular object-contain' />
|
||||
<span className='system-sm-medium grow text-text-secondary'>
|
||||
<CredentialIcon
|
||||
avatar_url={avatar_url}
|
||||
name={name}
|
||||
size={20}
|
||||
/>
|
||||
<span className='system-sm-medium grow truncate text-text-secondary'>
|
||||
{t('datasetPipeline.credentialSelector.name', {
|
||||
credentialName: name,
|
||||
pluginName,
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@ import type { DataSourceCredential } from '@/types/pipeline'
|
|||
import { useTranslation } from 'react-i18next'
|
||||
import { RiArrowDownSLine } from '@remixicon/react'
|
||||
import cn from '@/utils/classnames'
|
||||
import { CredentialIcon } from '@/app/components/datasets/common/credential-icon'
|
||||
|
||||
type TriggerProps = {
|
||||
currentCredential: DataSourceCredential
|
||||
currentCredential: DataSourceCredential | undefined
|
||||
pluginName: string
|
||||
isOpen: boolean
|
||||
}
|
||||
|
|
@ -19,23 +20,29 @@ const Trigger = ({
|
|||
|
||||
const {
|
||||
avatar_url,
|
||||
name,
|
||||
} = currentCredential
|
||||
name = '',
|
||||
} = currentCredential || {}
|
||||
|
||||
return (
|
||||
<div className={cn(
|
||||
'flex cursor-pointer items-center gap-x-2 rounded-md p-1 pr-2',
|
||||
isOpen ? 'bg-state-base-hover' : 'hover:bg-state-base-hover',
|
||||
)}>
|
||||
<img src={avatar_url} className='size-5 shrink-0 rounded-md border border-divider-regular object-contain' />
|
||||
<div className='flex grow items-center gap-x-1'>
|
||||
<span className='system-md-semibold text-text-secondary'>
|
||||
<div
|
||||
className={cn(
|
||||
'flex cursor-pointer items-center gap-x-2 rounded-md p-1 pr-2',
|
||||
isOpen ? 'bg-state-base-hover' : 'hover:bg-state-base-hover',
|
||||
)}
|
||||
>
|
||||
<CredentialIcon
|
||||
avatar_url={avatar_url}
|
||||
name={name}
|
||||
size={20}
|
||||
/>
|
||||
<div className='flex items-center gap-x-1'>
|
||||
<span className='system-md-semibold min-w-0 truncate text-text-secondary'>
|
||||
{t('datasetPipeline.credentialSelector.name', {
|
||||
credentialName: name,
|
||||
pluginName,
|
||||
})}
|
||||
</span>
|
||||
<RiArrowDownSLine className='size-4 text-text-secondary' />
|
||||
<RiArrowDownSLine className='size-4 shrink-0 text-text-secondary' />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,11 +23,11 @@ const Header = ({
|
|||
|
||||
return (
|
||||
<div className='flex items-center gap-x-2'>
|
||||
<div className='flex shrink-0 grow items-center gap-x-1'>
|
||||
<div className='flex grow items-center gap-x-1'>
|
||||
<CredentialSelector
|
||||
{...rest}
|
||||
/>
|
||||
<Divider type='vertical' className='mx-1 h-3.5' />
|
||||
<Divider type='vertical' className='mx-1 h-3.5 shrink-0' />
|
||||
<Tooltip
|
||||
popupContent={t('datasetPipeline.configurationTip', { pluginName: rest.pluginName })}
|
||||
position='top'
|
||||
|
|
@ -35,7 +35,7 @@ const Header = ({
|
|||
<Button
|
||||
variant='ghost'
|
||||
size='small'
|
||||
className='size-6 px-1'
|
||||
className='size-6 shrink-0 px-1'
|
||||
>
|
||||
<RiEqualizer2Line
|
||||
className='h-4 w-4'
|
||||
|
|
@ -45,13 +45,13 @@ const Header = ({
|
|||
</Tooltip>
|
||||
</div>
|
||||
<a
|
||||
className='system-xs-medium flex items-center gap-x-1 overflow-hidden text-text-accent'
|
||||
className='system-xs-medium flex shrink-0 items-center gap-x-1 text-text-accent'
|
||||
href={docLink}
|
||||
target='_blank'
|
||||
rel='noopener noreferrer'
|
||||
>
|
||||
<RiBookOpenLine className='size-3.5 shrink-0' />
|
||||
<span className='grow truncate' title={docTitle}>{docTitle}</span>
|
||||
<span title={docTitle}>{docTitle}</span>
|
||||
</a>
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,19 +13,21 @@ import { useDataSourceStore, useDataSourceStoreWithSelector } from '../store'
|
|||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useModalContextSelector } from '@/context/modal-context'
|
||||
import Title from './title'
|
||||
import { CredentialTypeEnum } from '@/app/components/plugins/plugin-auth'
|
||||
import { noop } from 'lodash-es'
|
||||
import { useGetDataSourceAuth } from '@/service/use-datasource'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
|
||||
type OnlineDocumentsProps = {
|
||||
isInPipeline?: boolean
|
||||
nodeId: string
|
||||
nodeData: DataSourceNodeType
|
||||
onCredentialChange: (credentialId: string) => void
|
||||
}
|
||||
|
||||
const OnlineDocuments = ({
|
||||
nodeId,
|
||||
nodeData,
|
||||
isInPipeline = false,
|
||||
onCredentialChange,
|
||||
}: OnlineDocumentsProps) => {
|
||||
const pipelineId = useDatasetDetailContextWithSelector(s => s.dataset?.pipeline_id)
|
||||
const setShowAccountSettingModal = useModalContextSelector(s => s.setShowAccountSettingModal)
|
||||
|
|
@ -33,13 +35,19 @@ const OnlineDocuments = ({
|
|||
documentsData,
|
||||
searchValue,
|
||||
selectedPagesId,
|
||||
currentWorkspaceId,
|
||||
currentCredentialId,
|
||||
} = useDataSourceStoreWithSelector(useShallow(state => ({
|
||||
documentsData: state.documentsData,
|
||||
searchValue: state.searchValue,
|
||||
selectedPagesId: state.selectedPagesId,
|
||||
currentWorkspaceId: state.currentWorkspaceId,
|
||||
currentCredentialId: state.currentCredentialId,
|
||||
})))
|
||||
|
||||
const { data: dataSourceAuth } = useGetDataSourceAuth({
|
||||
pluginId: nodeData.plugin_id,
|
||||
provider: nodeData.provider_name,
|
||||
})
|
||||
|
||||
const dataSourceStore = useDataSourceStore()
|
||||
|
||||
const PagesMapAndSelectedPagesId: DataSourceNotionPageMap = useMemo(() => {
|
||||
|
|
@ -61,19 +69,21 @@ const OnlineDocuments = ({
|
|||
: `/rag/pipelines/${pipelineId}/workflows/draft/datasource/nodes/${nodeId}/run`
|
||||
|
||||
const getOnlineDocuments = useCallback(async () => {
|
||||
const { currentCredentialId } = dataSourceStore.getState()
|
||||
if (!currentCredentialId) return
|
||||
ssePost(
|
||||
datasourceNodeRunURL,
|
||||
{
|
||||
body: {
|
||||
inputs: {},
|
||||
credential_id: currentCredentialId,
|
||||
datasource_type: DatasourceType.onlineDocument,
|
||||
},
|
||||
},
|
||||
{
|
||||
onDataSourceNodeCompleted: (documentsData: DataSourceNodeCompletedResponse) => {
|
||||
const { setDocumentsData, setCurrentWorkspaceId } = dataSourceStore.getState()
|
||||
const { setDocumentsData } = dataSourceStore.getState()
|
||||
setDocumentsData(documentsData.data as DataSourceNotionWorkspace[])
|
||||
setCurrentWorkspaceId(documentsData.data[0].workspace_id)
|
||||
},
|
||||
onDataSourceNodeError: (error: DataSourceNodeErrorResponse) => {
|
||||
Toast.notify({
|
||||
|
|
@ -86,33 +96,8 @@ const OnlineDocuments = ({
|
|||
}, [dataSourceStore, datasourceNodeRunURL])
|
||||
|
||||
useEffect(() => {
|
||||
const {
|
||||
setDocumentsData,
|
||||
setCurrentWorkspaceId,
|
||||
setSearchValue,
|
||||
setSelectedPagesId,
|
||||
setOnlineDocuments,
|
||||
setCurrentDocument,
|
||||
currentNodeIdRef,
|
||||
} = dataSourceStore.getState()
|
||||
if (nodeId !== currentNodeIdRef.current) {
|
||||
setDocumentsData([])
|
||||
setCurrentWorkspaceId('')
|
||||
setSearchValue('')
|
||||
setSelectedPagesId(new Set())
|
||||
setOnlineDocuments([])
|
||||
setCurrentDocument(undefined)
|
||||
currentNodeIdRef.current = nodeId
|
||||
getOnlineDocuments()
|
||||
}
|
||||
else {
|
||||
// Avoid fetching documents when come back from next step
|
||||
if (!documentsData.length)
|
||||
getOnlineDocuments()
|
||||
}
|
||||
}, [nodeId])
|
||||
|
||||
const currentWorkspace = documentsData.find(workspace => workspace.workspace_id === currentWorkspaceId)
|
||||
getOnlineDocuments()
|
||||
}, [currentCredentialId])
|
||||
|
||||
const handleSearchValueChange = useCallback((value: string) => {
|
||||
const { setSearchValue } = dataSourceStore.getState()
|
||||
|
|
@ -137,29 +122,16 @@ const OnlineDocuments = ({
|
|||
})
|
||||
}, [setShowAccountSettingModal])
|
||||
|
||||
if (!documentsData?.length)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div className='flex flex-col gap-y-2'>
|
||||
<Header
|
||||
// todo: delete mock data
|
||||
docTitle='How to use?'
|
||||
docLink='https://docs.dify.ai'
|
||||
onClickConfiguration={handleSetting}
|
||||
pluginName={nodeData.datasource_label}
|
||||
currentCredentialId={'12345678'}
|
||||
onCredentialChange={noop}
|
||||
credentials={[{
|
||||
avatar_url: 'https://cloud.dify.ai/logo/logo.svg',
|
||||
credential: {
|
||||
credentials: '......',
|
||||
},
|
||||
id: '12345678',
|
||||
is_default: true,
|
||||
name: 'test123',
|
||||
type: CredentialTypeEnum.API_KEY,
|
||||
}]}
|
||||
currentCredentialId={currentCredentialId}
|
||||
onCredentialChange={onCredentialChange}
|
||||
credentials={dataSourceAuth?.result || []}
|
||||
/>
|
||||
<div className='rounded-xl border border-components-panel-border bg-background-default-subtle'>
|
||||
<div className='flex items-center gap-x-2 rounded-t-xl border-b border-b-divider-regular bg-components-panel-bg p-1 pl-3'>
|
||||
|
|
@ -172,18 +144,24 @@ const OnlineDocuments = ({
|
|||
/>
|
||||
</div>
|
||||
<div className='overflow-hidden rounded-b-xl'>
|
||||
<PageSelector
|
||||
checkedIds={selectedPagesId}
|
||||
disabledValue={new Set()}
|
||||
searchValue={searchValue}
|
||||
list={currentWorkspace?.pages || []}
|
||||
pagesMap={PagesMapAndSelectedPagesId}
|
||||
onSelect={handleSelectPages}
|
||||
canPreview={!isInPipeline}
|
||||
onPreview={handlePreviewPage}
|
||||
isMultipleChoice={!isInPipeline}
|
||||
currentWorkspaceId={currentWorkspaceId}
|
||||
/>
|
||||
{documentsData?.length ? (
|
||||
<PageSelector
|
||||
checkedIds={selectedPagesId}
|
||||
disabledValue={new Set()}
|
||||
searchValue={searchValue}
|
||||
list={documentsData[0].pages || []}
|
||||
pagesMap={PagesMapAndSelectedPagesId}
|
||||
onSelect={handleSelectPages}
|
||||
canPreview={!isInPipeline}
|
||||
onPreview={handlePreviewPage}
|
||||
isMultipleChoice={!isInPipeline}
|
||||
currentCredentialId={currentCredentialId}
|
||||
/>
|
||||
) : (
|
||||
<div className='flex h-[296px] items-center justify-center'>
|
||||
<Loading type='app' />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ type PageSelectorProps = {
|
|||
canPreview?: boolean
|
||||
onPreview?: (selectedPageId: string) => void
|
||||
isMultipleChoice?: boolean
|
||||
currentWorkspaceId: string
|
||||
currentCredentialId: string
|
||||
}
|
||||
|
||||
export type NotionPageTreeItem = {
|
||||
|
|
@ -42,7 +42,7 @@ const PageSelector = ({
|
|||
canPreview = true,
|
||||
onPreview,
|
||||
isMultipleChoice = true,
|
||||
currentWorkspaceId,
|
||||
currentCredentialId,
|
||||
}: PageSelectorProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [dataList, setDataList] = useState<NotionPageItem[]>([])
|
||||
|
|
@ -56,8 +56,7 @@ const PageSelector = ({
|
|||
depth: 0,
|
||||
}
|
||||
}))
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [currentWorkspaceId])
|
||||
}, [currentCredentialId])
|
||||
|
||||
const searchDataList = list.filter((item) => {
|
||||
return item.page_name.includes(searchValue)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import { useDataSourceStore } from '../../../../store'
|
|||
import Bucket from './bucket'
|
||||
import BreadcrumbItem from './item'
|
||||
import Dropdown from './dropdown'
|
||||
import type { OnlineDriveFile } from '@/models/pipeline'
|
||||
|
||||
type BreadcrumbsProps = {
|
||||
prefix: string[]
|
||||
|
|
@ -12,12 +11,6 @@ type BreadcrumbsProps = {
|
|||
bucket: string
|
||||
searchResultsLength: number
|
||||
isInPipeline: boolean
|
||||
getOnlineDriveFiles: (params: {
|
||||
prefix?: string[]
|
||||
bucket?: string
|
||||
startAfter?: string
|
||||
fileList?: OnlineDriveFile[]
|
||||
}) => void
|
||||
}
|
||||
|
||||
const Breadcrumbs = ({
|
||||
|
|
@ -26,7 +19,6 @@ const Breadcrumbs = ({
|
|||
bucket,
|
||||
searchResultsLength,
|
||||
isInPipeline,
|
||||
getOnlineDriveFiles,
|
||||
}: BreadcrumbsProps) => {
|
||||
const { t } = useTranslation()
|
||||
const dataSourceStore = useDataSourceStore()
|
||||
|
|
@ -56,23 +48,14 @@ const Breadcrumbs = ({
|
|||
setSelectedFileKeys([])
|
||||
setBucket('')
|
||||
setPrefix([])
|
||||
getOnlineDriveFiles({
|
||||
prefix: [],
|
||||
bucket: '',
|
||||
fileList: [],
|
||||
})
|
||||
}, [dataSourceStore, getOnlineDriveFiles])
|
||||
}, [dataSourceStore])
|
||||
|
||||
const handleClickBucketName = useCallback(() => {
|
||||
const { setFileList, setSelectedFileKeys, setPrefix } = dataSourceStore.getState()
|
||||
setFileList([])
|
||||
setSelectedFileKeys([])
|
||||
setPrefix([])
|
||||
getOnlineDriveFiles({
|
||||
prefix: [],
|
||||
fileList: [],
|
||||
})
|
||||
}, [dataSourceStore, getOnlineDriveFiles])
|
||||
}, [dataSourceStore])
|
||||
|
||||
const handleClickBreadcrumb = useCallback((index: number) => {
|
||||
const { prefix, setFileList, setSelectedFileKeys, setPrefix } = dataSourceStore.getState()
|
||||
|
|
@ -80,11 +63,7 @@ const Breadcrumbs = ({
|
|||
setFileList([])
|
||||
setSelectedFileKeys([])
|
||||
setPrefix(newPrefix)
|
||||
getOnlineDriveFiles({
|
||||
prefix: newPrefix,
|
||||
fileList: [],
|
||||
})
|
||||
}, [dataSourceStore, getOnlineDriveFiles])
|
||||
}, [dataSourceStore])
|
||||
|
||||
return (
|
||||
<div className='flex grow items-center overflow-hidden'>
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import React from 'react'
|
|||
import Breadcrumbs from './breadcrumbs'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { OnlineDriveFile } from '@/models/pipeline'
|
||||
|
||||
type HeaderProps = {
|
||||
prefix: string[]
|
||||
|
|
@ -13,12 +12,6 @@ type HeaderProps = {
|
|||
handleInputChange: React.ChangeEventHandler<HTMLInputElement>
|
||||
handleResetKeywords: () => void
|
||||
isInPipeline: boolean
|
||||
getOnlineDriveFiles: (params: {
|
||||
prefix?: string[]
|
||||
bucket?: string
|
||||
startAfter?: string
|
||||
fileList?: OnlineDriveFile[]
|
||||
}) => void
|
||||
}
|
||||
|
||||
const Header = ({
|
||||
|
|
@ -30,7 +23,6 @@ const Header = ({
|
|||
searchResultsLength,
|
||||
handleInputChange,
|
||||
handleResetKeywords,
|
||||
getOnlineDriveFiles,
|
||||
}: HeaderProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
|
|
@ -42,7 +34,6 @@ const Header = ({
|
|||
bucket={bucket}
|
||||
searchResultsLength={searchResultsLength}
|
||||
isInPipeline={isInPipeline}
|
||||
getOnlineDriveFiles={getOnlineDriveFiles}
|
||||
/>
|
||||
<Input
|
||||
value={inputValue}
|
||||
|
|
|
|||
|
|
@ -17,12 +17,6 @@ type FileListProps = {
|
|||
handleSelectFile: (file: OnlineDriveFile) => void
|
||||
handleOpenFolder: (file: OnlineDriveFile) => void
|
||||
isLoading: boolean
|
||||
getOnlineDriveFiles: (params: {
|
||||
prefix?: string[]
|
||||
bucket?: string
|
||||
startAfter?: string
|
||||
fileList?: OnlineDriveFile[]
|
||||
}) => void
|
||||
}
|
||||
|
||||
const FileList = ({
|
||||
|
|
@ -38,7 +32,6 @@ const FileList = ({
|
|||
handleOpenFolder,
|
||||
isInPipeline,
|
||||
isLoading,
|
||||
getOnlineDriveFiles,
|
||||
}: FileListProps) => {
|
||||
const [inputValue, setInputValue] = useState(keywords)
|
||||
|
||||
|
|
@ -71,7 +64,6 @@ const FileList = ({
|
|||
handleInputChange={handleInputChange}
|
||||
searchResultsLength={searchResultsLength}
|
||||
handleResetKeywords={handleResetKeywords}
|
||||
getOnlineDriveFiles={getOnlineDriveFiles}
|
||||
/>
|
||||
<List
|
||||
fileList={fileList}
|
||||
|
|
@ -82,7 +74,6 @@ const FileList = ({
|
|||
handleSelectFile={handleSelectFile}
|
||||
isInPipeline={isInPipeline}
|
||||
isLoading={isLoading}
|
||||
getOnlineDriveFiles={getOnlineDriveFiles}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -16,12 +16,6 @@ type FileListProps = {
|
|||
handleResetKeywords: () => void
|
||||
handleSelectFile: (file: OnlineDriveFile) => void
|
||||
handleOpenFolder: (file: OnlineDriveFile) => void
|
||||
getOnlineDriveFiles: (params: {
|
||||
prefix?: string[]
|
||||
bucket?: string
|
||||
startAfter?: string
|
||||
fileList?: OnlineDriveFile[]
|
||||
}) => void
|
||||
}
|
||||
|
||||
const List = ({
|
||||
|
|
@ -33,27 +27,23 @@ const List = ({
|
|||
handleOpenFolder,
|
||||
isInPipeline,
|
||||
isLoading,
|
||||
getOnlineDriveFiles,
|
||||
}: FileListProps) => {
|
||||
const anchorRef = useRef<HTMLDivElement>(null)
|
||||
const observerRef = useRef<IntersectionObserver>()
|
||||
const observerRef = useRef<IntersectionObserver>(null)
|
||||
const dataSourceStore = useDataSourceStore()
|
||||
|
||||
useEffect(() => {
|
||||
if (anchorRef.current) {
|
||||
observerRef.current = new IntersectionObserver((entries) => {
|
||||
const { startAfter, isTruncated } = dataSourceStore.getState()
|
||||
if (entries[0].isIntersecting && isTruncated.current && !isLoading) {
|
||||
startAfter.current = fileList[fileList.length - 1].key
|
||||
getOnlineDriveFiles({ startAfter: fileList[fileList.length - 1].key })
|
||||
}
|
||||
const { setStartAfter, isTruncated } = dataSourceStore.getState()
|
||||
if (entries[0].isIntersecting && isTruncated.current && !isLoading)
|
||||
setStartAfter(fileList[fileList.length - 1].key)
|
||||
}, {
|
||||
rootMargin: '100px',
|
||||
})
|
||||
observerRef.current.observe(anchorRef.current)
|
||||
}
|
||||
return () => observerRef.current?.disconnect()
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [anchorRef])
|
||||
|
||||
const isAllLoading = isLoading && fileList.length === 0 && keywords.length === 0
|
||||
|
|
|
|||
|
|
@ -13,54 +13,55 @@ import { convertOnlineDriveData } from './utils'
|
|||
import produce from 'immer'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useModalContextSelector } from '@/context/modal-context'
|
||||
import { noop } from 'lodash-es'
|
||||
import { CredentialTypeEnum } from '@/app/components/plugins/plugin-auth'
|
||||
import { useGetDataSourceAuth } from '@/service/use-datasource'
|
||||
|
||||
type OnlineDriveProps = {
|
||||
nodeId: string
|
||||
nodeData: DataSourceNodeType
|
||||
isInPipeline?: boolean
|
||||
onCredentialChange: (credentialId: string) => void
|
||||
}
|
||||
|
||||
const OnlineDrive = ({
|
||||
nodeId,
|
||||
nodeData,
|
||||
isInPipeline = false,
|
||||
onCredentialChange,
|
||||
}: OnlineDriveProps) => {
|
||||
const pipelineId = useDatasetDetailContextWithSelector(s => s.dataset?.pipeline_id)
|
||||
const setShowAccountSettingModal = useModalContextSelector(s => s.setShowAccountSettingModal)
|
||||
const {
|
||||
startAfter,
|
||||
prefix,
|
||||
keywords,
|
||||
bucket,
|
||||
selectedFileKeys,
|
||||
fileList,
|
||||
currentCredentialId,
|
||||
} = useDataSourceStoreWithSelector(useShallow(state => ({
|
||||
startAfter: state.startAfter,
|
||||
prefix: state.prefix,
|
||||
keywords: state.keywords,
|
||||
bucket: state.bucket,
|
||||
selectedFileKeys: state.selectedFileKeys,
|
||||
fileList: state.fileList,
|
||||
currentCredentialId: state.currentCredentialId,
|
||||
})))
|
||||
const dataSourceStore = useDataSourceStore()
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
|
||||
const { data: dataSourceAuth } = useGetDataSourceAuth({
|
||||
pluginId: nodeData.plugin_id,
|
||||
provider: nodeData.provider_name,
|
||||
})
|
||||
|
||||
const datasourceNodeRunURL = !isInPipeline
|
||||
? `/rag/pipelines/${pipelineId}/workflows/published/datasource/nodes/${nodeId}/run`
|
||||
: `/rag/pipelines/${pipelineId}/workflows/draft/datasource/nodes/${nodeId}/run`
|
||||
|
||||
const getOnlineDriveFiles = useCallback(async (params: {
|
||||
prefix?: string[]
|
||||
bucket?: string
|
||||
startAfter?: string
|
||||
fileList?: OnlineDriveFile[]
|
||||
}) => {
|
||||
const { startAfter, prefix, bucket, fileList } = dataSourceStore.getState()
|
||||
const _prefix = params.prefix ?? prefix
|
||||
const _bucket = params.bucket ?? bucket
|
||||
const _startAfter = params.startAfter ?? startAfter.current
|
||||
const _fileList = params.fileList ?? fileList
|
||||
const prefixString = _prefix.length > 0 ? `${_prefix.join('/')}/` : ''
|
||||
const getOnlineDriveFiles = useCallback(async () => {
|
||||
const { startAfter, prefix, bucket, fileList, currentCredentialId } = dataSourceStore.getState()
|
||||
const prefixString = prefix.length > 0 ? `${prefix.join('/')}/` : ''
|
||||
setIsLoading(true)
|
||||
ssePost(
|
||||
datasourceNodeRunURL,
|
||||
|
|
@ -68,18 +69,19 @@ const OnlineDrive = ({
|
|||
body: {
|
||||
inputs: {
|
||||
prefix: prefixString,
|
||||
bucket: _bucket,
|
||||
start_after: _startAfter,
|
||||
bucket,
|
||||
start_after: startAfter,
|
||||
max_keys: 30, // Adjust as needed
|
||||
},
|
||||
datasource_type: DatasourceType.onlineDrive,
|
||||
credential_id: currentCredentialId,
|
||||
},
|
||||
},
|
||||
{
|
||||
onDataSourceNodeCompleted: (documentsData: DataSourceNodeCompletedResponse) => {
|
||||
const { setFileList, isTruncated } = dataSourceStore.getState()
|
||||
const { fileList: newFileList, isTruncated: newIsTruncated } = convertOnlineDriveData(documentsData.data, _prefix, _bucket)
|
||||
setFileList([..._fileList, ...newFileList])
|
||||
const { fileList: newFileList, isTruncated: newIsTruncated } = convertOnlineDriveData(documentsData.data, prefix, bucket)
|
||||
setFileList([...fileList, ...newFileList])
|
||||
isTruncated.current = newIsTruncated
|
||||
setIsLoading(false)
|
||||
},
|
||||
|
|
@ -95,34 +97,8 @@ const OnlineDrive = ({
|
|||
}, [datasourceNodeRunURL, dataSourceStore])
|
||||
|
||||
useEffect(() => {
|
||||
const {
|
||||
setFileList,
|
||||
setBucket,
|
||||
setPrefix,
|
||||
setKeywords,
|
||||
setSelectedFileKeys,
|
||||
currentNodeIdRef,
|
||||
} = dataSourceStore.getState()
|
||||
if (nodeId !== currentNodeIdRef.current) {
|
||||
setFileList([])
|
||||
setBucket('')
|
||||
setPrefix([])
|
||||
setKeywords('')
|
||||
setSelectedFileKeys([])
|
||||
currentNodeIdRef.current = nodeId
|
||||
getOnlineDriveFiles({
|
||||
prefix: [],
|
||||
bucket: '',
|
||||
fileList: [],
|
||||
startAfter: '',
|
||||
})
|
||||
}
|
||||
else {
|
||||
// Avoid fetching files when come back from next step
|
||||
if (fileList.length > 0) return
|
||||
getOnlineDriveFiles({})
|
||||
}
|
||||
}, [nodeId])
|
||||
getOnlineDriveFiles()
|
||||
}, [startAfter, prefix, bucket, currentCredentialId])
|
||||
|
||||
const onlineDriveFileList = useMemo(() => {
|
||||
if (keywords)
|
||||
|
|
@ -163,7 +139,6 @@ const OnlineDrive = ({
|
|||
setFileList([])
|
||||
if (file.type === OnlineDriveFileType.bucket) {
|
||||
setBucket(file.displayName)
|
||||
getOnlineDriveFiles({ bucket: file.displayName, fileList: [] })
|
||||
}
|
||||
else {
|
||||
setSelectedFileKeys([])
|
||||
|
|
@ -172,7 +147,6 @@ const OnlineDrive = ({
|
|||
draft.push(displayName)
|
||||
})
|
||||
setPrefix(newPrefix)
|
||||
getOnlineDriveFiles({ prefix: newPrefix, fileList: [] })
|
||||
}
|
||||
}, [dataSourceStore, getOnlineDriveFiles])
|
||||
|
||||
|
|
@ -185,23 +159,13 @@ const OnlineDrive = ({
|
|||
return (
|
||||
<div className='flex flex-col gap-y-2'>
|
||||
<Header
|
||||
// todo: delete mock data
|
||||
docTitle='Online Drive Docs'
|
||||
docLink='https://docs.dify.ai/'
|
||||
onClickConfiguration={handleSetting}
|
||||
pluginName={nodeData.datasource_label}
|
||||
currentCredentialId={'12345678'}
|
||||
onCredentialChange={noop}
|
||||
credentials={[{
|
||||
avatar_url: 'https://cloud.dify.ai/logo/logo.svg',
|
||||
credential: {
|
||||
credentials: '......',
|
||||
},
|
||||
id: '12345678',
|
||||
is_default: true,
|
||||
name: 'test123',
|
||||
type: CredentialTypeEnum.API_KEY,
|
||||
}]}
|
||||
currentCredentialId={currentCredentialId}
|
||||
onCredentialChange={onCredentialChange}
|
||||
credentials={dataSourceAuth?.result || []}
|
||||
/>
|
||||
<FileList
|
||||
fileList={onlineDriveFileList}
|
||||
|
|
@ -216,7 +180,6 @@ const OnlineDrive = ({
|
|||
handleOpenFolder={handleOpenFolder}
|
||||
isInPipeline={isInPipeline}
|
||||
isLoading={isLoading}
|
||||
getOnlineDriveFiles={getOnlineDriveFiles}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,12 +12,11 @@ import { createWebsiteCrawlSlice } from './slices/website-crawl'
|
|||
import type { OnlineDriveSliceShape } from './slices/online-drive'
|
||||
import { createOnlineDriveSlice } from './slices/online-drive'
|
||||
|
||||
export type DataSourceShape =
|
||||
CommonShape &
|
||||
LocalFileSliceShape &
|
||||
OnlineDocumentSliceShape &
|
||||
WebsiteCrawlSliceShape &
|
||||
OnlineDriveSliceShape
|
||||
export type DataSourceShape = CommonShape
|
||||
& LocalFileSliceShape
|
||||
& OnlineDocumentSliceShape
|
||||
& WebsiteCrawlSliceShape
|
||||
& OnlineDriveSliceShape
|
||||
|
||||
export const createDataSourceStore = () => {
|
||||
return createStore<DataSourceShape>((...args) => ({
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ type DataSourceProviderProps = {
|
|||
const DataSourceProvider = ({
|
||||
children,
|
||||
}: DataSourceProviderProps) => {
|
||||
const storeRef = useRef<DataSourceStoreApi>()
|
||||
const storeRef = useRef<DataSourceStoreApi>(null)
|
||||
|
||||
if (!storeRef.current)
|
||||
storeRef.current = createDataSourceStore()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,19 @@
|
|||
import type { StateCreator } from 'zustand'
|
||||
|
||||
export type CommonShape = {
|
||||
currentNodeIdRef: React.MutableRefObject<string | undefined>
|
||||
currentNodeIdRef: React.RefObject<string>
|
||||
currentCredentialId: string
|
||||
setCurrentCredentialId: (credentialId: string) => void
|
||||
currentCredentialIdRef: React.RefObject<string>
|
||||
}
|
||||
|
||||
export const createCommonSlice: StateCreator<CommonShape> = () => {
|
||||
export const createCommonSlice: StateCreator<CommonShape> = (set) => {
|
||||
return ({
|
||||
currentNodeIdRef: { current: undefined },
|
||||
currentNodeIdRef: { current: '' },
|
||||
currentCredentialId: '',
|
||||
setCurrentCredentialId: (credentialId: string) => {
|
||||
set({ currentCredentialId: credentialId })
|
||||
},
|
||||
currentCredentialIdRef: { current: '' },
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ export type LocalFileSliceShape = {
|
|||
setLocalFileList: (fileList: FileItem[]) => void
|
||||
currentLocalFile: File | undefined
|
||||
setCurrentLocalFile: (file: File | undefined) => void
|
||||
previewLocalFileRef: React.MutableRefObject<DocumentItem | undefined>
|
||||
previewLocalFileRef: React.RefObject<DocumentItem | undefined>
|
||||
}
|
||||
|
||||
export const createLocalFileSlice: StateCreator<LocalFileSliceShape> = (set, get) => {
|
||||
|
|
|
|||
|
|
@ -6,15 +6,13 @@ export type OnlineDocumentSliceShape = {
|
|||
setDocumentsData: (documentData: DataSourceNotionWorkspace[]) => void
|
||||
searchValue: string
|
||||
setSearchValue: (searchValue: string) => void
|
||||
currentWorkspaceId: string
|
||||
setCurrentWorkspaceId: (workspaceId: string) => void
|
||||
onlineDocuments: NotionPage[]
|
||||
setOnlineDocuments: (documents: NotionPage[]) => void
|
||||
currentDocument: NotionPage | undefined
|
||||
setCurrentDocument: (document: NotionPage | undefined) => void
|
||||
selectedPagesId: Set<string>
|
||||
setSelectedPagesId: (selectedPagesId: Set<string>) => void
|
||||
previewOnlineDocumentRef: React.MutableRefObject<NotionPage | undefined>
|
||||
previewOnlineDocumentRef: React.RefObject<NotionPage | undefined>
|
||||
}
|
||||
|
||||
export const createOnlineDocumentSlice: StateCreator<OnlineDocumentSliceShape> = (set, get) => {
|
||||
|
|
@ -27,10 +25,6 @@ export const createOnlineDocumentSlice: StateCreator<OnlineDocumentSliceShape> =
|
|||
setSearchValue: (searchValue: string) => set(() => ({
|
||||
searchValue,
|
||||
})),
|
||||
currentWorkspaceId: '',
|
||||
setCurrentWorkspaceId: (workspaceId: string) => set(() => ({
|
||||
currentWorkspaceId: workspaceId,
|
||||
})),
|
||||
onlineDocuments: [],
|
||||
setOnlineDocuments: (documents: NotionPage[]) => {
|
||||
set(() => ({
|
||||
|
|
|
|||
|
|
@ -12,9 +12,10 @@ export type OnlineDriveSliceShape = {
|
|||
setFileList: (fileList: OnlineDriveFile[]) => void
|
||||
bucket: string
|
||||
setBucket: (bucket: string) => void
|
||||
startAfter: React.MutableRefObject<string>
|
||||
isTruncated: React.MutableRefObject<boolean>
|
||||
previewOnlineDriveFileRef: React.MutableRefObject<OnlineDriveFile | undefined>
|
||||
startAfter: string
|
||||
setStartAfter: (startAfter: string) => void
|
||||
isTruncated: React.RefObject<boolean>
|
||||
previewOnlineDriveFileRef: React.RefObject<OnlineDriveFile | undefined>
|
||||
}
|
||||
|
||||
export const createOnlineDriveSlice: StateCreator<OnlineDriveSliceShape> = (set, get) => {
|
||||
|
|
@ -44,7 +45,10 @@ export const createOnlineDriveSlice: StateCreator<OnlineDriveSliceShape> = (set,
|
|||
setBucket: (bucket: string) => set(() => ({
|
||||
bucket,
|
||||
})),
|
||||
startAfter: { current: '' },
|
||||
startAfter: '',
|
||||
setStartAfter: (startAfter: string) => set(() => ({
|
||||
startAfter,
|
||||
})),
|
||||
isTruncated: { current: false },
|
||||
previewOnlineDriveFileRef: { current: undefined },
|
||||
})
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ export type WebsiteCrawlSliceShape = {
|
|||
setStep: (step: CrawlStep) => void
|
||||
previewIndex: number
|
||||
setPreviewIndex: (index: number) => void
|
||||
previewWebsitePageRef: React.MutableRefObject<CrawlResultItem | undefined>
|
||||
previewWebsitePageRef: React.RefObject<CrawlResultItem | undefined>
|
||||
}
|
||||
|
||||
export const createWebsiteCrawlSlice: StateCreator<WebsiteCrawlSliceShape> = (set, get) => {
|
||||
|
|
|
|||
|
|
@ -11,22 +11,21 @@ import Toast from '@/app/components/base/toast'
|
|||
import type { RAGPipelineVariables } from '@/models/pipeline'
|
||||
import { useConfigurations, useInitialData } from '@/app/components/rag-pipeline/hooks/use-input-fields'
|
||||
import { generateZodSchema } from '@/app/components/base/form/form-scenarios/base/utils'
|
||||
import { CrawlStep } from '@/models/datasets'
|
||||
|
||||
const I18N_PREFIX = 'datasetCreation.stepOne.website'
|
||||
|
||||
type OptionsProps = {
|
||||
variables: RAGPipelineVariables
|
||||
isRunning: boolean
|
||||
step: CrawlStep
|
||||
runDisabled?: boolean
|
||||
controlFoldOptions?: number
|
||||
onSubmit: (data: Record<string, any>) => void
|
||||
}
|
||||
|
||||
const Options = ({
|
||||
variables,
|
||||
isRunning,
|
||||
step,
|
||||
runDisabled,
|
||||
controlFoldOptions,
|
||||
onSubmit,
|
||||
}: OptionsProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
|
@ -62,12 +61,18 @@ const Options = ({
|
|||
const [fold, {
|
||||
toggle: foldToggle,
|
||||
setTrue: foldHide,
|
||||
setFalse: foldShow,
|
||||
}] = useBoolean(false)
|
||||
|
||||
useEffect(() => {
|
||||
if (controlFoldOptions !== 0)
|
||||
// When the step change
|
||||
if (step !== CrawlStep.init)
|
||||
foldHide()
|
||||
}, [controlFoldOptions])
|
||||
else
|
||||
foldShow()
|
||||
}, [step])
|
||||
|
||||
const isRunning = useMemo(() => step === CrawlStep.running, [step])
|
||||
|
||||
return (
|
||||
<form
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
'use client'
|
||||
import React, { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import React, { useCallback, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { CrawlResultItem } from '@/models/datasets'
|
||||
import { CrawlStep } from '@/models/datasets'
|
||||
|
|
@ -24,8 +24,7 @@ import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-so
|
|||
import { useDataSourceStore, useDataSourceStoreWithSelector } from '../store'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useModalContextSelector } from '@/context/modal-context'
|
||||
import { CredentialTypeEnum } from '@/app/components/plugins/plugin-auth'
|
||||
import { noop } from 'lodash-es'
|
||||
import { useGetDataSourceAuth } from '@/service/use-datasource'
|
||||
|
||||
const I18N_PREFIX = 'datasetCreation.stepOne.website'
|
||||
|
||||
|
|
@ -33,15 +32,16 @@ export type WebsiteCrawlProps = {
|
|||
nodeId: string
|
||||
nodeData: DataSourceNodeType
|
||||
isInPipeline?: boolean
|
||||
onCredentialChange: (credentialId: string) => void
|
||||
}
|
||||
|
||||
const WebsiteCrawl = ({
|
||||
nodeId,
|
||||
nodeData,
|
||||
isInPipeline = false,
|
||||
onCredentialChange,
|
||||
}: WebsiteCrawlProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [controlFoldOptions, setControlFoldOptions] = useState<number>(0)
|
||||
const [totalNum, setTotalNum] = useState(0)
|
||||
const [crawledNum, setCrawledNum] = useState(0)
|
||||
const [crawlErrorMessage, setCrawlErrorMessage] = useState('')
|
||||
|
|
@ -52,12 +52,20 @@ const WebsiteCrawl = ({
|
|||
step,
|
||||
checkedCrawlResult,
|
||||
previewIndex,
|
||||
currentCredentialId,
|
||||
} = useDataSourceStoreWithSelector(useShallow(state => ({
|
||||
crawlResult: state.crawlResult,
|
||||
step: state.step,
|
||||
checkedCrawlResult: state.websitePages,
|
||||
previewIndex: state.previewIndex,
|
||||
currentCredentialId: state.currentCredentialId,
|
||||
})))
|
||||
|
||||
const { data: dataSourceAuth } = useGetDataSourceAuth({
|
||||
pluginId: nodeData.plugin_id,
|
||||
provider: nodeData.provider_name,
|
||||
})
|
||||
|
||||
const dataSourceStore = useDataSourceStore()
|
||||
|
||||
const usePreProcessingParams = useRef(!isInPipeline ? usePublishedPipelinePreProcessingParams : useDraftPipelinePreProcessingParams)
|
||||
|
|
@ -66,33 +74,6 @@ const WebsiteCrawl = ({
|
|||
node_id: nodeId,
|
||||
}, !!pipelineId && !!nodeId)
|
||||
|
||||
useEffect(() => {
|
||||
if (step !== CrawlStep.init)
|
||||
setControlFoldOptions(Date.now())
|
||||
}, [step])
|
||||
|
||||
useEffect(() => {
|
||||
const {
|
||||
setStep,
|
||||
setCrawlResult,
|
||||
setWebsitePages,
|
||||
setPreviewIndex,
|
||||
setCurrentWebsite,
|
||||
currentNodeIdRef,
|
||||
} = dataSourceStore.getState()
|
||||
if (nodeId !== currentNodeIdRef.current) {
|
||||
setStep(CrawlStep.init)
|
||||
setCrawlResult(undefined)
|
||||
setCurrentWebsite(undefined)
|
||||
setWebsitePages([])
|
||||
setPreviewIndex(-1)
|
||||
setCrawledNum(0)
|
||||
setTotalNum(0)
|
||||
setCrawlErrorMessage('')
|
||||
currentNodeIdRef.current = nodeId
|
||||
}
|
||||
}, [nodeId])
|
||||
|
||||
const isInit = step === CrawlStep.init
|
||||
const isCrawlFinished = step === CrawlStep.finished
|
||||
const isRunning = step === CrawlStep.running
|
||||
|
|
@ -113,7 +94,7 @@ const WebsiteCrawl = ({
|
|||
}, [dataSourceStore])
|
||||
|
||||
const handleRun = useCallback(async (value: Record<string, any>) => {
|
||||
const { setStep, setCrawlResult } = dataSourceStore.getState()
|
||||
const { setStep, setCrawlResult, currentCredentialId } = dataSourceStore.getState()
|
||||
|
||||
setStep(CrawlStep.running)
|
||||
ssePost(
|
||||
|
|
@ -122,6 +103,7 @@ const WebsiteCrawl = ({
|
|||
body: {
|
||||
inputs: value,
|
||||
datasource_type: DatasourceType.websiteCrawl,
|
||||
credential_id: currentCredentialId,
|
||||
response_mode: 'streaming',
|
||||
},
|
||||
},
|
||||
|
|
@ -165,33 +147,29 @@ const WebsiteCrawl = ({
|
|||
})
|
||||
}, [setShowAccountSettingModal])
|
||||
|
||||
const handleCredentialChange = useCallback((credentialId: string) => {
|
||||
setCrawledNum(0)
|
||||
setTotalNum(0)
|
||||
setCrawlErrorMessage('')
|
||||
onCredentialChange(credentialId)
|
||||
}, [dataSourceStore, onCredentialChange])
|
||||
|
||||
return (
|
||||
<div className='flex flex-col'>
|
||||
<Header
|
||||
// todo: delete mock data
|
||||
docTitle='How to use?'
|
||||
docLink='https://docs.dify.ai'
|
||||
onClickConfiguration={handleSetting}
|
||||
pluginName={nodeData.datasource_label}
|
||||
currentCredentialId={'12345678'}
|
||||
onCredentialChange={noop}
|
||||
credentials={[{
|
||||
avatar_url: 'https://cloud.dify.ai/logo/logo.svg',
|
||||
credential: {
|
||||
credentials: '......',
|
||||
},
|
||||
id: '12345678',
|
||||
is_default: true,
|
||||
name: 'test123',
|
||||
type: CredentialTypeEnum.API_KEY,
|
||||
}]}
|
||||
currentCredentialId={currentCredentialId}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
credentials={dataSourceAuth?.result || []}
|
||||
/>
|
||||
<div className='mt-2 rounded-xl border border-components-panel-border bg-background-default-subtle'>
|
||||
<Options
|
||||
variables={paramsConfig?.variables || []}
|
||||
isRunning={isRunning}
|
||||
runDisabled={isFetchingParams}
|
||||
controlFoldOptions={controlFoldOptions}
|
||||
step={step}
|
||||
runDisabled={!currentCredentialId || isFetchingParams}
|
||||
onSubmit={handleSubmit}
|
||||
/>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-so
|
|||
import { useDataSourceStore, useDataSourceStoreWithSelector } from './data-source/store'
|
||||
import type { DataSourceNotionPageMap, DataSourceNotionWorkspace } from '@/models/common'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { CrawlStep } from '@/models/datasets'
|
||||
|
||||
export const useAddDocumentsSteps = () => {
|
||||
const { t } = useTranslation()
|
||||
|
|
@ -87,21 +88,19 @@ export const useLocalFile = () => {
|
|||
}
|
||||
}
|
||||
|
||||
export const useOnlineDocuments = () => {
|
||||
export const useOnlineDocument = () => {
|
||||
const {
|
||||
documentsData,
|
||||
currentWorkspaceId,
|
||||
onlineDocuments,
|
||||
currentDocument,
|
||||
} = useDataSourceStoreWithSelector(useShallow(state => ({
|
||||
documentsData: state.documentsData,
|
||||
currentWorkspaceId: state.currentWorkspaceId,
|
||||
onlineDocuments: state.onlineDocuments,
|
||||
currentDocument: state.currentDocument,
|
||||
})))
|
||||
const dataSourceStore = useDataSourceStore()
|
||||
|
||||
const currentWorkspace = documentsData.find(workspace => workspace.workspace_id === currentWorkspaceId)
|
||||
const currentWorkspace = documentsData[0]
|
||||
|
||||
const PagesMapAndSelectedPagesId: DataSourceNotionPageMap = useMemo(() => {
|
||||
const pagesMap = (documentsData || []).reduce((prev: DataSourceNotionPageMap, next: DataSourceNotionWorkspace) => {
|
||||
|
|
@ -122,12 +121,28 @@ export const useOnlineDocuments = () => {
|
|||
setCurrentDocument(undefined)
|
||||
}, [dataSourceStore])
|
||||
|
||||
const clearOnlineDocumentData = useCallback(() => {
|
||||
const {
|
||||
setDocumentsData,
|
||||
setSearchValue,
|
||||
setSelectedPagesId,
|
||||
setOnlineDocuments,
|
||||
setCurrentDocument,
|
||||
} = dataSourceStore.getState()
|
||||
setDocumentsData([])
|
||||
setSearchValue('')
|
||||
setSelectedPagesId(new Set())
|
||||
setOnlineDocuments([])
|
||||
setCurrentDocument(undefined)
|
||||
}, [dataSourceStore])
|
||||
|
||||
return {
|
||||
currentWorkspace,
|
||||
onlineDocuments,
|
||||
currentDocument,
|
||||
PagesMapAndSelectedPagesId,
|
||||
hidePreviewOnlineDocument,
|
||||
clearOnlineDocumentData,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -147,10 +162,26 @@ export const useWebsiteCrawl = () => {
|
|||
setPreviewIndex(-1)
|
||||
}, [dataSourceStore])
|
||||
|
||||
const clearWebsiteCrawlData = useCallback(() => {
|
||||
const {
|
||||
setStep,
|
||||
setCrawlResult,
|
||||
setWebsitePages,
|
||||
setPreviewIndex,
|
||||
setCurrentWebsite,
|
||||
} = dataSourceStore.getState()
|
||||
setStep(CrawlStep.init)
|
||||
setCrawlResult(undefined)
|
||||
setCurrentWebsite(undefined)
|
||||
setWebsitePages([])
|
||||
setPreviewIndex(-1)
|
||||
}, [dataSourceStore])
|
||||
|
||||
return {
|
||||
websitePages,
|
||||
currentWebsite,
|
||||
hideWebsitePreview,
|
||||
clearWebsiteCrawlData,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -162,14 +193,31 @@ export const useOnlineDrive = () => {
|
|||
fileList: state.fileList,
|
||||
selectedFileKeys: state.selectedFileKeys,
|
||||
})))
|
||||
const dataSourceStore = useDataSourceStore()
|
||||
|
||||
const selectedOnlineDriveFileList = useMemo(() => {
|
||||
return selectedFileKeys.map(key => fileList.find(item => item.key === key)!)
|
||||
}, [fileList, selectedFileKeys])
|
||||
|
||||
const clearOnlineDriveData = useCallback(() => {
|
||||
const {
|
||||
setFileList,
|
||||
setBucket,
|
||||
setPrefix,
|
||||
setKeywords,
|
||||
setSelectedFileKeys,
|
||||
} = dataSourceStore.getState()
|
||||
setFileList([])
|
||||
setBucket('')
|
||||
setPrefix([])
|
||||
setKeywords('')
|
||||
setSelectedFileKeys([])
|
||||
}, [dataSourceStore])
|
||||
|
||||
return {
|
||||
fileList,
|
||||
selectedFileKeys,
|
||||
selectedOnlineDriveFileList,
|
||||
clearOnlineDriveData,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,10 +24,15 @@ import WebsitePreview from './preview/web-preview'
|
|||
import ProcessDocuments from './process-documents'
|
||||
import ChunkPreview from './preview/chunk-preview'
|
||||
import Processing from './processing'
|
||||
import type { InitialDocumentDetail, OnlineDriveFile, PublishedPipelineRunPreviewResponse, PublishedPipelineRunResponse } from '@/models/pipeline'
|
||||
import type {
|
||||
InitialDocumentDetail,
|
||||
OnlineDriveFile,
|
||||
PublishedPipelineRunPreviewResponse,
|
||||
PublishedPipelineRunResponse,
|
||||
} from '@/models/pipeline'
|
||||
import { DatasourceType } from '@/models/pipeline'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { useAddDocumentsSteps, useLocalFile, useOnlineDocuments, useOnlineDrive, useWebsiteCrawl } from './hooks'
|
||||
import { useAddDocumentsSteps, useLocalFile, useOnlineDocument, useOnlineDrive, useWebsiteCrawl } from './hooks'
|
||||
import DataSourceProvider from './data-source/store/provider'
|
||||
import { useDataSourceStore } from './data-source/store'
|
||||
import { useFileUploadConfig } from '@/service/use-common'
|
||||
|
|
@ -67,16 +72,19 @@ const CreateFormPipeline = () => {
|
|||
currentDocument,
|
||||
PagesMapAndSelectedPagesId,
|
||||
hidePreviewOnlineDocument,
|
||||
} = useOnlineDocuments()
|
||||
clearOnlineDocumentData,
|
||||
} = useOnlineDocument()
|
||||
const {
|
||||
websitePages,
|
||||
currentWebsite,
|
||||
hideWebsitePreview,
|
||||
clearWebsiteCrawlData,
|
||||
} = useWebsiteCrawl()
|
||||
const {
|
||||
fileList: onlineDriveFileList,
|
||||
selectedFileKeys,
|
||||
selectedOnlineDriveFileList,
|
||||
clearOnlineDriveData,
|
||||
} = useOnlineDrive()
|
||||
|
||||
const datasourceType = datasource?.nodeData.provider_type
|
||||
|
|
@ -346,6 +354,32 @@ const CreateFormPipeline = () => {
|
|||
}
|
||||
}, [PagesMapAndSelectedPagesId, currentWorkspace?.pages, dataSourceStore, datasourceType])
|
||||
|
||||
const clearDataSourceData = useCallback((dataSource: Datasource) => {
|
||||
if (dataSource.nodeData.provider_type === DatasourceType.onlineDocument)
|
||||
clearOnlineDocumentData()
|
||||
else if (dataSource.nodeData.provider_type === DatasourceType.websiteCrawl)
|
||||
clearWebsiteCrawlData()
|
||||
else if (dataSource.nodeData.provider_type === DatasourceType.onlineDrive)
|
||||
clearOnlineDriveData()
|
||||
}, [])
|
||||
|
||||
const handleSwitchDataSource = useCallback((dataSource: Datasource) => {
|
||||
const {
|
||||
setCurrentCredentialId,
|
||||
currentNodeIdRef,
|
||||
} = dataSourceStore.getState()
|
||||
clearDataSourceData(dataSource)
|
||||
setCurrentCredentialId('')
|
||||
currentNodeIdRef.current = dataSource.nodeId
|
||||
setDatasource(dataSource)
|
||||
}, [dataSourceStore])
|
||||
|
||||
const handleCredentialChange = useCallback((credentialId: string) => {
|
||||
const { setCurrentCredentialId } = dataSourceStore.getState()
|
||||
clearDataSourceData(datasource!)
|
||||
setCurrentCredentialId(credentialId)
|
||||
}, [dataSourceStore, datasource])
|
||||
|
||||
if (isFetchingPipelineInfo) {
|
||||
return (
|
||||
<Loading type='app' />
|
||||
|
|
@ -369,7 +403,7 @@ const CreateFormPipeline = () => {
|
|||
<div className='flex flex-col gap-y-5 pt-4'>
|
||||
<DataSourceOptions
|
||||
datasourceNodeId={datasource?.nodeId || ''}
|
||||
onSelect={setDatasource}
|
||||
onSelect={handleSwitchDataSource}
|
||||
pipelineNodes={(pipelineInfo?.graph.nodes || []) as Node<DataSourceNodeType>[]}
|
||||
/>
|
||||
{datasourceType === DatasourceType.localFile && (
|
||||
|
|
@ -382,18 +416,21 @@ const CreateFormPipeline = () => {
|
|||
<OnlineDocuments
|
||||
nodeId={datasource!.nodeId}
|
||||
nodeData={datasource!.nodeData}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
)}
|
||||
{datasourceType === DatasourceType.websiteCrawl && (
|
||||
<WebsiteCrawl
|
||||
nodeId={datasource!.nodeId}
|
||||
nodeData={datasource!.nodeData}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
)}
|
||||
{datasourceType === DatasourceType.onlineDrive && (
|
||||
<OnlineDrive
|
||||
nodeId={datasource!.nodeId}
|
||||
nodeData={datasource!.nodeData}
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
)}
|
||||
{isShowVectorSpaceFull && (
|
||||
|
|
|
|||
|
|
@ -88,7 +88,6 @@ const DocumentDetail: FC<DocumentDetailProps> = ({ datasetId, documentId }) => {
|
|||
documentId,
|
||||
params: { metadata: 'without' },
|
||||
})
|
||||
console.log('🚀 ~ DocumentDetail ~ documentDetail:', documentDetail)
|
||||
|
||||
const { data: documentMetadata } = useDocumentMetadata({
|
||||
datasetId,
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ const Datasets = ({
|
|||
})
|
||||
const resetDatasetList = useResetDatasetList()
|
||||
const anchorRef = useRef<HTMLDivElement>(null)
|
||||
const observerRef = useRef<IntersectionObserver>()
|
||||
const observerRef = useRef<IntersectionObserver>(null)
|
||||
|
||||
useEffect(() => {
|
||||
document.title = `${t('dataset.knowledge')} - Dify`
|
||||
|
|
@ -51,7 +51,6 @@ const Datasets = ({
|
|||
observerRef.current.observe(anchorRef.current)
|
||||
}
|
||||
return () => observerRef.current?.disconnect()
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [anchorRef])
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
'use client'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useEffect, useMemo, useState } from 'react'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiListUnordered } from '@remixicon/react'
|
||||
import { RiCloseLine, RiListUnordered } from '@remixicon/react'
|
||||
import TemplateEn from './template/template.en.mdx'
|
||||
import TemplateZh from './template/template.zh.mdx'
|
||||
import TemplateJa from './template/template.ja.mdx'
|
||||
|
|
@ -30,6 +30,7 @@ const Doc = ({ appDetail }: IDocProps) => {
|
|||
const { t } = useTranslation()
|
||||
const [toc, setToc] = useState<Array<{ href: string; text: string }>>([])
|
||||
const [isTocExpanded, setIsTocExpanded] = useState(false)
|
||||
const [activeSection, setActiveSection] = useState<string>('')
|
||||
const { theme } = useTheme()
|
||||
|
||||
const variables = appDetail?.model_config?.configs?.prompt_variables || []
|
||||
|
|
@ -59,13 +60,43 @@ const Doc = ({ appDetail }: IDocProps) => {
|
|||
return null
|
||||
}).filter((item): item is { href: string; text: string } => item !== null)
|
||||
setToc(tocItems)
|
||||
if (tocItems.length > 0)
|
||||
setActiveSection(tocItems[0].href.replace('#', ''))
|
||||
}
|
||||
}
|
||||
|
||||
// Run after component has rendered
|
||||
setTimeout(extractTOC, 0)
|
||||
}, [appDetail, locale])
|
||||
|
||||
useEffect(() => {
|
||||
const handleScroll = () => {
|
||||
const scrollContainer = document.querySelector('.overflow-auto')
|
||||
if (!scrollContainer || toc.length === 0)
|
||||
return
|
||||
|
||||
let currentSection = ''
|
||||
toc.forEach((item) => {
|
||||
const targetId = item.href.replace('#', '')
|
||||
const element = document.getElementById(targetId)
|
||||
if (element) {
|
||||
const rect = element.getBoundingClientRect()
|
||||
if (rect.top <= window.innerHeight / 2)
|
||||
currentSection = targetId
|
||||
}
|
||||
})
|
||||
|
||||
if (currentSection && currentSection !== activeSection)
|
||||
setActiveSection(currentSection)
|
||||
}
|
||||
|
||||
const scrollContainer = document.querySelector('.overflow-auto')
|
||||
if (scrollContainer) {
|
||||
scrollContainer.addEventListener('scroll', handleScroll)
|
||||
handleScroll()
|
||||
return () => scrollContainer.removeEventListener('scroll', handleScroll)
|
||||
}
|
||||
}, [toc, activeSection])
|
||||
|
||||
const handleTocClick = (e: React.MouseEvent<HTMLAnchorElement>, item: { href: string; text: string }) => {
|
||||
e.preventDefault()
|
||||
const targetId = item.href.replace('#', '')
|
||||
|
|
@ -82,94 +113,128 @@ const Doc = ({ appDetail }: IDocProps) => {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Template = useMemo(() => {
|
||||
if (appDetail?.mode === 'chat' || appDetail?.mode === 'agent-chat') {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateChatZh appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateChatJa appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
default:
|
||||
return <TemplateChatEn appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
}
|
||||
}
|
||||
if (appDetail?.mode === 'advanced-chat') {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateAdvancedChatZh appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateAdvancedChatJa appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
default:
|
||||
return <TemplateAdvancedChatEn appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
}
|
||||
}
|
||||
if (appDetail?.mode === 'workflow') {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateWorkflowZh appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateWorkflowJa appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
default:
|
||||
return <TemplateWorkflowEn appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
}
|
||||
}
|
||||
if (appDetail?.mode === 'completion') {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateZh appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateJa appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
default:
|
||||
return <TemplateEn appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
}
|
||||
}
|
||||
return null
|
||||
}, [appDetail, locale, variables, inputs])
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
<div className={`fixed right-8 top-32 z-10 transition-all ${isTocExpanded ? 'w-64' : 'w-10'}`}>
|
||||
<div className={`fixed right-20 top-32 z-10 transition-all duration-150 ease-out ${isTocExpanded ? 'w-[280px]' : 'w-11'}`}>
|
||||
{isTocExpanded
|
||||
? (
|
||||
<nav className="toc max-h-[calc(100vh-150px)] w-full overflow-y-auto rounded-lg border border-components-panel-border bg-components-panel-bg p-4 shadow-md">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<h3 className="text-lg font-semibold text-text-primary">{t('appApi.develop.toc')}</h3>
|
||||
<nav className="toc flex max-h-[calc(100vh-150px)] w-full flex-col overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-background-default-hover shadow-xl">
|
||||
<div className="relative z-10 flex items-center justify-between border-b border-components-panel-border-subtle bg-background-default-hover px-4 py-2.5">
|
||||
<span className="text-xs font-medium uppercase tracking-wide text-text-tertiary">
|
||||
{t('appApi.develop.toc')}
|
||||
</span>
|
||||
<button
|
||||
onClick={() => setIsTocExpanded(false)}
|
||||
className="text-text-tertiary hover:text-text-secondary"
|
||||
className="group flex h-6 w-6 items-center justify-center rounded-md transition-colors hover:bg-state-base-hover"
|
||||
aria-label="Close"
|
||||
>
|
||||
✕
|
||||
<RiCloseLine className="h-3 w-3 text-text-quaternary transition-colors group-hover:text-text-secondary" />
|
||||
</button>
|
||||
</div>
|
||||
<ul className="space-y-2">
|
||||
{toc.map((item, index) => (
|
||||
<li key={index}>
|
||||
<a
|
||||
href={item.href}
|
||||
className="text-text-secondary transition-colors duration-200 hover:text-text-primary hover:underline"
|
||||
onClick={e => handleTocClick(e, item)}
|
||||
>
|
||||
{item.text}
|
||||
</a>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
|
||||
<div className="from-components-panel-border-subtle/20 pointer-events-none absolute left-0 right-0 top-[41px] z-10 h-2 bg-gradient-to-b to-transparent"></div>
|
||||
<div className="pointer-events-none absolute left-0 right-0 top-[43px] z-10 h-3 bg-gradient-to-b from-background-default-hover to-transparent"></div>
|
||||
|
||||
<div className="relative flex-1 overflow-y-auto px-3 py-3 pt-1">
|
||||
{toc.length === 0 ? (
|
||||
<div className="px-2 py-8 text-center text-xs text-text-quaternary">
|
||||
{t('appApi.develop.noContent')}
|
||||
</div>
|
||||
) : (
|
||||
<ul className="space-y-0.5">
|
||||
{toc.map((item, index) => {
|
||||
const isActive = activeSection === item.href.replace('#', '')
|
||||
return (
|
||||
<li key={index}>
|
||||
<a
|
||||
href={item.href}
|
||||
onClick={e => handleTocClick(e, item)}
|
||||
className={cn(
|
||||
'group relative flex items-center rounded-md px-3 py-2 text-[13px] transition-all duration-200',
|
||||
isActive
|
||||
? 'bg-state-base-hover font-medium text-text-primary'
|
||||
: 'text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary',
|
||||
)}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
'mr-2 h-1.5 w-1.5 rounded-full transition-all duration-200',
|
||||
isActive
|
||||
? 'scale-100 bg-text-accent'
|
||||
: 'scale-75 bg-components-panel-border',
|
||||
)}
|
||||
/>
|
||||
<span className="flex-1 truncate">
|
||||
{item.text}
|
||||
</span>
|
||||
</a>
|
||||
</li>
|
||||
)
|
||||
})}
|
||||
</ul>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="pointer-events-none absolute bottom-0 left-0 right-0 z-10 h-4 rounded-b-xl bg-gradient-to-t from-background-default-hover to-transparent"></div>
|
||||
</nav>
|
||||
)
|
||||
: (
|
||||
<button
|
||||
onClick={() => setIsTocExpanded(true)}
|
||||
className="flex h-10 w-10 items-center justify-center rounded-full border border-components-panel-border bg-components-button-secondary-bg shadow-md transition-colors duration-200 hover:bg-components-button-secondary-bg-hover"
|
||||
className="group flex h-11 w-11 items-center justify-center rounded-full border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg transition-all duration-150 hover:bg-background-default-hover hover:shadow-xl"
|
||||
aria-label="Open table of contents"
|
||||
>
|
||||
<RiListUnordered className="h-6 w-6 text-components-button-secondary-text" />
|
||||
<RiListUnordered className="h-5 w-5 text-text-tertiary transition-colors group-hover:text-text-secondary" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
<article className={cn('prose-xl prose', theme === Theme.dark && 'prose-invert')} >
|
||||
{(appDetail?.mode === 'chat' || appDetail?.mode === 'agent-chat') && (
|
||||
(() => {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateChatZh appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateChatJa appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
default:
|
||||
return <TemplateChatEn appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
}
|
||||
})()
|
||||
)}
|
||||
{appDetail?.mode === 'advanced-chat' && (
|
||||
(() => {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateAdvancedChatZh appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateAdvancedChatJa appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
default:
|
||||
return <TemplateAdvancedChatEn appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
}
|
||||
})()
|
||||
)}
|
||||
{appDetail?.mode === 'workflow' && (
|
||||
(() => {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateWorkflowZh appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateWorkflowJa appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
default:
|
||||
return <TemplateWorkflowEn appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
}
|
||||
})()
|
||||
)}
|
||||
{appDetail?.mode === 'completion' && (
|
||||
(() => {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateZh appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateJa appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
default:
|
||||
return <TemplateEn appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
}
|
||||
})()
|
||||
)}
|
||||
<article className={cn('prose-xl prose', theme === Theme.dark && 'prose-invert')}>
|
||||
{Template}
|
||||
</article>
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -448,7 +448,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
url='/text-to-audio'
|
||||
method='POST'
|
||||
title='テキストから音声'
|
||||
name='#audio'
|
||||
name='#text-to-audio'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
|
|||
|
|
@ -423,7 +423,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
url='/text-to-audio'
|
||||
method='POST'
|
||||
title='文字转语音'
|
||||
name='#audio'
|
||||
name='#text-to-audio'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
|
|||
|
|
@ -1136,7 +1136,7 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
url='/audio-to-text'
|
||||
method='POST'
|
||||
title='Speech to Text'
|
||||
name='#audio'
|
||||
name='#audio-to-text'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
@ -1187,7 +1187,7 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
url='/text-to-audio'
|
||||
method='POST'
|
||||
title='Text to Audio'
|
||||
name='#audio'
|
||||
name='#text-to-audio'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
|
|||
|
|
@ -1136,7 +1136,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
url='/audio-to-text'
|
||||
method='POST'
|
||||
title='音声からテキストへ'
|
||||
name='#audio'
|
||||
name='#audio-to-text'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
@ -1187,7 +1187,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
url='/text-to-audio'
|
||||
method='POST'
|
||||
title='テキストから音声へ'
|
||||
name='#audio'
|
||||
name='#text-to-audio'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
|
|||
|
|
@ -1174,7 +1174,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
url='/audio-to-text'
|
||||
method='POST'
|
||||
title='语音转文字'
|
||||
name='#audio'
|
||||
name='#audio-to-text'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
@ -1222,7 +1222,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
url='/text-to-audio'
|
||||
method='POST'
|
||||
title='文字转语音'
|
||||
name='#audio'
|
||||
name='#text-to-audio'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
|
|||
|
|
@ -1170,7 +1170,7 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
url='/audio-to-text'
|
||||
method='POST'
|
||||
title='Speech to Text'
|
||||
name='#audio'
|
||||
name='#audio-to-text'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
@ -1221,7 +1221,7 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
url='/text-to-audio'
|
||||
method='POST'
|
||||
title='Text to Audio'
|
||||
name='#audio'
|
||||
name='#text-to-audio'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
|
|||
|
|
@ -1169,7 +1169,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
url='/audio-to-text'
|
||||
method='POST'
|
||||
title='音声からテキストへ'
|
||||
name='#audio'
|
||||
name='#audio-to-text'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
@ -1220,7 +1220,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
url='/text-to-audio'
|
||||
method='POST'
|
||||
title='テキストから音声へ'
|
||||
name='#audio'
|
||||
name='#text-to-audio'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
|
|||
|
|
@ -1185,7 +1185,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
url='/audio-to-text'
|
||||
method='POST'
|
||||
title='语音转文字'
|
||||
name='#audio'
|
||||
name='#audio-to-text'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
@ -1233,7 +1233,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
url='/text-to-audio'
|
||||
method='POST'
|
||||
title='文字转语音'
|
||||
name='#audio'
|
||||
name='#text-to-audio'
|
||||
/>
|
||||
<Row>
|
||||
<Col>
|
||||
|
|
|
|||
|
|
@ -43,7 +43,10 @@ const Card = ({
|
|||
category: AuthCategory.datasource,
|
||||
provider: `${item.plugin_id}/${item.name}`,
|
||||
}
|
||||
const { handleAuthUpdate } = useDataSourceAuthUpdate()
|
||||
const { handleAuthUpdate } = useDataSourceAuthUpdate({
|
||||
pluginId: item.plugin_id,
|
||||
provider: item.name,
|
||||
})
|
||||
const {
|
||||
deleteCredentialId,
|
||||
doingAction,
|
||||
|
|
|
|||
|
|
@ -1,17 +1,28 @@
|
|||
import { useCallback } from 'react'
|
||||
import { useInvalidDataSourceListAuth } from '@/service/use-datasource'
|
||||
import { useInvalidDataSourceAuth, useInvalidDataSourceListAuth } from '@/service/use-datasource'
|
||||
import { useInvalidDefaultDataSourceListAuth } from '@/service/use-datasource'
|
||||
import { useInvalidDataSourceList } from '@/service/use-pipeline'
|
||||
|
||||
export const useDataSourceAuthUpdate = () => {
|
||||
export const useDataSourceAuthUpdate = ({
|
||||
pluginId,
|
||||
provider,
|
||||
}: {
|
||||
pluginId: string
|
||||
provider: string
|
||||
}) => {
|
||||
const invalidateDataSourceListAuth = useInvalidDataSourceListAuth()
|
||||
const invalidDefaultDataSourceListAuth = useInvalidDefaultDataSourceListAuth()
|
||||
const invalidateDataSourceList = useInvalidDataSourceList()
|
||||
const invalidateDataSourceAuth = useInvalidDataSourceAuth({
|
||||
pluginId,
|
||||
provider,
|
||||
})
|
||||
const handleAuthUpdate = useCallback(() => {
|
||||
invalidateDataSourceListAuth()
|
||||
invalidDefaultDataSourceListAuth()
|
||||
invalidateDataSourceList()
|
||||
}, [invalidateDataSourceListAuth, invalidateDataSourceList])
|
||||
invalidateDataSourceAuth()
|
||||
}, [invalidateDataSourceListAuth, invalidateDataSourceList, invalidateDataSourceAuth, invalidDefaultDataSourceListAuth])
|
||||
|
||||
return {
|
||||
handleAuthUpdate,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import { useNodes } from 'reactflow'
|
|||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useDataSourceStore } from '@/app/components/datasets/documents/create-from-pipeline/data-source/store'
|
||||
import { CrawlStep } from '@/models/datasets'
|
||||
|
||||
export const useTestRunSteps = () => {
|
||||
const { t } = useTranslation()
|
||||
|
|
@ -56,3 +58,72 @@ export const useDatasourceOptions = () => {
|
|||
|
||||
return options
|
||||
}
|
||||
|
||||
export const useOnlineDocument = () => {
|
||||
const dataSourceStore = useDataSourceStore()
|
||||
|
||||
const clearOnlineDocumentData = useCallback(() => {
|
||||
const {
|
||||
setDocumentsData,
|
||||
setSearchValue,
|
||||
setSelectedPagesId,
|
||||
setOnlineDocuments,
|
||||
setCurrentDocument,
|
||||
} = dataSourceStore.getState()
|
||||
setDocumentsData([])
|
||||
setSearchValue('')
|
||||
setSelectedPagesId(new Set())
|
||||
setOnlineDocuments([])
|
||||
setCurrentDocument(undefined)
|
||||
}, [dataSourceStore])
|
||||
|
||||
return {
|
||||
clearOnlineDocumentData,
|
||||
}
|
||||
}
|
||||
|
||||
export const useWebsiteCrawl = () => {
|
||||
const dataSourceStore = useDataSourceStore()
|
||||
|
||||
const clearWebsiteCrawlData = useCallback(() => {
|
||||
const {
|
||||
setStep,
|
||||
setCrawlResult,
|
||||
setWebsitePages,
|
||||
setPreviewIndex,
|
||||
setCurrentWebsite,
|
||||
} = dataSourceStore.getState()
|
||||
setStep(CrawlStep.init)
|
||||
setCrawlResult(undefined)
|
||||
setCurrentWebsite(undefined)
|
||||
setWebsitePages([])
|
||||
setPreviewIndex(-1)
|
||||
}, [dataSourceStore])
|
||||
|
||||
return {
|
||||
clearWebsiteCrawlData,
|
||||
}
|
||||
}
|
||||
|
||||
export const useOnlineDrive = () => {
|
||||
const dataSourceStore = useDataSourceStore()
|
||||
|
||||
const clearOnlineDriveData = useCallback(() => {
|
||||
const {
|
||||
setFileList,
|
||||
setBucket,
|
||||
setPrefix,
|
||||
setKeywords,
|
||||
setSelectedFileKeys,
|
||||
} = dataSourceStore.getState()
|
||||
setFileList([])
|
||||
setBucket('')
|
||||
setPrefix([])
|
||||
setKeywords('')
|
||||
setSelectedFileKeys([])
|
||||
}, [dataSourceStore])
|
||||
|
||||
return {
|
||||
clearOnlineDriveData,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
import { useStore as useWorkflowStoreWithSelector } from '@/app/components/workflow/store'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useTestRunSteps } from './hooks'
|
||||
import {
|
||||
useOnlineDocument,
|
||||
useOnlineDrive,
|
||||
useTestRunSteps,
|
||||
useWebsiteCrawl,
|
||||
} from './hooks'
|
||||
import DataSourceOptions from './data-source-options'
|
||||
import LocalFile from '@/app/components/datasets/documents/create-from-pipeline/data-source/local-file'
|
||||
import OnlineDocuments from '@/app/components/datasets/documents/create-from-pipeline/data-source/online-documents'
|
||||
|
|
@ -42,6 +47,10 @@ const TestRunPanel = () => {
|
|||
handleBackStep,
|
||||
} = useTestRunSteps()
|
||||
|
||||
const { clearOnlineDocumentData } = useOnlineDocument()
|
||||
const { clearWebsiteCrawlData } = useWebsiteCrawl()
|
||||
const { clearOnlineDriveData } = useOnlineDrive()
|
||||
|
||||
const datasourceType = datasource?.nodeData.provider_type
|
||||
|
||||
const nextBtnDisabled = useMemo(() => {
|
||||
|
|
@ -67,6 +76,7 @@ const TestRunPanel = () => {
|
|||
if (!datasource)
|
||||
return
|
||||
const datasourceInfoList: Record<string, any>[] = []
|
||||
const credentialId = dataSourceStore.getState().currentCredentialId
|
||||
if (datasourceType === DatasourceType.localFile) {
|
||||
const { id, name, type, size, extension, mime_type } = fileList[0].file
|
||||
const documentInfo = {
|
||||
|
|
@ -86,16 +96,22 @@ const TestRunPanel = () => {
|
|||
const documentInfo = {
|
||||
workspace_id,
|
||||
page: rest,
|
||||
credential_id: credentialId,
|
||||
}
|
||||
datasourceInfoList.push(documentInfo)
|
||||
}
|
||||
if (datasourceType === DatasourceType.websiteCrawl)
|
||||
datasourceInfoList.push(websitePages[0])
|
||||
if (datasourceType === DatasourceType.websiteCrawl) {
|
||||
datasourceInfoList.push({
|
||||
...websitePages[0],
|
||||
credential_id: credentialId,
|
||||
})
|
||||
}
|
||||
if (datasourceType === DatasourceType.onlineDrive) {
|
||||
const { bucket } = dataSourceStore.getState()
|
||||
datasourceInfoList.push({
|
||||
bucket,
|
||||
key: selectedFileKeys[0],
|
||||
credential_id: credentialId,
|
||||
})
|
||||
}
|
||||
handleRun({
|
||||
|
|
@ -106,6 +122,32 @@ const TestRunPanel = () => {
|
|||
})
|
||||
}, [dataSourceStore, datasource, datasourceType, fileList, handleRun, onlineDocuments, selectedFileKeys, websitePages])
|
||||
|
||||
const clearDataSourceData = useCallback((dataSource: Datasource) => {
|
||||
if (dataSource.nodeData.provider_type === DatasourceType.onlineDocument)
|
||||
clearOnlineDocumentData()
|
||||
else if (dataSource.nodeData.provider_type === DatasourceType.websiteCrawl)
|
||||
clearWebsiteCrawlData()
|
||||
else if (dataSource.nodeData.provider_type === DatasourceType.onlineDrive)
|
||||
clearOnlineDriveData()
|
||||
}, [])
|
||||
|
||||
const handleSwitchDataSource = useCallback((dataSource: Datasource) => {
|
||||
const {
|
||||
setCurrentCredentialId,
|
||||
currentNodeIdRef,
|
||||
} = dataSourceStore.getState()
|
||||
clearDataSourceData(dataSource)
|
||||
setCurrentCredentialId('')
|
||||
currentNodeIdRef.current = dataSource.nodeId
|
||||
setDatasource(dataSource)
|
||||
}, [dataSourceStore])
|
||||
|
||||
const handleCredentialChange = useCallback((credentialId: string) => {
|
||||
const { setCurrentCredentialId } = dataSourceStore.getState()
|
||||
clearDataSourceData(datasource!)
|
||||
setCurrentCredentialId(credentialId)
|
||||
}, [dataSourceStore, datasource])
|
||||
|
||||
return (
|
||||
<div
|
||||
className='relative flex h-full w-[480px] flex-col rounded-l-2xl border-y-[0.5px] border-l-[0.5px] border-components-panel-border bg-components-panel-bg shadow-xl shadow-shadow-shadow-1'
|
||||
|
|
@ -119,7 +161,7 @@ const TestRunPanel = () => {
|
|||
<div className='flex flex-col gap-y-4 px-4 py-2'>
|
||||
<DataSourceOptions
|
||||
dataSourceNodeId={datasource?.nodeId || ''}
|
||||
onSelect={setDatasource}
|
||||
onSelect={handleSwitchDataSource}
|
||||
/>
|
||||
{datasourceType === DatasourceType.localFile && (
|
||||
<LocalFile
|
||||
|
|
@ -132,6 +174,7 @@ const TestRunPanel = () => {
|
|||
nodeId={datasource!.nodeId}
|
||||
nodeData={datasource!.nodeData}
|
||||
isInPipeline
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
)}
|
||||
{datasourceType === DatasourceType.websiteCrawl && (
|
||||
|
|
@ -139,6 +182,7 @@ const TestRunPanel = () => {
|
|||
nodeId={datasource!.nodeId}
|
||||
nodeData={datasource!.nodeData}
|
||||
isInPipeline
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
)}
|
||||
{datasourceType === DatasourceType.onlineDrive && (
|
||||
|
|
@ -146,6 +190,7 @@ const TestRunPanel = () => {
|
|||
nodeId={datasource!.nodeId}
|
||||
nodeData={datasource!.nodeData}
|
||||
isInPipeline
|
||||
onCredentialChange={handleCredentialChange}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -218,7 +218,6 @@ export const useShortcuts = (): void => {
|
|||
useKeyPress(
|
||||
'shift',
|
||||
(e) => {
|
||||
console.log('Shift down', e)
|
||||
if (shouldHandleShortcut(e))
|
||||
dimOtherNodes()
|
||||
},
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import type { ValueSelector } from '@/app/components/workflow/types'
|
|||
|
||||
type Props = {
|
||||
className?: string
|
||||
root: { nodeId?: string, nodeName?: string, attrName: string }
|
||||
root: { nodeId?: string, nodeName?: string, attrName: string, attrAlias?: string }
|
||||
payload: StructuredOutput
|
||||
readonly?: boolean
|
||||
onSelect?: (valueSelector: ValueSelector) => void
|
||||
|
|
@ -52,8 +52,7 @@ export const PickerPanelMain: FC<Props> = ({
|
|||
)}
|
||||
<div className='system-sm-medium text-text-secondary'>{root.attrName}</div>
|
||||
</div>
|
||||
{/* It must be object */}
|
||||
<div className='system-xs-regular ml-2 shrink-0 text-text-tertiary'>object</div>
|
||||
<div className='system-xs-regular ml-2 truncate text-text-tertiary' title={root.attrAlias || 'object'}>{root.attrAlias || 'object'}</div>
|
||||
</div>
|
||||
{fieldNames.map(name => (
|
||||
<Field
|
||||
|
|
|
|||
|
|
@ -217,6 +217,7 @@ const findExceptVarInObject = (obj: any, filterVar: (payload: Var, selector: Val
|
|||
variable: obj.variable,
|
||||
type: isFile ? VarType.file : VarType.object,
|
||||
children: childrenResult,
|
||||
alias: obj.alias,
|
||||
}
|
||||
|
||||
return res
|
||||
|
|
@ -412,6 +413,7 @@ const formatItem = (
|
|||
? `array[${output.items?.type.slice(0, 1).toLocaleLowerCase()}${output.items?.type.slice(1)}]`
|
||||
: `${output.type.slice(0, 1).toLocaleLowerCase()}${output.type.slice(1)}`,
|
||||
description: output.description,
|
||||
alias: output?.properties?.dify_builtin_type?.enum?.[0],
|
||||
children: output.type === 'object' ? {
|
||||
schema: {
|
||||
type: 'object',
|
||||
|
|
@ -518,41 +520,8 @@ const formatItem = (
|
|||
|
||||
case BlockEnum.DataSource: {
|
||||
const payload = data as DataSourceNodeType
|
||||
const baseVars = DataSourceNodeDefault.getOutputVars?.(payload, ragVars) || []
|
||||
if (payload.output_schema?.properties) {
|
||||
const dynamicOutputSchema: any[] = []
|
||||
Object.keys(payload.output_schema.properties).forEach((outputKey) => {
|
||||
const output = payload.output_schema!.properties[outputKey]
|
||||
const dataType = output?.properties?.dify_builtin_type ? output.properties.dify_builtin_type.enum[0] : output.type
|
||||
dynamicOutputSchema.push({
|
||||
variable: outputKey,
|
||||
type: dataType === 'array'
|
||||
? `array[${output.items?.type.slice(0, 1).toLocaleLowerCase()}${output.items?.type.slice(1)}]`
|
||||
: `${dataType.slice(0, 1).toLocaleLowerCase()}${dataType.slice(1)}`,
|
||||
description: output.description,
|
||||
children: output.type === 'object' ? {
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: Object.fromEntries(
|
||||
Object.entries(output.properties).filter(([key]) => key !== 'dify_builtin_type'),
|
||||
),
|
||||
},
|
||||
} : undefined,
|
||||
})
|
||||
})
|
||||
res.vars = [
|
||||
...baseVars,
|
||||
...dynamicOutputSchema,
|
||||
{
|
||||
variable: 'output',
|
||||
type: VarType.object,
|
||||
children: dynamicOutputSchema,
|
||||
},
|
||||
]
|
||||
}
|
||||
else {
|
||||
res.vars = baseVars
|
||||
}
|
||||
const dataSourceVars = DataSourceNodeDefault.getOutputVars?.(payload, ragVars) || []
|
||||
res.vars = dataSourceVars
|
||||
break
|
||||
}
|
||||
|
||||
|
|
@ -952,7 +921,7 @@ export const getVarType = ({
|
|||
const isStructuredOutputVar = !!targetVar.children?.schema?.properties
|
||||
if (isStructuredOutputVar) {
|
||||
if (valueSelector.length === 2) { // root
|
||||
return VarType.object
|
||||
return targetVar.alias || VarType.object
|
||||
}
|
||||
let currProperties = targetVar.children.schema;
|
||||
(valueSelector as ValueSelector).slice(2).forEach((key, i) => {
|
||||
|
|
|
|||
|
|
@ -173,7 +173,7 @@ const Item: FC<ItemProps> = ({
|
|||
<div title={itemData.des} className='system-sm-medium ml-1 w-0 grow truncate text-text-secondary'>{itemData.variable.split('.').slice(-1)[0]}</div>
|
||||
)}
|
||||
</div>
|
||||
<div className='ml-1 shrink-0 text-xs font-normal capitalize text-text-tertiary'>{itemData.type}</div>
|
||||
<div className='ml-1 shrink-0 text-xs font-normal capitalize text-text-tertiary'>{itemData.alias || itemData.type}</div>
|
||||
{
|
||||
(isObj || isStructureOutput) && (
|
||||
<ChevronRight className={cn('ml-0.5 h-3 w-3 text-text-quaternary', isHovering && 'text-text-tertiary')} />
|
||||
|
|
@ -186,7 +186,7 @@ const Item: FC<ItemProps> = ({
|
|||
}}>
|
||||
{(isStructureOutput || isObj) && (
|
||||
<PickerStructurePanel
|
||||
root={{ nodeId, nodeName: title, attrName: itemData.variable }}
|
||||
root={{ nodeId, nodeName: title, attrName: itemData.variable, attrAlias: itemData.alias }}
|
||||
payload={structuredOutput!}
|
||||
onHovering={setIsChildrenHovering}
|
||||
onSelect={(valueSelector) => {
|
||||
|
|
|
|||
|
|
@ -58,6 +58,29 @@ const nodeDefault: NodeDefault<DataSourceNodeType> = {
|
|||
provider_type,
|
||||
} = payload
|
||||
const isLocalFile = provider_type === DataSourceClassification.localFile
|
||||
const dynamicOutputSchema: any[] = []
|
||||
if (payload.output_schema?.properties) {
|
||||
Object.keys(payload.output_schema.properties).forEach((outputKey) => {
|
||||
const output = payload.output_schema!.properties[outputKey]
|
||||
const dataType = output.type
|
||||
dynamicOutputSchema.push({
|
||||
variable: outputKey,
|
||||
type: dataType === 'array'
|
||||
? `array[${output.items?.type.slice(0, 1).toLocaleLowerCase()}${output.items?.type.slice(1)}]`
|
||||
: `${dataType.slice(0, 1).toLocaleLowerCase()}${dataType.slice(1)}`,
|
||||
description: output.description,
|
||||
alias: output?.properties?.dify_builtin_type?.enum?.[0],
|
||||
children: output.type === 'object' ? {
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: Object.fromEntries(
|
||||
Object.entries(output.properties).filter(([key]) => key !== 'dify_builtin_type'),
|
||||
),
|
||||
},
|
||||
} : undefined,
|
||||
})
|
||||
})
|
||||
}
|
||||
return [
|
||||
...COMMON_OUTPUT.map(item => ({ variable: item.name, type: item.type })),
|
||||
...(
|
||||
|
|
@ -66,6 +89,7 @@ const nodeDefault: NodeDefault<DataSourceNodeType> = {
|
|||
: []
|
||||
),
|
||||
...ragVars,
|
||||
...dynamicOutputSchema,
|
||||
]
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,14 +15,13 @@ import StructureOutputItem from '@/app/components/workflow/nodes/_base/component
|
|||
import TagInput from '@/app/components/base/tag-input'
|
||||
import { useNodesReadOnly } from '@/app/components/workflow/hooks'
|
||||
import { useConfig } from './hooks/use-config'
|
||||
import type { StructuredOutput } from '@/app/components/workflow/nodes/llm/types'
|
||||
import { Type } from '@/app/components/workflow/nodes/llm/types'
|
||||
import {
|
||||
COMMON_OUTPUT,
|
||||
} from './constants'
|
||||
import { useStore } from '@/app/components/workflow/store'
|
||||
import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
|
||||
import ToolForm from '../tool/components/tool-form'
|
||||
import { wrapStructuredVarItem } from '@/app/components/workflow/utils/tool'
|
||||
|
||||
const Panel: FC<NodePanelProps<DataSourceNodeType>> = ({ id, data }) => {
|
||||
const { t } = useTranslation()
|
||||
|
|
@ -49,24 +48,7 @@ const Panel: FC<NodePanelProps<DataSourceNodeType>> = ({ id, data }) => {
|
|||
|
||||
const pipelineId = useStore(s => s.pipelineId)
|
||||
const setShowInputFieldPanel = useStore(s => s.setShowInputFieldPanel)
|
||||
const wrapStructuredVarItem = (outputItem: any): StructuredOutput => {
|
||||
const dataType = outputItem.value?.properties?.dify_builtin_type ? outputItem.value?.properties?.dify_builtin_type.enum[0] : Type.object
|
||||
const properties = Object.fromEntries(
|
||||
Object.entries(outputItem.value?.properties || {}).filter(([key]) => key !== 'dify_builtin_type'),
|
||||
) as Record<string, any>
|
||||
return {
|
||||
schema: {
|
||||
type: dataType,
|
||||
properties: {
|
||||
[outputItem.name]: {
|
||||
...outputItem.value,
|
||||
properties,
|
||||
},
|
||||
},
|
||||
additionalProperties: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div >
|
||||
{
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ export type Field = {
|
|||
items?: ArrayItems // Array has items. Define the item type
|
||||
enum?: SchemaEnumType // Enum values
|
||||
additionalProperties?: false // Required in object by api. Just set false
|
||||
alias?: string // Alias of the field
|
||||
}
|
||||
|
||||
export type StructuredOutput = {
|
||||
|
|
|
|||
|
|
@ -10,9 +10,12 @@ export const checkNodeValid = (_payload: LLMNodeType) => {
|
|||
}
|
||||
|
||||
export const getFieldType = (field: Field) => {
|
||||
const { type, items } = field
|
||||
if (type !== Type.array || !items)
|
||||
const { type, items, alias } = field
|
||||
if (type !== Type.array || !items) {
|
||||
if (alias)
|
||||
return alias
|
||||
return type
|
||||
}
|
||||
|
||||
return ArrayType[items.type]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ import type { NodePanelProps } from '@/app/components/workflow/types'
|
|||
import Loading from '@/app/components/base/loading'
|
||||
import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars'
|
||||
import StructureOutputItem from '@/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/show'
|
||||
import { Type } from '../llm/types'
|
||||
import { useStore } from '@/app/components/workflow/store'
|
||||
import { wrapStructuredVarItem } from '@/app/components/workflow/utils/tool'
|
||||
|
||||
const i18nPrefix = 'workflow.nodes.tool'
|
||||
|
||||
|
|
@ -121,15 +121,7 @@ const Panel: FC<NodePanelProps<ToolNodeType>> = ({
|
|||
{outputItem.value?.type === 'object' ? (
|
||||
<StructureOutputItem
|
||||
rootClassName='code-sm-semibold text-text-secondary'
|
||||
payload={{
|
||||
schema: {
|
||||
type: Type.object,
|
||||
properties: {
|
||||
[outputItem.name]: outputItem.value,
|
||||
},
|
||||
additionalProperties: false,
|
||||
},
|
||||
}} />
|
||||
payload={wrapStructuredVarItem(outputItem)} />
|
||||
) : (
|
||||
<VarItem
|
||||
name={outputItem.name}
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ const DebugAndPreview = () => {
|
|||
<div
|
||||
ref={containerRef}
|
||||
className={cn(
|
||||
'relative flex h-full flex-col rounded-l-2xl border border-r-0 border-components-panel-border bg-chatbot-bg shadow-xl',
|
||||
'relative flex h-full flex-col rounded-l-2xl border border-r-0 border-components-panel-border bg-components-panel-bg shadow-xl',
|
||||
)}
|
||||
style={{ width: `${panelWidth}px` }}
|
||||
>
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ const ShortcutsName = ({
|
|||
keys.map(key => (
|
||||
<div
|
||||
key={key}
|
||||
className='system-kbd flex h-4 w-4 items-center justify-center rounded-[4px] bg-components-kbd-bg-gray capitalize'
|
||||
className='system-kbd flex h-4 min-w-4 items-center justify-center rounded-[4px] bg-components-kbd-bg-gray capitalize'
|
||||
>
|
||||
{getKeyboardKeyNameBySystem(key)}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -302,6 +302,7 @@ export type Var = {
|
|||
isLoopVariable?: boolean
|
||||
nodeId?: string
|
||||
isRagVariable?: boolean
|
||||
alias?: string
|
||||
}
|
||||
|
||||
export type NodeOutPutVar = {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import type { ToolNodeType } from '../nodes/tool/types'
|
|||
import { CollectionType } from '@/app/components/tools/types'
|
||||
import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
|
||||
import { canFindTool } from '@/utils'
|
||||
import type { StructuredOutput } from '@/app/components/workflow/nodes/llm/types'
|
||||
import { Type } from '@/app/components/workflow/nodes/llm/types'
|
||||
|
||||
export const getToolCheckParams = (
|
||||
toolData: ToolNodeType,
|
||||
|
|
@ -41,3 +43,23 @@ export const getToolCheckParams = (
|
|||
language,
|
||||
}
|
||||
}
|
||||
|
||||
export const wrapStructuredVarItem = (outputItem: any): StructuredOutput => {
|
||||
const dataType = Type.object
|
||||
const properties = Object.fromEntries(
|
||||
Object.entries(outputItem.value?.properties || {}).filter(([key]) => key !== 'dify_builtin_type'),
|
||||
) as Record<string, any>
|
||||
return {
|
||||
schema: {
|
||||
type: dataType,
|
||||
properties: {
|
||||
[outputItem.name]: {
|
||||
...outputItem.value,
|
||||
properties,
|
||||
alias: outputItem.value?.properties?.dify_builtin_type?.enum?.[0],
|
||||
},
|
||||
},
|
||||
additionalProperties: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
@use '../../themes/light';
|
||||
@use '../../themes/dark';
|
||||
@use '../../themes/markdown-light';
|
||||
@use '../../themes/markdown-dark';
|
||||
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ const translation = {
|
|||
after: '',
|
||||
},
|
||||
},
|
||||
contentEnableLabel: 'Moderater Inhalt aktiviert',
|
||||
contentEnableLabel: 'Inhaltsmoderation aktiviert',
|
||||
},
|
||||
fileUpload: {
|
||||
title: 'Datei-Upload',
|
||||
|
|
|
|||
|
|
@ -166,6 +166,10 @@ const translation = {
|
|||
description: 'Gibt an, ob das web app Symbol zum Ersetzen 🤖 in der freigegebenen Anwendung verwendet werden soll',
|
||||
},
|
||||
importFromDSLUrlPlaceholder: 'DSL-Link hier einfügen',
|
||||
dslUploader: {
|
||||
button: 'Datei per Drag & Drop ablegen oder',
|
||||
browse: 'Durchsuchen',
|
||||
},
|
||||
duplicate: 'Duplikat',
|
||||
importFromDSL: 'Import von DSL',
|
||||
importDSL: 'DSL-Datei importieren',
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ const translation = {
|
|||
uploader: {
|
||||
title: 'Textdatei hochladen',
|
||||
button: 'Dateien und Ordner hierher ziehen oder klicken',
|
||||
buttonSingleFile: 'Datei hierher ziehen oder klicken',
|
||||
browse: 'Durchsuchen',
|
||||
tip: 'Unterstützt {{supportTypes}}. Maximal {{size}}MB pro Datei.',
|
||||
validation: {
|
||||
|
|
|
|||
|
|
@ -287,6 +287,18 @@ const translation = {
|
|||
zoomTo50: 'Auf 50% vergrößern',
|
||||
zoomTo100: 'Auf 100% vergrößern',
|
||||
zoomToFit: 'An Bildschirm anpassen',
|
||||
selectionAlignment: 'Ausrichtung der Auswahl',
|
||||
alignLeft: 'Links',
|
||||
alignTop: 'Nach oben',
|
||||
distributeVertical: 'Vertikaler Raum',
|
||||
alignBottom: 'Unteres',
|
||||
distributeHorizontal: 'Horizontaler Raum',
|
||||
vertical: 'Senkrecht',
|
||||
alignMiddle: 'Mitte',
|
||||
alignCenter: 'Mitte',
|
||||
alignRight: 'Rechts',
|
||||
alignNodes: 'Knoten ausrichten',
|
||||
horizontal: 'Horizontal',
|
||||
},
|
||||
panel: {
|
||||
userInputField: 'Benutzereingabefeld',
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue