mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/mcp-authentication
This commit is contained in:
commit
9c6d059227
|
|
@ -1521,6 +1521,14 @@ def transform_datasource_credentials():
|
|||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not firecrawl_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
|
|
@ -1576,6 +1584,14 @@ def transform_datasource_credentials():
|
|||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not jina_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import flask_restx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
|
|
@ -8,7 +7,8 @@ from werkzeug.exceptions import Forbidden
|
|||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
|
|
@ -57,6 +57,8 @@ class BaseApiKeyListResource(Resource):
|
|||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(
|
||||
|
|
@ -69,8 +71,10 @@ class BaseApiKeyListResource(Resource):
|
|||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
|
|
@ -108,6 +112,8 @@ class BaseApiKeyResource(Resource):
|
|||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import console_ns
|
||||
|
|
@ -17,6 +17,8 @@ class ComplianceApi(Resource):
|
|||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("doc_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -45,6 +45,79 @@ def _validate_name(name: str) -> str:
|
|||
return name
|
||||
|
||||
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get supported retrieval methods based on vector database type.
|
||||
|
||||
Args:
|
||||
vector_type: Vector database type, can be None
|
||||
is_mock: Whether this is a Mock API, affects MILVUS handling
|
||||
|
||||
Returns:
|
||||
Dictionary containing supported retrieval methods
|
||||
|
||||
Raises:
|
||||
ValueError: If vector_type is None or unsupported
|
||||
"""
|
||||
if vector_type is None:
|
||||
raise ValueError("Vector store type is not configured.")
|
||||
|
||||
# Define vector database types that only support semantic search
|
||||
semantic_only_types = {
|
||||
VectorType.RELYT,
|
||||
VectorType.TIDB_VECTOR,
|
||||
VectorType.CHROMA,
|
||||
VectorType.PGVECTO_RS,
|
||||
VectorType.VIKINGDB,
|
||||
VectorType.UPSTASH,
|
||||
}
|
||||
|
||||
# Define vector database types that support all retrieval methods
|
||||
full_search_types = {
|
||||
VectorType.QDRANT,
|
||||
VectorType.WEAVIATE,
|
||||
VectorType.OPENSEARCH,
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.MYSCALE,
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
VectorType.ELASTICSEARCH_JA,
|
||||
VectorType.PGVECTOR,
|
||||
VectorType.VASTBASE,
|
||||
VectorType.TIDB_ON_QDRANT,
|
||||
VectorType.LINDORM,
|
||||
VectorType.COUCHBASE,
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.OCEANBASE,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.HUAWEI_CLOUD,
|
||||
VectorType.TENCENT,
|
||||
VectorType.MATRIXONE,
|
||||
VectorType.CLICKZETTA,
|
||||
VectorType.BAIDU,
|
||||
VectorType.ALIBABACLOUD_MYSQL,
|
||||
}
|
||||
|
||||
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
full_methods = {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
|
||||
if vector_type == VectorType.MILVUS:
|
||||
return semantic_methods if is_mock else full_methods
|
||||
|
||||
if vector_type in semantic_only_types:
|
||||
return semantic_methods
|
||||
elif vector_type in full_search_types:
|
||||
return full_methods
|
||||
else:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
|
||||
|
||||
@console_ns.route("/datasets")
|
||||
class DatasetListApi(Resource):
|
||||
@api.doc("get_datasets")
|
||||
|
|
@ -777,50 +850,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.VASTBASE
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.MILVUS
|
||||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
| VectorType.ALIBABACLOUD_MYSQL
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH,
|
||||
RetrievalMethod.HYBRID_SEARCH,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
|
||||
|
|
@ -833,49 +863,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.VASTBASE
|
||||
| VectorType.LINDORM
|
||||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.TENCENT
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
| VectorType.BAIDU
|
||||
| VectorType.ALIBABACLOUD_MYSQL
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH,
|
||||
RetrievalMethod.HYBRID_SEARCH,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
|
|
@ -21,6 +19,7 @@ from core.errors.error import (
|
|||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
|
@ -31,6 +30,7 @@ logger = logging.getLogger(__name__)
|
|||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
|
@ -57,11 +57,12 @@ class DatasetsHitTestingBase:
|
|||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=cast(Account, current_user),
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
|
|
|
|||
|
|
@ -2,15 +2,15 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import InstalledApp
|
||||
from models.account import Account
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
|
@ -24,6 +24,8 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
|||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(
|
||||
|
|
@ -56,6 +58,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
|||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||
feature = FeatureService.get_system_features()
|
||||
if feature.webapp_auth.enabled:
|
||||
assert isinstance(current_user, Account)
|
||||
app_id = installed_app.app_id
|
||||
app_code = AppService.get_app_code_by_id(app_id)
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
|
@ -47,6 +47,8 @@ class APIBasedExtensionAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tenant_id = current_user.current_tenant_id
|
||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
|
||||
|
|
@ -68,6 +70,8 @@ class APIBasedExtensionAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||
|
|
@ -95,6 +99,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
@ -119,6 +125,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
@ -146,6 +154,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api, console_ns
|
||||
|
|
@ -23,6 +23,8 @@ class FeatureApi(Resource):
|
|||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get feature configuration for current tenant"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
|
|
@ -16,6 +14,7 @@ from core.file import helpers as file_helpers
|
|||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
|
|
@ -65,7 +64,8 @@ class RemoteFileUploadApi(Resource):
|
|||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user = cast(Account, current_user)
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
|
@ -24,6 +24,8 @@ class TagListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(dataset_tag_fields)
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_type = request.args.get("type", type=str, default="")
|
||||
keyword = request.args.get("keyword", default=None, type=str)
|
||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
||||
|
|
@ -34,8 +36,10 @@ class TagListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -59,9 +63,11 @@ class TagUpdateDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -81,9 +87,11 @@ class TagUpdateDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, tag_id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
|
@ -97,8 +105,10 @@ class TagBindingCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -123,8 +133,10 @@ class TagBindingDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
|
|
@ -21,7 +21,9 @@ class AgentProviderListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
|
@ -43,7 +45,9 @@ class AgentProviderApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -6,10 +5,18 @@ from controllers.console import api, console_ns
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
def _current_account_with_tenant() -> tuple[Account, str]:
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
assert tenant_id is not None
|
||||
return current_user, tenant_id
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@api.doc("create_endpoint")
|
||||
|
|
@ -34,7 +41,7 @@ class EndpointCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -51,7 +58,7 @@ class EndpointCreateApi(Resource):
|
|||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
|
|
@ -80,7 +87,7 @@ class EndpointListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -93,7 +100,7 @@ class EndpointListApi(Resource):
|
|||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
|
|
@ -123,7 +130,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -138,7 +145,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
|
|
@ -165,7 +172,7 @@ class EndpointDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -177,9 +184,7 @@ class EndpointDeleteApi(Resource):
|
|||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -207,7 +212,7 @@ class EndpointUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -224,7 +229,7 @@ class EndpointUpdateApi(Resource):
|
|||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
|
|
@ -250,7 +255,7 @@ class EndpointEnableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -262,9 +267,7 @@ class EndpointEnableApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -285,7 +288,7 @@ class EndpointDisableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -297,7 +300,5 @@ class EndpointDisableApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from urllib import parse
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
|
|
@ -26,7 +25,7 @@ from controllers.console.wraps import (
|
|||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
|
@ -24,7 +23,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account, Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ from functools import wraps
|
|||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import AccountStatus
|
||||
from libs.login import current_user
|
||||
from models.account import Account, AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
|
@ -25,11 +25,16 @@ P = ParamSpec("P")
|
|||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _current_account() -> Account:
|
||||
assert isinstance(current_user, Account)
|
||||
return current_user
|
||||
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# check account initialization
|
||||
account = current_user
|
||||
account = _current_account()
|
||||
|
||||
if account.status == AccountStatus.UNINITIALIZED:
|
||||
raise AccountNotInitializedError()
|
||||
|
|
@ -75,7 +80,9 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
|||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
|
@ -87,7 +94,10 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
|
|
@ -128,7 +138,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
|
|
@ -151,10 +163,13 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{current_user.current_tenant_id}"
|
||||
key = f"rate_limit_{tenant_id}"
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
|
||||
|
|
@ -165,7 +180,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
if request_count > knowledge_rate_limit.limit:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
|
|
@ -185,14 +200,17 @@ def cloud_utm_record(view: Callable[P, R]):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: dict = json.loads(utm_info)
|
||||
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
|
||||
OperationService.record_utm(tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
@ -271,7 +289,9 @@ def enable_change_email(view: Callable[P, R]):
|
|||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if features.is_allow_transfer_workspace:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
@ -284,7 +304,9 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
|||
def knowledge_pipeline_publish_enabled(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if features.knowledge_pipeline.publish_enabled:
|
||||
return view(*args, **kwargs)
|
||||
abort(403)
|
||||
|
|
|
|||
|
|
@ -70,7 +70,11 @@ class ModelConfigConverter:
|
|||
if not model_mode:
|
||||
model_mode = LLMMode.CHAT
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
try:
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE])
|
||||
except ValueError:
|
||||
# Fall back to CHAT mode if the stored value is invalid
|
||||
model_mode = LLMMode.CHAT
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import enum
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
|
@ -218,7 +218,7 @@ class DatasourceLabel(BaseModel):
|
|||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class DatasourceInvokeFrom(Enum):
|
||||
class DatasourceInvokeFrom(StrEnum):
|
||||
"""
|
||||
Enum class for datasource invoke
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1414,7 +1414,7 @@ class ProviderConfiguration(BaseModel):
|
|||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type.value == FormType.SECRET_INPUT:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum, auto
|
||||
from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
|
|
@ -7,7 +7,7 @@ from core.model_runtime.entities.common_entities import I18nObject
|
|||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
class ConfigurateMethod(Enum):
|
||||
class ConfigurateMethod(StrEnum):
|
||||
"""
|
||||
Enum class for configurate method of provider model.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ class BasePluginClient:
|
|||
except Exception:
|
||||
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
|
||||
|
||||
logger.error("Error in stream reponse for plugin %s", rep.__dict__)
|
||||
logger.error("Error in stream response for plugin %s", rep.__dict__)
|
||||
self._handle_plugin_daemon_error(error.error_type, error.message)
|
||||
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
|
||||
if rep.data is None:
|
||||
|
|
|
|||
|
|
@ -1046,7 +1046,7 @@ class ProviderManager:
|
|||
"""
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
if credential_form_schema.type.value == FormType.SECRET_INPUT:
|
||||
if credential_form_schema.type == FormType.SECRET_INPUT:
|
||||
secret_input_form_variables.append(credential_form_schema.variable)
|
||||
|
||||
return secret_input_form_variables
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class RetrievalService:
|
|||
@classmethod
|
||||
def retrieve(
|
||||
cls,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
|
|
@ -56,7 +56,7 @@ class RetrievalService:
|
|||
# Optimize multithreading with thread pools
|
||||
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
|
||||
futures = []
|
||||
if retrieval_method == "keyword_search":
|
||||
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
|
||||
futures.append(
|
||||
executor.submit(
|
||||
cls.keyword_search,
|
||||
|
|
@ -220,7 +220,7 @@ class RetrievalService:
|
|||
score_threshold: float | None,
|
||||
reranking_model: dict | None,
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
exceptions: list,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DatasourceStreamEvent(Enum):
|
||||
class DatasourceStreamEvent(StrEnum):
|
||||
"""
|
||||
Datasource Stream event
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -49,7 +50,8 @@ class UnstructuredWordExtractor(BaseExtractor):
|
|||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import logging
|
|||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -46,7 +47,8 @@ class UnstructuredEmailExtractor(BaseExtractor):
|
|||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import logging
|
|||
|
||||
import pypandoc # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -40,7 +41,8 @@ class UnstructuredEpubExtractor(BaseExtractor):
|
|||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -32,7 +33,8 @@ class UnstructuredMarkdownExtractor(BaseExtractor):
|
|||
elements = partition_md(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -31,7 +32,8 @@ class UnstructuredMsgExtractor(BaseExtractor):
|
|||
elements = partition_msg(filename=self._file_path)
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
|
@ -32,7 +33,8 @@ class UnstructuredXmlExtractor(BaseExtractor):
|
|||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
|
||||
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
|
||||
documents = []
|
||||
for chunk in chunks:
|
||||
text = chunk.text.strip()
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||
from configs import dify_config
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.splitter.fixed_text_splitter import (
|
||||
EnhanceRecursiveCharacterTextSplitter,
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
|
|
@ -49,7 +50,7 @@ class BaseIndexProcessor(ABC):
|
|||
@abstractmethod
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
|
|
@ -106,7 +107,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
|
|
@ -161,7 +162,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
|||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import Document, QAStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
|
|
@ -141,7 +142,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
|
||||
def retrieve(
|
||||
self,
|
||||
retrieval_method: str,
|
||||
retrieval_method: RetrievalMethod,
|
||||
query: str,
|
||||
dataset: Dataset,
|
||||
top_k: int,
|
||||
|
|
|
|||
|
|
@ -364,7 +364,7 @@ class DatasetRetrieval:
|
|||
top_k = retrieval_model_config["top_k"]
|
||||
# get retrieval method
|
||||
if dataset.indexing_technique == "economy":
|
||||
retrieval_method = "keyword_search"
|
||||
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
else:
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
# get reranking model
|
||||
|
|
@ -623,7 +623,7 @@ class DatasetRetrieval:
|
|||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search",
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RetrievalMethod(Enum):
|
||||
class RetrievalMethod(StrEnum):
|
||||
SEMANTIC_SEARCH = "semantic_search"
|
||||
FULL_TEXT_SEARCH = "full_text_search"
|
||||
HYBRID_SEARCH = "hybrid_search"
|
||||
|
|
|
|||
|
|
@ -76,7 +76,8 @@ class MCPToolProviderController(ToolProviderController):
|
|||
)
|
||||
for remote_mcp_tool in remote_mcp_tools
|
||||
]
|
||||
|
||||
if not db_provider.icon:
|
||||
raise ValueError("Database provider icon is required")
|
||||
return cls(
|
||||
entity=ToolProviderEntityWithPlugin(
|
||||
identity=ToolProviderIdentity(
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search",
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search",
|
||||
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models.
|
|||
"""
|
||||
|
||||
import json
|
||||
from decimal import Decimal
|
||||
from typing import cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
|
|
@ -118,10 +119,10 @@ class ModelInvocationUtils:
|
|||
model_response="",
|
||||
prompt_tokens=prompt_tokens,
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
answer_unit_price=Decimal(),
|
||||
answer_price_unit=Decimal(),
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
total_price=Decimal(),
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
|
|
@ -152,7 +153,7 @@ class ModelInvocationUtils:
|
|||
raise InvokeModelError(f"Invoke error: {e}")
|
||||
|
||||
# update tool model invoke
|
||||
tool_model_invoke.model_response = response.message.content
|
||||
tool_model_invoke.model_response = str(response.message.content)
|
||||
if response.usage:
|
||||
tool_model_invoke.answer_tokens = response.usage.completion_tokens
|
||||
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum, StrEnum
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class NodeState(Enum):
|
||||
class NodeState(StrEnum):
|
||||
"""State of a node or edge during workflow execution."""
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ When limits are exceeded, the layer automatically aborts execution.
|
|||
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
|
@ -24,7 +24,7 @@ from core.workflow.graph_events import (
|
|||
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
|
||||
|
||||
class LimitType(Enum):
|
||||
class LimitType(StrEnum):
|
||||
"""Types of execution limits that can be exceeded."""
|
||||
|
||||
STEP_LIMIT = "step_limit"
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Literal, Union
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
|
|
@ -63,7 +64,7 @@ class RetrievalSetting(BaseModel):
|
|||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
|
||||
search_method: RetrievalMethod
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
|
|
|||
|
|
@ -37,10 +37,11 @@ config.set_main_option('sqlalchemy.url', get_engine_url())
|
|||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
from models.base import Base
|
||||
from models.base import TypeBase
|
||||
|
||||
|
||||
def get_metadata():
|
||||
return Base.metadata
|
||||
return TypeBase.metadata
|
||||
|
||||
def include_object(object, name, type_, reflected, compare_to):
|
||||
if type_ == "foreign_key_constraint":
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@ from sqlalchemy import DateTime, String, func
|
|||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from models.base import Base
|
||||
from models.base import TypeBase
|
||||
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DataSourceOauthBinding(Base):
|
||||
class DataSourceOauthBinding(TypeBase):
|
||||
__tablename__ = "data_source_oauth_bindings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="source_binding_pkey"),
|
||||
|
|
@ -19,17 +19,25 @@ class DataSourceOauthBinding(Base):
|
|||
sa.Index("source_info_idx", "source_info", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
source_info = mapped_column(JSONB, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
source_info: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
|
||||
|
||||
|
||||
class DataSourceApiKeyAuthBinding(Base):
|
||||
class DataSourceApiKeyAuthBinding(TypeBase):
|
||||
__tablename__ = "data_source_api_key_auth_bindings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
|
||||
|
|
@ -37,14 +45,22 @@ class DataSourceApiKeyAuthBinding(Base):
|
|||
sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
credentials = mapped_column(sa.Text, nullable=True) # JSON
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
credentials: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) # JSON
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
|
@ -52,7 +68,7 @@ class DataSourceApiKeyAuthBinding(Base):
|
|||
"tenant_id": self.tenant_id,
|
||||
"category": self.category,
|
||||
"provider": self.provider,
|
||||
"credentials": json.loads(self.credentials),
|
||||
"credentials": json.loads(self.credentials) if self.credentials else None,
|
||||
"created_at": self.created_at.timestamp(),
|
||||
"updated_at": self.updated_at.timestamp(),
|
||||
"disabled": self.disabled,
|
||||
|
|
|
|||
|
|
@ -6,41 +6,43 @@ from sqlalchemy import DateTime, String
|
|||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.base import Base
|
||||
from models.base import TypeBase
|
||||
|
||||
|
||||
class CeleryTask(Base):
|
||||
class CeleryTask(TypeBase):
|
||||
"""Task result/status."""
|
||||
|
||||
__tablename__ = "celery_taskmeta"
|
||||
|
||||
id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
|
||||
task_id = mapped_column(String(155), unique=True)
|
||||
status = mapped_column(String(50), default=states.PENDING)
|
||||
result = mapped_column(sa.PickleType, nullable=True)
|
||||
date_done = mapped_column(
|
||||
id: Mapped[int] = mapped_column(
|
||||
sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True, init=False
|
||||
)
|
||||
task_id: Mapped[str] = mapped_column(String(155), unique=True)
|
||||
status: Mapped[str] = mapped_column(String(50), default=states.PENDING)
|
||||
result: Mapped[bytes | None] = mapped_column(sa.PickleType, nullable=True, default=None)
|
||||
date_done: Mapped[datetime | None] = mapped_column(
|
||||
DateTime,
|
||||
default=lambda: naive_utc_now(),
|
||||
onupdate=lambda: naive_utc_now(),
|
||||
default=naive_utc_now,
|
||||
onupdate=naive_utc_now,
|
||||
nullable=True,
|
||||
)
|
||||
traceback = mapped_column(sa.Text, nullable=True)
|
||||
name = mapped_column(String(155), nullable=True)
|
||||
args = mapped_column(sa.LargeBinary, nullable=True)
|
||||
kwargs = mapped_column(sa.LargeBinary, nullable=True)
|
||||
worker = mapped_column(String(155), nullable=True)
|
||||
retries: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
queue = mapped_column(String(155), nullable=True)
|
||||
traceback: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
name: Mapped[str | None] = mapped_column(String(155), nullable=True, default=None)
|
||||
args: Mapped[bytes | None] = mapped_column(sa.LargeBinary, nullable=True, default=None)
|
||||
kwargs: Mapped[bytes | None] = mapped_column(sa.LargeBinary, nullable=True, default=None)
|
||||
worker: Mapped[str | None] = mapped_column(String(155), nullable=True, default=None)
|
||||
retries: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
|
||||
queue: Mapped[str | None] = mapped_column(String(155), nullable=True, default=None)
|
||||
|
||||
|
||||
class CeleryTaskSet(Base):
|
||||
class CeleryTaskSet(TypeBase):
|
||||
"""TaskSet result."""
|
||||
|
||||
__tablename__ = "celery_tasksetmeta"
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
|
||||
sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True, init=False
|
||||
)
|
||||
taskset_id = mapped_column(String(155), unique=True)
|
||||
result = mapped_column(sa.PickleType, nullable=True)
|
||||
date_done: Mapped[datetime | None] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True)
|
||||
taskset_id: Mapped[str] = mapped_column(String(155), unique=True)
|
||||
result: Mapped[bytes | None] = mapped_column(sa.PickleType, nullable=True, default=None)
|
||||
date_done: Mapped[datetime | None] = mapped_column(DateTime, default=naive_utc_now, nullable=True)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
|
@ -13,7 +14,7 @@ from core.helper import encrypter
|
|||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
from models.base import Base, TypeBase
|
||||
from models.base import TypeBase
|
||||
|
||||
from .engine import db
|
||||
from .model import Account, App, Tenant
|
||||
|
|
@ -42,28 +43,28 @@ class ToolOAuthSystemClient(TypeBase):
|
|||
|
||||
|
||||
# tenant level tool oauth client params (client_id, client_secret, etc.)
|
||||
class ToolOAuthTenantClient(Base):
|
||||
class ToolOAuthTenantClient(TypeBase):
|
||||
__tablename__ = "tool_oauth_tenant_clients"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), init=False)
|
||||
# oauth params of the tool provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False, init=False)
|
||||
|
||||
@property
|
||||
def oauth_params(self) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
|
||||
|
||||
|
||||
class BuiltinToolProvider(Base):
|
||||
class BuiltinToolProvider(TypeBase):
|
||||
"""
|
||||
This table stores the tool provider information for built-in tools for each tenant.
|
||||
"""
|
||||
|
|
@ -75,37 +76,45 @@ class BuiltinToolProvider(Base):
|
|||
)
|
||||
|
||||
# id of the tool provider
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
name: Mapped[str] = mapped_column(
|
||||
String(256), nullable=False, server_default=sa.text("'API KEY 1'::character varying")
|
||||
String(256),
|
||||
nullable=False,
|
||||
server_default=sa.text("'API KEY 1'::character varying"),
|
||||
)
|
||||
# id of the tenant
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
||||
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
# who created this tool provider
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# name of the tool provider
|
||||
provider: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
# credential of the tool provider
|
||||
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
|
||||
encrypted_credentials: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
|
||||
# credential type, e.g., "api-key", "oauth2"
|
||||
credential_type: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, server_default=sa.text("'api-key'::character varying")
|
||||
String(32), nullable=False, server_default=sa.text("'api-key'::character varying"), default="api-key"
|
||||
)
|
||||
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))
|
||||
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"), default=-1)
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
if not self.encrypted_credentials:
|
||||
return {}
|
||||
return cast(dict[str, Any], json.loads(self.encrypted_credentials))
|
||||
|
||||
|
||||
class ApiToolProvider(Base):
|
||||
class ApiToolProvider(TypeBase):
|
||||
"""
|
||||
The table stores the api providers.
|
||||
"""
|
||||
|
|
@ -116,31 +125,43 @@ class ApiToolProvider(Base):
|
|||
sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# name of the api provider
|
||||
name = mapped_column(String(255), nullable=False, server_default=sa.text("'API KEY 1'::character varying"))
|
||||
name: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'API KEY 1'::character varying"),
|
||||
)
|
||||
# icon
|
||||
icon: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# original schema
|
||||
schema = mapped_column(sa.Text, nullable=False)
|
||||
schema: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# who created this tool
|
||||
user_id = mapped_column(StringUUID, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# description of the provider
|
||||
description = mapped_column(sa.Text, nullable=False)
|
||||
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# json format tools
|
||||
tools_str = mapped_column(sa.Text, nullable=False)
|
||||
tools_str: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# json format credentials
|
||||
credentials_str = mapped_column(sa.Text, nullable=False)
|
||||
credentials_str: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# privacy policy
|
||||
privacy_policy = mapped_column(String(255), nullable=True)
|
||||
privacy_policy: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
# custom_disclaimer
|
||||
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def schema_type(self) -> "ApiProviderSchemaType":
|
||||
|
|
@ -189,7 +210,7 @@ class ToolLabelBinding(TypeBase):
|
|||
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
|
||||
|
||||
class WorkflowToolProvider(Base):
|
||||
class WorkflowToolProvider(TypeBase):
|
||||
"""
|
||||
The table stores the workflow providers.
|
||||
"""
|
||||
|
|
@ -201,7 +222,7 @@ class WorkflowToolProvider(Base):
|
|||
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# name of the workflow provider
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# label of the workflow provider
|
||||
|
|
@ -219,15 +240,19 @@ class WorkflowToolProvider(Base):
|
|||
# description of the provider
|
||||
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# parameter configuration
|
||||
parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]")
|
||||
parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]", default="[]")
|
||||
# privacy policy
|
||||
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="")
|
||||
privacy_policy: Mapped[str | None] = mapped_column(String(255), nullable=True, server_default="", default=None)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -252,7 +277,7 @@ class WorkflowToolProvider(Base):
|
|||
return db.session.query(App).where(App.id == self.app_id).first()
|
||||
|
||||
|
||||
class MCPToolProvider(Base):
|
||||
class MCPToolProvider(TypeBase):
|
||||
"""
|
||||
The table stores the mcp providers.
|
||||
"""
|
||||
|
|
@ -265,7 +290,7 @@ class MCPToolProvider(Base):
|
|||
sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# name of the mcp provider
|
||||
name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# server identifier of the mcp provider
|
||||
|
|
@ -275,27 +300,33 @@ class MCPToolProvider(Base):
|
|||
# hash of server_url for uniqueness check
|
||||
server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# icon of the mcp provider
|
||||
icon: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
icon: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# who created this tool
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# encrypted credentials
|
||||
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
|
||||
encrypted_credentials: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
# authed
|
||||
authed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||
# tools
|
||||
tools: Mapped[str] = mapped_column(sa.Text, nullable=False, default="[]")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"), default=30.0)
|
||||
sse_read_timeout: Mapped[float] = mapped_column(
|
||||
sa.Float, nullable=False, server_default=sa.text("300"), default=300.0
|
||||
)
|
||||
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
|
||||
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
|
||||
# encrypted headers for MCP server requests
|
||||
encrypted_headers: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
encrypted_headers: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
|
|
@ -306,9 +337,11 @@ class MCPToolProvider(Base):
|
|||
|
||||
@property
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
if not self.encrypted_credentials:
|
||||
return {}
|
||||
try:
|
||||
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
|
||||
except Exception:
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
@property
|
||||
|
|
@ -321,6 +354,7 @@ class MCPToolProvider(Base):
|
|||
def provider_icon(self) -> Mapping[str, str] | str:
|
||||
from core.file import helpers as file_helpers
|
||||
|
||||
assert self.icon
|
||||
try:
|
||||
return json.loads(self.icon)
|
||||
except json.JSONDecodeError:
|
||||
|
|
@ -419,7 +453,7 @@ class MCPToolProvider(Base):
|
|||
return encrypter.decrypt(self.credentials)
|
||||
|
||||
|
||||
class ToolModelInvoke(Base):
|
||||
class ToolModelInvoke(TypeBase):
|
||||
"""
|
||||
store the invoke logs from tool invoke
|
||||
"""
|
||||
|
|
@ -427,37 +461,47 @@ class ToolModelInvoke(Base):
|
|||
__tablename__ = "tool_model_invokes"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# who invoke this tool
|
||||
user_id = mapped_column(StringUUID, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# provider
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# type
|
||||
tool_type = mapped_column(String(40), nullable=False)
|
||||
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# tool name
|
||||
tool_name = mapped_column(String(128), nullable=False)
|
||||
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
# invoke parameters
|
||||
model_parameters = mapped_column(sa.Text, nullable=False)
|
||||
model_parameters: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# prompt messages
|
||||
prompt_messages = mapped_column(sa.Text, nullable=False)
|
||||
prompt_messages: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# invoke response
|
||||
model_response = mapped_column(sa.Text, nullable=False)
|
||||
model_response: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
prompt_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
|
||||
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_price = mapped_column(sa.Numeric(10, 7))
|
||||
answer_unit_price: Mapped[Decimal] = mapped_column(sa.Numeric(10, 4), nullable=False)
|
||||
answer_price_unit: Mapped[Decimal] = mapped_column(
|
||||
sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")
|
||||
)
|
||||
provider_response_latency: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric(10, 7))
|
||||
currency: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
@deprecated
|
||||
class ToolConversationVariables(Base):
|
||||
class ToolConversationVariables(TypeBase):
|
||||
"""
|
||||
store the conversation variables from tool invoke
|
||||
"""
|
||||
|
|
@ -470,18 +514,26 @@ class ToolConversationVariables(Base):
|
|||
sa.Index("conversation_id_idx", "conversation_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# conversation user id
|
||||
user_id = mapped_column(StringUUID, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# tenant id
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# conversation id
|
||||
conversation_id = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# variables pool
|
||||
variables_str = mapped_column(sa.Text, nullable=False)
|
||||
variables_str: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
|
|
@ -519,7 +571,7 @@ class ToolFile(TypeBase):
|
|||
|
||||
|
||||
@deprecated
|
||||
class DeprecatedPublishedAppTool(Base):
|
||||
class DeprecatedPublishedAppTool(TypeBase):
|
||||
"""
|
||||
The table stores the apps published as a tool for each person.
|
||||
"""
|
||||
|
|
@ -530,26 +582,34 @@ class DeprecatedPublishedAppTool(Base):
|
|||
sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
# id of the app
|
||||
app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
|
||||
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# who published this tool
|
||||
description = mapped_column(sa.Text, nullable=False)
|
||||
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# llm_description of the tool, for LLM
|
||||
llm_description = mapped_column(sa.Text, nullable=False)
|
||||
llm_description: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# query description, query will be seem as a parameter of the tool,
|
||||
# to describe this parameter to llm, we need this field
|
||||
query_description = mapped_column(sa.Text, nullable=False)
|
||||
query_description: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
# query name, the name of the query parameter
|
||||
query_name = mapped_column(String(40), nullable=False)
|
||||
query_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# name of the tool provider
|
||||
tool_name = mapped_column(String(40), nullable=False)
|
||||
tool_name: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
# author
|
||||
author = mapped_column(String(40), nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
author: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def description_i18n(self) -> "I18nObject":
|
||||
|
|
|
|||
|
|
@ -4,46 +4,58 @@ import sqlalchemy as sa
|
|||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from models.base import Base
|
||||
from models.base import TypeBase
|
||||
|
||||
from .engine import db
|
||||
from .model import Message
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class SavedMessage(Base):
|
||||
class SavedMessage(TypeBase):
|
||||
__tablename__ = "saved_messages"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="saved_message_pkey"),
|
||||
sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
message_id = mapped_column(StringUUID, nullable=False)
|
||||
created_by_role = mapped_column(
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_by_role: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return db.session.query(Message).where(Message.id == self.message_id).first()
|
||||
|
||||
|
||||
class PinnedConversation(Base):
|
||||
class PinnedConversation(TypeBase):
|
||||
__tablename__ = "pinned_conversations"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
|
||||
sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_by_role = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
|
||||
created_by_role: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'end_user'::character varying"),
|
||||
)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ dependencies = [
|
|||
"markdown~=3.5.1",
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.7.25",
|
||||
"opik~=1.8.72",
|
||||
"opentelemetry-api==1.27.0",
|
||||
"opentelemetry-distro==0.48b0",
|
||||
"opentelemetry-exporter-otlp==1.27.0",
|
||||
|
|
|
|||
|
|
@ -26,10 +26,9 @@ class ApiKeyAuthService:
|
|||
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
|
||||
args["credentials"]["config"]["api_key"] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
|
||||
data_source_api_key_binding.tenant_id = tenant_id
|
||||
data_source_api_key_binding.category = args["category"]
|
||||
data_source_api_key_binding.provider = args["provider"]
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
|
||||
)
|
||||
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
|
@ -48,6 +47,8 @@ class ApiKeyAuthService:
|
|||
)
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
if not data_source_api_key_bindings.credentials:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return credentials
|
||||
|
||||
|
|
|
|||
|
|
@ -1470,7 +1470,7 @@ class DocumentService:
|
|||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
if not dataset.retrieval_model:
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
|
|
@ -1752,7 +1752,7 @@ class DocumentService:
|
|||
# dataset.collection_binding_id = dataset_collection_binding.id
|
||||
# if not dataset.retrieval_model:
|
||||
# default_retrieval_model = {
|
||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
# "search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
# "reranking_enable": False,
|
||||
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
# "top_k": 2,
|
||||
|
|
@ -2205,7 +2205,7 @@ class DocumentService:
|
|||
retrieval_model = knowledge_config.retrieval_model
|
||||
else:
|
||||
retrieval_model = RetrievalModel(
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH,
|
||||
reranking_enable=False,
|
||||
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||
top_k=4,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from typing import Literal
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class ParentMode(StrEnum):
|
||||
FULL_DOC = "full-doc"
|
||||
|
|
@ -95,7 +97,7 @@ class WeightModel(BaseModel):
|
|||
|
||||
|
||||
class RetrievalModel(BaseModel):
|
||||
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
|
||||
search_method: RetrievalMethod
|
||||
reranking_enable: bool
|
||||
reranking_model: RerankingModel | None = None
|
||||
reranking_mode: str | None = None
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from typing import Literal
|
|||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
icon: str
|
||||
|
|
@ -83,7 +85,7 @@ class RetrievalSetting(BaseModel):
|
|||
Retrieval Setting.
|
||||
"""
|
||||
|
||||
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
|
||||
search_method: RetrievalMethod
|
||||
top_k: int
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ from core.model_runtime.entities.provider_entities import (
|
|||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class CustomConfigurationStatus(Enum):
|
||||
class CustomConfigurationStatus(StrEnum):
|
||||
"""
|
||||
Enum class for custom configuration status.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -88,9 +88,9 @@ class ExternalDatasetService:
|
|||
else:
|
||||
raise ValueError(f"invalid endpoint: {endpoint}")
|
||||
try:
|
||||
response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
|
||||
except Exception:
|
||||
raise ValueError(f"failed to connect to the endpoint: {endpoint}")
|
||||
response = ssrf_proxy.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
|
||||
except Exception as e:
|
||||
raise ValueError(f"failed to connect to the endpoint: {endpoint}") from e
|
||||
if response.status_code == 502:
|
||||
raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}")
|
||||
if response.status_code == 404:
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class HitTestingService:
|
|||
if metadata_condition and not document_ids_filter:
|
||||
return cls.compact_retrieve_response(query, [])
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from flask_login import current_user
|
|||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
|
||||
|
|
@ -164,7 +165,7 @@ class RagPipelineTransformService:
|
|||
if retrieval_model:
|
||||
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
|
||||
if indexing_technique == "economy":
|
||||
retrieval_setting.search_method = "keyword_search"
|
||||
retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH
|
||||
knowledge_configuration.retrieval_model = retrieval_setting
|
||||
else:
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ class ApiToolManageService:
|
|||
description=extra_info.get("description", ""),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str={},
|
||||
credentials_str="{}",
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -683,7 +683,7 @@ class BuiltinToolManageService:
|
|||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
original_params = encrypter.decrypt(custom_client_params.oauth_params)
|
||||
new_params: dict = {
|
||||
new_params = {
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -188,6 +188,8 @@ class MCPToolManageService:
|
|||
raise
|
||||
|
||||
user = mcp_provider.load_user()
|
||||
if not mcp_provider.icon:
|
||||
raise ValueError("MCP provider icon is required")
|
||||
return ToolProviderApiEntity(
|
||||
id=mcp_provider.id,
|
||||
name=mcp_provider.name,
|
||||
|
|
|
|||
|
|
@ -152,7 +152,8 @@ class ToolTransformService:
|
|||
|
||||
if decrypt_credentials:
|
||||
credentials = db_provider.credentials
|
||||
|
||||
if not db_provider.tenant_id:
|
||||
raise ValueError(f"Required tenant_id is missing for BuiltinToolProvider with id {db_provider.id}")
|
||||
# init tool configuration
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class TestAccountInitialization:
|
|||
return "success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.current_user", mock_user):
|
||||
with patch("controllers.console.wraps._current_account", return_value=mock_user):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
|
|
@ -77,7 +77,7 @@ class TestAccountInitialization:
|
|||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with patch("controllers.console.wraps.current_user", mock_user):
|
||||
with patch("controllers.console.wraps._current_account", return_value=mock_user):
|
||||
with pytest.raises(AccountNotInitializedError):
|
||||
protected_view()
|
||||
|
||||
|
|
@ -163,7 +163,7 @@ class TestBillingResourceLimits:
|
|||
return "member_added"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.current_user"):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
result = add_member()
|
||||
|
||||
|
|
@ -185,7 +185,7 @@ class TestBillingResourceLimits:
|
|||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
add_member()
|
||||
|
|
@ -207,7 +207,7 @@ class TestBillingResourceLimits:
|
|||
|
||||
# Test 1: Should reject when source is datasets
|
||||
with app.test_request_context("/?source=datasets"):
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
upload_document()
|
||||
|
|
@ -215,7 +215,7 @@ class TestBillingResourceLimits:
|
|||
|
||||
# Test 2: Should allow when source is not datasets
|
||||
with app.test_request_context("/?source=other"):
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
result = upload_document()
|
||||
assert result == "document_uploaded"
|
||||
|
|
@ -239,7 +239,7 @@ class TestRateLimiting:
|
|||
return "knowledge_success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.current_user"):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||
):
|
||||
|
|
@ -271,7 +271,7 @@ class TestRateLimiting:
|
|||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import os
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
|
||||
|
||||
|
||||
def test_firecrawl_web_extractor_crawl_mode(mocker):
|
||||
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
|
||||
url = "https://firecrawl.dev"
|
||||
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
|
||||
base_url = "https://api.firecrawl.dev"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from unittest import mock
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.extractor import notion_extractor
|
||||
|
||||
user_id = "user1"
|
||||
|
|
@ -57,7 +59,7 @@ def _remove_multiple_new_lines(text):
|
|||
return text.strip()
|
||||
|
||||
|
||||
def test_notion_page(mocker):
|
||||
def test_notion_page(mocker: MockerFixture):
|
||||
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
|
||||
mocked_notion_page = {
|
||||
"object": "list",
|
||||
|
|
@ -77,7 +79,7 @@ def test_notion_page(mocker):
|
|||
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
|
||||
|
||||
|
||||
def test_notion_database(mocker):
|
||||
def test_notion_database(mocker: MockerFixture):
|
||||
page_title_list = ["page1", "page2", "page3"]
|
||||
mocked_notion_database = {
|
||||
"object": "list",
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
import redis
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.model_manager import LBModelManager
|
||||
|
|
@ -39,7 +40,7 @@ def lb_model_manager():
|
|||
return lb_model_manager
|
||||
|
||||
|
||||
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
|
||||
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
|
||||
# initialize redis client
|
||||
redis_client.initialize(redis.Redis())
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,13 @@ from core.entities.provider_entities import (
|
|||
)
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormOption,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
|
|
@ -306,3 +312,174 @@ class TestProviderConfiguration:
|
|||
|
||||
# Assert
|
||||
assert credentials == {"openai_api_key": "test_key"}
|
||||
|
||||
def test_extract_secret_variables_with_secret_input(self, provider_configuration):
|
||||
"""Test extracting secret variables from credential form schemas"""
|
||||
# Arrange
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="secret_token",
|
||||
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert len(secret_variables) == 2
|
||||
assert "api_key" in secret_variables
|
||||
assert "secret_token" in secret_variables
|
||||
assert "model_name" not in secret_variables
|
||||
|
||||
def test_extract_secret_variables_no_secret_input(self, provider_configuration):
|
||||
"""Test extracting secret variables when no secret input fields exist"""
|
||||
# Arrange
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=FormType.SELECT,
|
||||
required=True,
|
||||
options=[FormOption(label=I18nObject(en_US="0.1", zh_Hans="0.1"), value="0.1")],
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert len(secret_variables) == 0
|
||||
|
||||
def test_extract_secret_variables_empty_list(self, provider_configuration):
|
||||
"""Test extracting secret variables from empty credential form schemas"""
|
||||
# Arrange
|
||||
credential_form_schemas = []
|
||||
|
||||
# Act
|
||||
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert len(secret_variables) == 0
|
||||
|
||||
@patch("core.entities.provider_configuration.encrypter")
|
||||
def test_obfuscated_credentials_with_secret_variables(self, mock_encrypter, provider_configuration):
|
||||
"""Test obfuscating credentials with secret variables"""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"api_key": "sk-1234567890abcdef",
|
||||
"model_name": "gpt-4",
|
||||
"secret_token": "secret_value_123",
|
||||
"temperature": "0.7",
|
||||
}
|
||||
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="api_key",
|
||||
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="secret_token",
|
||||
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
|
||||
type=FormType.SECRET_INPUT,
|
||||
required=False,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
mock_encrypter.obfuscated_token.side_effect = lambda x: f"***{x[-4:]}"
|
||||
|
||||
# Act
|
||||
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert obfuscated["api_key"] == "***cdef"
|
||||
assert obfuscated["model_name"] == "gpt-4" # Not obfuscated
|
||||
assert obfuscated["secret_token"] == "***_123"
|
||||
assert obfuscated["temperature"] == "0.7" # Not obfuscated
|
||||
|
||||
# Verify encrypter was called for secret fields only
|
||||
assert mock_encrypter.obfuscated_token.call_count == 2
|
||||
mock_encrypter.obfuscated_token.assert_any_call("sk-1234567890abcdef")
|
||||
mock_encrypter.obfuscated_token.assert_any_call("secret_value_123")
|
||||
|
||||
def test_obfuscated_credentials_no_secret_variables(self, provider_configuration):
|
||||
"""Test obfuscating credentials when no secret variables exist"""
|
||||
# Arrange
|
||||
credentials = {
|
||||
"model_name": "gpt-4",
|
||||
"temperature": "0.7",
|
||||
"max_tokens": "1000",
|
||||
}
|
||||
|
||||
credential_form_schemas = [
|
||||
CredentialFormSchema(
|
||||
variable="model_name",
|
||||
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
CredentialFormSchema(
|
||||
variable="max_tokens",
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大令牌数"),
|
||||
type=FormType.TEXT_INPUT,
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert obfuscated == credentials # No changes expected
|
||||
|
||||
def test_obfuscated_credentials_empty_credentials(self, provider_configuration):
|
||||
"""Test obfuscating empty credentials"""
|
||||
# Arrange
|
||||
credentials = {}
|
||||
credential_form_schemas = []
|
||||
|
||||
# Act
|
||||
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
|
||||
|
||||
# Assert
|
||||
assert obfuscated == {}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities.provider_entities import ModelSettings
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
|
@ -7,19 +8,25 @@ from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity(mocker):
|
||||
def mock_provider_entity(mocker: MockerFixture):
|
||||
mock_entity = mocker.Mock()
|
||||
mock_entity.provider = "openai"
|
||||
mock_entity.configurate_methods = ["predefined-model"]
|
||||
mock_entity.supported_model_types = [ModelType.LLM]
|
||||
|
||||
mock_entity.model_credential_schema = mocker.Mock()
|
||||
mock_entity.model_credential_schema.credential_form_schemas = []
|
||||
# Use PropertyMock to ensure credential_form_schemas is iterable
|
||||
provider_credential_schema = mocker.Mock()
|
||||
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
mock_entity.provider_credential_schema = provider_credential_schema
|
||||
|
||||
model_credential_schema = mocker.Mock()
|
||||
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
mock_entity.model_credential_schema = model_credential_schema
|
||||
|
||||
return mock_entity
|
||||
|
||||
|
||||
def test__to_model_settings(mocker, mock_provider_entity):
|
||||
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
|
|
@ -79,7 +86,7 @@ def test__to_model_settings(mocker, mock_provider_entity):
|
|||
assert result[0].load_balancing_configs[1].name == "first"
|
||||
|
||||
|
||||
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
|
||||
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
|
|
@ -127,7 +134,7 @@ def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
|
|||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
|
||||
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
|
||||
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
|
|
|
|||
|
|
@ -1520,7 +1520,7 @@ requires-dist = [
|
|||
{ name = "opentelemetry-sdk", specifier = "==1.27.0" },
|
||||
{ name = "opentelemetry-semantic-conventions", specifier = "==0.48b0" },
|
||||
{ name = "opentelemetry-util-http", specifier = "==0.48b0" },
|
||||
{ name = "opik", specifier = "~=1.7.25" },
|
||||
{ name = "opik", specifier = "~=1.8.72" },
|
||||
{ name = "packaging", specifier = "~=23.2" },
|
||||
{ name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=2.2.2" },
|
||||
{ name = "psycogreen", specifier = "~=1.0.2" },
|
||||
|
|
@ -4019,7 +4019,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "opik"
|
||||
version = "1.7.43"
|
||||
version = "1.8.72"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "boto3-stubs", extra = ["bedrock-runtime"] },
|
||||
|
|
@ -4038,9 +4038,9 @@ dependencies = [
|
|||
{ name = "tqdm" },
|
||||
{ name = "uuid6" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ba/52/cea0317bc3207bc967b48932781995d9cdb2c490e7e05caa00ff660f7205/opik-1.7.43.tar.gz", hash = "sha256:0b02522b0b74d0a67b141939deda01f8bb69690eda6b04a7cecb1c7f0649ccd0", size = 326886, upload-time = "2025-07-07T10:30:07.715Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/aa/08/679b60db21994cf3318d4cdd1d08417c1877b79ac20971a8d80f118c9455/opik-1.8.72.tar.gz", hash = "sha256:26fcb003dc609d96b52eaf6a12fb16eb2b69eb0d1b35d88279ec612925d23944", size = 409774, upload-time = "2025-10-10T13:22:38.2Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/76/ae/f3566bdc3c49a1a8f795b1b6e726ef211c87e31f92d870ca6d63999c9bbf/opik-1.7.43-py3-none-any.whl", hash = "sha256:a66395c8b5ea7c24846f72dafc70c74d5b8f24ffbc4c8a1b3a7f9456e550568d", size = 625356, upload-time = "2025-07-07T10:30:06.389Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/f5/04d35af828d127de65a36286ce5b53e7310087a6b55a56f398daa7f0c9a6/opik-1.8.72-py3-none-any.whl", hash = "sha256:697e361a8364666f36aeb197aaba7ffa0696b49f04d2257b733d436749c90a8c", size = 768233, upload-time = "2025-10-10T13:22:36.352Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -1 +1,3 @@
|
|||
recursive-include dify_client *.py
|
||||
include README.md
|
||||
include LICENSE
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ First, install `dify-client` python sdk package:
|
|||
pip install dify-client
|
||||
```
|
||||
|
||||
### Synchronous Usage
|
||||
|
||||
Write your code with sdk:
|
||||
|
||||
- completion generate with `blocking` response_mode
|
||||
|
|
@ -221,3 +223,187 @@ answer = result.get("data").get("outputs")
|
|||
print(answer["answer"])
|
||||
|
||||
```
|
||||
|
||||
- Dataset Management
|
||||
|
||||
```python
|
||||
from dify_client import KnowledgeBaseClient
|
||||
|
||||
api_key = "your_api_key"
|
||||
dataset_id = "your_dataset_id"
|
||||
|
||||
# Use context manager to ensure proper resource cleanup
|
||||
with KnowledgeBaseClient(api_key, dataset_id) as kb_client:
|
||||
# Get dataset information
|
||||
dataset_info = kb_client.get_dataset()
|
||||
dataset_info.raise_for_status()
|
||||
print(dataset_info.json())
|
||||
|
||||
# Update dataset configuration
|
||||
update_response = kb_client.update_dataset(
|
||||
name="Updated Dataset Name",
|
||||
description="Updated description",
|
||||
indexing_technique="high_quality"
|
||||
)
|
||||
update_response.raise_for_status()
|
||||
print(update_response.json())
|
||||
|
||||
# Batch update document status
|
||||
batch_response = kb_client.batch_update_document_status(
|
||||
action="enable",
|
||||
document_ids=["doc_id_1", "doc_id_2", "doc_id_3"]
|
||||
)
|
||||
batch_response.raise_for_status()
|
||||
print(batch_response.json())
|
||||
```
|
||||
|
||||
- Conversation Variables Management
|
||||
|
||||
```python
|
||||
from dify_client import ChatClient
|
||||
|
||||
api_key = "your_api_key"
|
||||
|
||||
# Use context manager to ensure proper resource cleanup
|
||||
with ChatClient(api_key) as chat_client:
|
||||
# Get all conversation variables
|
||||
variables = chat_client.get_conversation_variables(
|
||||
conversation_id="conversation_id",
|
||||
user="user_id"
|
||||
)
|
||||
variables.raise_for_status()
|
||||
print(variables.json())
|
||||
|
||||
# Update a specific conversation variable
|
||||
update_var = chat_client.update_conversation_variable(
|
||||
conversation_id="conversation_id",
|
||||
variable_id="variable_id",
|
||||
value="new_value",
|
||||
user="user_id"
|
||||
)
|
||||
update_var.raise_for_status()
|
||||
print(update_var.json())
|
||||
```
|
||||
|
||||
### Asynchronous Usage
|
||||
|
||||
The SDK provides full async/await support for all API operations using `httpx.AsyncClient`. All async clients mirror their synchronous counterparts but require `await` for method calls.
|
||||
|
||||
- async chat with `blocking` response_mode
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from dify_client import AsyncChatClient
|
||||
|
||||
api_key = "your_api_key"
|
||||
|
||||
async def main():
|
||||
# Use async context manager for proper resource cleanup
|
||||
async with AsyncChatClient(api_key) as client:
|
||||
response = await client.create_chat_message(
|
||||
inputs={},
|
||||
query="Hello, how are you?",
|
||||
user="user_id",
|
||||
response_mode="blocking"
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
print(result.get('answer'))
|
||||
|
||||
# Run the async function
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
- async completion with `streaming` response_mode
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import json
|
||||
from dify_client import AsyncCompletionClient
|
||||
|
||||
api_key = "your_api_key"
|
||||
|
||||
async def main():
|
||||
async with AsyncCompletionClient(api_key) as client:
|
||||
response = await client.create_completion_message(
|
||||
inputs={"query": "What's the weather?"},
|
||||
response_mode="streaming",
|
||||
user="user_id"
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Stream the response
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith('data:'):
|
||||
data = line[5:].strip()
|
||||
if data:
|
||||
chunk = json.loads(data)
|
||||
print(chunk.get('answer', ''), end='', flush=True)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
- async workflow execution
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from dify_client import AsyncWorkflowClient
|
||||
|
||||
api_key = "your_api_key"
|
||||
|
||||
async def main():
|
||||
async with AsyncWorkflowClient(api_key) as client:
|
||||
response = await client.run(
|
||||
inputs={"query": "What is machine learning?"},
|
||||
response_mode="blocking",
|
||||
user="user_id"
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
print(result.get("data").get("outputs"))
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
- async dataset management
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from dify_client import AsyncKnowledgeBaseClient
|
||||
|
||||
api_key = "your_api_key"
|
||||
dataset_id = "your_dataset_id"
|
||||
|
||||
async def main():
|
||||
async with AsyncKnowledgeBaseClient(api_key, dataset_id) as kb_client:
|
||||
# Get dataset information
|
||||
dataset_info = await kb_client.get_dataset()
|
||||
dataset_info.raise_for_status()
|
||||
print(dataset_info.json())
|
||||
|
||||
# List documents
|
||||
docs = await kb_client.list_documents(page=1, page_size=10)
|
||||
docs.raise_for_status()
|
||||
print(docs.json())
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
**Benefits of Async Usage:**
|
||||
|
||||
- **Better Performance**: Handle multiple concurrent API requests efficiently
|
||||
- **Non-blocking I/O**: Don't block the event loop during network operations
|
||||
- **Scalability**: Ideal for applications handling many simultaneous requests
|
||||
- **Modern Python**: Leverages Python's native async/await syntax
|
||||
|
||||
**Available Async Clients:**
|
||||
|
||||
- `AsyncDifyClient` - Base async client
|
||||
- `AsyncChatClient` - Async chat operations
|
||||
- `AsyncCompletionClient` - Async completion operations
|
||||
- `AsyncWorkflowClient` - Async workflow operations
|
||||
- `AsyncKnowledgeBaseClient` - Async dataset/knowledge base operations
|
||||
- `AsyncWorkspaceClient` - Async workspace operations
|
||||
|
||||
```
|
||||
```
|
||||
|
|
|
|||
|
|
@ -7,11 +7,28 @@ from dify_client.client import (
|
|||
WorkspaceClient,
|
||||
)
|
||||
|
||||
from dify_client.async_client import (
|
||||
AsyncChatClient,
|
||||
AsyncCompletionClient,
|
||||
AsyncDifyClient,
|
||||
AsyncKnowledgeBaseClient,
|
||||
AsyncWorkflowClient,
|
||||
AsyncWorkspaceClient,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Synchronous clients
|
||||
"ChatClient",
|
||||
"CompletionClient",
|
||||
"DifyClient",
|
||||
"KnowledgeBaseClient",
|
||||
"WorkflowClient",
|
||||
"WorkspaceClient",
|
||||
# Asynchronous clients
|
||||
"AsyncChatClient",
|
||||
"AsyncCompletionClient",
|
||||
"AsyncDifyClient",
|
||||
"AsyncKnowledgeBaseClient",
|
||||
"AsyncWorkflowClient",
|
||||
"AsyncWorkspaceClient",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,808 @@
|
|||
"""Asynchronous Dify API client.
|
||||
|
||||
This module provides async/await support for all Dify API operations using httpx.AsyncClient.
|
||||
All client classes mirror their synchronous counterparts but require `await` for method calls.
|
||||
|
||||
Example:
|
||||
import asyncio
|
||||
from dify_client import AsyncChatClient
|
||||
|
||||
async def main():
|
||||
async with AsyncChatClient(api_key="your-key") as client:
|
||||
response = await client.create_chat_message(
|
||||
inputs={},
|
||||
query="Hello",
|
||||
user="user-123"
|
||||
)
|
||||
print(response.json())
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Literal, Dict, List, Any, IO
|
||||
|
||||
import aiofiles
|
||||
import httpx
|
||||
|
||||
|
||||
class AsyncDifyClient:
|
||||
"""Asynchronous Dify API client.
|
||||
|
||||
This client uses httpx.AsyncClient for efficient async connection pooling.
|
||||
It's recommended to use this client as a context manager:
|
||||
|
||||
Example:
|
||||
async with AsyncDifyClient(api_key="your-key") as client:
|
||||
response = await client.get_app_info()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.dify.ai/v1",
|
||||
timeout: float = 60.0,
|
||||
):
|
||||
"""Initialize the async Dify client.
|
||||
|
||||
Args:
|
||||
api_key: Your Dify API key
|
||||
base_url: Base URL for the Dify API
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=base_url,
|
||||
timeout=httpx.Timeout(timeout, connect=5.0),
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Support async context manager protocol."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Clean up resources when exiting async context."""
|
||||
await self.aclose()
|
||||
|
||||
async def aclose(self):
|
||||
"""Close the async HTTP client and release resources."""
|
||||
if hasattr(self, "_client"):
|
||||
await self._client.aclose()
|
||||
|
||||
async def _send_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
json: dict | None = None,
|
||||
params: dict | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Send an async HTTP request to the Dify API.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, PUT, PATCH, DELETE)
|
||||
endpoint: API endpoint path
|
||||
json: JSON request body
|
||||
params: Query parameters
|
||||
stream: Whether to stream the response
|
||||
**kwargs: Additional arguments to pass to httpx.request
|
||||
|
||||
Returns:
|
||||
httpx.Response object
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = await self._client.request(
|
||||
method,
|
||||
endpoint,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict):
|
||||
"""Send an async HTTP request with file uploads.
|
||||
|
||||
Args:
|
||||
method: HTTP method (POST, PUT, etc.)
|
||||
endpoint: API endpoint path
|
||||
data: Form data
|
||||
files: Files to upload
|
||||
|
||||
Returns:
|
||||
httpx.Response object
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
response = await self._client.request(
|
||||
method,
|
||||
endpoint,
|
||||
data=data,
|
||||
headers=headers,
|
||||
files=files,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str):
|
||||
"""Send feedback for a message."""
|
||||
data = {"rating": rating, "user": user}
|
||||
return await self._send_request("POST", f"/messages/{message_id}/feedbacks", data)
|
||||
|
||||
async def get_application_parameters(self, user: str):
|
||||
"""Get application parameters."""
|
||||
params = {"user": user}
|
||||
return await self._send_request("GET", "/parameters", params=params)
|
||||
|
||||
async def file_upload(self, user: str, files: dict):
|
||||
"""Upload a file."""
|
||||
data = {"user": user}
|
||||
return await self._send_request_with_files("POST", "/files/upload", data=data, files=files)
|
||||
|
||||
async def text_to_audio(self, text: str, user: str, streaming: bool = False):
|
||||
"""Convert text to audio."""
|
||||
data = {"text": text, "user": user, "streaming": streaming}
|
||||
return await self._send_request("POST", "/text-to-audio", json=data)
|
||||
|
||||
async def get_meta(self, user: str):
|
||||
"""Get metadata."""
|
||||
params = {"user": user}
|
||||
return await self._send_request("GET", "/meta", params=params)
|
||||
|
||||
async def get_app_info(self):
|
||||
"""Get basic application information including name, description, tags, and mode."""
|
||||
return await self._send_request("GET", "/info")
|
||||
|
||||
async def get_app_site_info(self):
|
||||
"""Get application site information."""
|
||||
return await self._send_request("GET", "/site")
|
||||
|
||||
async def get_file_preview(self, file_id: str):
|
||||
"""Get file preview by file ID."""
|
||||
return await self._send_request("GET", f"/files/{file_id}/preview")
|
||||
|
||||
|
||||
class AsyncCompletionClient(AsyncDifyClient):
|
||||
"""Async client for Completion API operations."""
|
||||
|
||||
async def create_completion_message(
|
||||
self,
|
||||
inputs: dict,
|
||||
response_mode: Literal["blocking", "streaming"],
|
||||
user: str,
|
||||
files: dict | None = None,
|
||||
):
|
||||
"""Create a completion message.
|
||||
|
||||
Args:
|
||||
inputs: Input variables for the completion
|
||||
response_mode: Response mode ('blocking' or 'streaming')
|
||||
user: User identifier
|
||||
files: Optional files to include
|
||||
|
||||
Returns:
|
||||
httpx.Response object
|
||||
"""
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"response_mode": response_mode,
|
||||
"user": user,
|
||||
"files": files,
|
||||
}
|
||||
return await self._send_request(
|
||||
"POST",
|
||||
"/completion-messages",
|
||||
data,
|
||||
stream=(response_mode == "streaming"),
|
||||
)
|
||||
|
||||
|
||||
class AsyncChatClient(AsyncDifyClient):
|
||||
"""Async client for Chat API operations."""
|
||||
|
||||
async def create_chat_message(
|
||||
self,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
user: str,
|
||||
response_mode: Literal["blocking", "streaming"] = "blocking",
|
||||
conversation_id: str | None = None,
|
||||
files: dict | None = None,
|
||||
):
|
||||
"""Create a chat message.
|
||||
|
||||
Args:
|
||||
inputs: Input variables for the chat
|
||||
query: User query/message
|
||||
user: User identifier
|
||||
response_mode: Response mode ('blocking' or 'streaming')
|
||||
conversation_id: Optional conversation ID for context
|
||||
files: Optional files to include
|
||||
|
||||
Returns:
|
||||
httpx.Response object
|
||||
"""
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"user": user,
|
||||
"response_mode": response_mode,
|
||||
"files": files,
|
||||
}
|
||||
if conversation_id:
|
||||
data["conversation_id"] = conversation_id
|
||||
|
||||
return await self._send_request(
|
||||
"POST",
|
||||
"/chat-messages",
|
||||
data,
|
||||
stream=(response_mode == "streaming"),
|
||||
)
|
||||
|
||||
async def get_suggested(self, message_id: str, user: str):
|
||||
"""Get suggested questions for a message."""
|
||||
params = {"user": user}
|
||||
return await self._send_request("GET", f"/messages/{message_id}/suggested", params=params)
|
||||
|
||||
async def stop_message(self, task_id: str, user: str):
|
||||
"""Stop a running message generation."""
|
||||
data = {"user": user}
|
||||
return await self._send_request("POST", f"/chat-messages/{task_id}/stop", data)
|
||||
|
||||
async def get_conversations(
|
||||
self,
|
||||
user: str,
|
||||
last_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
pinned: bool | None = None,
|
||||
):
|
||||
"""Get list of conversations."""
|
||||
params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned}
|
||||
return await self._send_request("GET", "/conversations", params=params)
|
||||
|
||||
async def get_conversation_messages(
|
||||
self,
|
||||
user: str,
|
||||
conversation_id: str | None = None,
|
||||
first_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
):
|
||||
"""Get messages from a conversation."""
|
||||
params = {
|
||||
"user": user,
|
||||
"conversation_id": conversation_id,
|
||||
"first_id": first_id,
|
||||
"limit": limit,
|
||||
}
|
||||
return await self._send_request("GET", "/messages", params=params)
|
||||
|
||||
async def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str):
|
||||
"""Rename a conversation."""
|
||||
data = {"name": name, "auto_generate": auto_generate, "user": user}
|
||||
return await self._send_request("POST", f"/conversations/{conversation_id}/name", data)
|
||||
|
||||
async def delete_conversation(self, conversation_id: str, user: str):
|
||||
"""Delete a conversation."""
|
||||
data = {"user": user}
|
||||
return await self._send_request("DELETE", f"/conversations/{conversation_id}", data)
|
||||
|
||||
async def audio_to_text(self, audio_file: IO[bytes] | tuple, user: str):
|
||||
"""Convert audio to text."""
|
||||
data = {"user": user}
|
||||
files = {"file": audio_file}
|
||||
return await self._send_request_with_files("POST", "/audio-to-text", data, files)
|
||||
|
||||
# Annotation APIs
|
||||
async def annotation_reply_action(
|
||||
self,
|
||||
action: Literal["enable", "disable"],
|
||||
score_threshold: float,
|
||||
embedding_provider_name: str,
|
||||
embedding_model_name: str,
|
||||
):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
data = {
|
||||
"score_threshold": score_threshold,
|
||||
"embedding_provider_name": embedding_provider_name,
|
||||
"embedding_model_name": embedding_model_name,
|
||||
}
|
||||
return await self._send_request("POST", f"/apps/annotation-reply/{action}", json=data)
|
||||
|
||||
async def get_annotation_reply_status(self, action: Literal["enable", "disable"], job_id: str):
|
||||
"""Get the status of an annotation reply action job."""
|
||||
return await self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}")
|
||||
|
||||
async def list_annotations(self, page: int = 1, limit: int = 20, keyword: str | None = None):
|
||||
"""List annotations for the application."""
|
||||
params = {"page": page, "limit": limit, "keyword": keyword}
|
||||
return await self._send_request("GET", "/apps/annotations", params=params)
|
||||
|
||||
async def create_annotation(self, question: str, answer: str):
|
||||
"""Create a new annotation."""
|
||||
data = {"question": question, "answer": answer}
|
||||
return await self._send_request("POST", "/apps/annotations", json=data)
|
||||
|
||||
async def update_annotation(self, annotation_id: str, question: str, answer: str):
|
||||
"""Update an existing annotation."""
|
||||
data = {"question": question, "answer": answer}
|
||||
return await self._send_request("PUT", f"/apps/annotations/{annotation_id}", json=data)
|
||||
|
||||
async def delete_annotation(self, annotation_id: str):
|
||||
"""Delete an annotation."""
|
||||
return await self._send_request("DELETE", f"/apps/annotations/{annotation_id}")
|
||||
|
||||
# Conversation Variables APIs
|
||||
async def get_conversation_variables(self, conversation_id: str, user: str):
|
||||
"""Get all variables for a specific conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID to query variables for
|
||||
user: User identifier
|
||||
|
||||
Returns:
|
||||
Response from the API containing:
|
||||
- variables: List of conversation variables with their values
|
||||
- conversation_id: The conversation ID
|
||||
"""
|
||||
params = {"user": user}
|
||||
url = f"/conversations/{conversation_id}/variables"
|
||||
return await self._send_request("GET", url, params=params)
|
||||
|
||||
async def update_conversation_variable(self, conversation_id: str, variable_id: str, value: Any, user: str):
|
||||
"""Update a specific conversation variable.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID
|
||||
variable_id: The variable ID to update
|
||||
value: New value for the variable
|
||||
user: User identifier
|
||||
|
||||
Returns:
|
||||
Response from the API with updated variable information
|
||||
"""
|
||||
data = {"value": value, "user": user}
|
||||
url = f"/conversations/{conversation_id}/variables/{variable_id}"
|
||||
return await self._send_request("PATCH", url, json=data)
|
||||
|
||||
|
||||
class AsyncWorkflowClient(AsyncDifyClient):
|
||||
"""Async client for Workflow API operations."""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
inputs: dict,
|
||||
response_mode: Literal["blocking", "streaming"] = "streaming",
|
||||
user: str = "abc-123",
|
||||
):
|
||||
"""Run a workflow."""
|
||||
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
|
||||
return await self._send_request("POST", "/workflows/run", data)
|
||||
|
||||
async def stop(self, task_id: str, user: str):
|
||||
"""Stop a running workflow task."""
|
||||
data = {"user": user}
|
||||
return await self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data)
|
||||
|
||||
async def get_result(self, workflow_run_id: str):
|
||||
"""Get workflow run result."""
|
||||
return await self._send_request("GET", f"/workflows/run/{workflow_run_id}")
|
||||
|
||||
async def get_workflow_logs(
|
||||
self,
|
||||
keyword: str = None,
|
||||
status: Literal["succeeded", "failed", "stopped"] | None = None,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
created_at__before: str = None,
|
||||
created_at__after: str = None,
|
||||
created_by_end_user_session_id: str = None,
|
||||
created_by_account: str = None,
|
||||
):
|
||||
"""Get workflow execution logs with optional filtering."""
|
||||
params = {
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"keyword": keyword,
|
||||
"status": status,
|
||||
"created_at__before": created_at__before,
|
||||
"created_at__after": created_at__after,
|
||||
"created_by_end_user_session_id": created_by_end_user_session_id,
|
||||
"created_by_account": created_by_account,
|
||||
}
|
||||
return await self._send_request("GET", "/workflows/logs", params=params)
|
||||
|
||||
async def run_specific_workflow(
|
||||
self,
|
||||
workflow_id: str,
|
||||
inputs: dict,
|
||||
response_mode: Literal["blocking", "streaming"] = "streaming",
|
||||
user: str = "abc-123",
|
||||
):
|
||||
"""Run a specific workflow by workflow ID."""
|
||||
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
|
||||
return await self._send_request(
|
||||
"POST",
|
||||
f"/workflows/{workflow_id}/run",
|
||||
data,
|
||||
stream=(response_mode == "streaming"),
|
||||
)
|
||||
|
||||
|
||||
class AsyncWorkspaceClient(AsyncDifyClient):
|
||||
"""Async client for workspace-related operations."""
|
||||
|
||||
async def get_available_models(self, model_type: str):
|
||||
"""Get available models by model type."""
|
||||
url = f"/workspaces/current/models/model-types/{model_type}"
|
||||
return await self._send_request("GET", url)
|
||||
|
||||
|
||||
class AsyncKnowledgeBaseClient(AsyncDifyClient):
|
||||
"""Async client for Knowledge Base API operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.dify.ai/v1",
|
||||
dataset_id: str | None = None,
|
||||
timeout: float = 60.0,
|
||||
):
|
||||
"""Construct an AsyncKnowledgeBaseClient object.
|
||||
|
||||
Args:
|
||||
api_key: API key of Dify
|
||||
base_url: Base URL of Dify API
|
||||
dataset_id: ID of the dataset
|
||||
timeout: Request timeout in seconds
|
||||
"""
|
||||
super().__init__(api_key=api_key, base_url=base_url, timeout=timeout)
|
||||
self.dataset_id = dataset_id
|
||||
|
||||
def _get_dataset_id(self):
|
||||
"""Get the dataset ID, raise error if not set."""
|
||||
if self.dataset_id is None:
|
||||
raise ValueError("dataset_id is not set")
|
||||
return self.dataset_id
|
||||
|
||||
async def create_dataset(self, name: str, **kwargs):
|
||||
"""Create a new dataset."""
|
||||
return await self._send_request("POST", "/datasets", {"name": name}, **kwargs)
|
||||
|
||||
async def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
|
||||
"""List all datasets."""
|
||||
return await self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs)
|
||||
|
||||
async def create_document_by_text(self, name: str, text: str, extra_params: dict | None = None, **kwargs):
|
||||
"""Create a document by text.
|
||||
|
||||
Args:
|
||||
name: Name of the document
|
||||
text: Text content of the document
|
||||
extra_params: Extra parameters for the API
|
||||
|
||||
Returns:
|
||||
Response from the API
|
||||
"""
|
||||
data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"process_rule": {"mode": "automatic"},
|
||||
"name": name,
|
||||
"text": text,
|
||||
}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
url = f"/datasets/{self._get_dataset_id()}/document/create_by_text"
|
||||
return await self._send_request("POST", url, json=data, **kwargs)
|
||||
|
||||
async def update_document_by_text(
|
||||
self,
|
||||
document_id: str,
|
||||
name: str,
|
||||
text: str,
|
||||
extra_params: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Update a document by text."""
|
||||
data = {"name": name, "text": text}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
|
||||
return await self._send_request("POST", url, json=data, **kwargs)
|
||||
|
||||
async def create_document_by_file(
|
||||
self,
|
||||
file_path: str,
|
||||
original_document_id: str | None = None,
|
||||
extra_params: dict | None = None,
|
||||
):
|
||||
"""Create a document by file."""
|
||||
async with aiofiles.open(file_path, "rb") as f:
|
||||
files = {"file": (os.path.basename(file_path), f)}
|
||||
data = {
|
||||
"process_rule": {"mode": "automatic"},
|
||||
"indexing_technique": "high_quality",
|
||||
}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
if original_document_id is not None:
|
||||
data["original_document_id"] = original_document_id
|
||||
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
|
||||
return await self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
|
||||
|
||||
async def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None):
|
||||
"""Update a document by file."""
|
||||
async with aiofiles.open(file_path, "rb") as f:
|
||||
files = {"file": (os.path.basename(file_path), f)}
|
||||
data = {}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
|
||||
return await self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
|
||||
|
||||
async def batch_indexing_status(self, batch_id: str, **kwargs):
|
||||
"""Get the status of the batch indexing."""
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status"
|
||||
return await self._send_request("GET", url, **kwargs)
|
||||
|
||||
async def delete_dataset(self):
|
||||
"""Delete this dataset."""
|
||||
url = f"/datasets/{self._get_dataset_id()}"
|
||||
return await self._send_request("DELETE", url)
|
||||
|
||||
async def delete_document(self, document_id: str):
|
||||
"""Delete a document."""
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}"
|
||||
return await self._send_request("DELETE", url)
|
||||
|
||||
async def list_documents(
|
||||
self,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
keyword: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Get a list of documents in this dataset."""
|
||||
params = {
|
||||
"page": page,
|
||||
"limit": page_size,
|
||||
"keyword": keyword,
|
||||
}
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents"
|
||||
return await self._send_request("GET", url, params=params, **kwargs)
|
||||
|
||||
async def add_segments(self, document_id: str, segments: list[dict], **kwargs):
|
||||
"""Add segments to a document."""
|
||||
data = {"segments": segments}
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
|
||||
return await self._send_request("POST", url, json=data, **kwargs)
|
||||
|
||||
async def query_segments(
|
||||
self,
|
||||
document_id: str,
|
||||
keyword: str | None = None,
|
||||
status: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Query segments in this document.
|
||||
|
||||
Args:
|
||||
document_id: ID of the document
|
||||
keyword: Query keyword (optional)
|
||||
status: Status of the segment (optional, e.g., 'completed')
|
||||
**kwargs: Additional parameters to pass to the API.
|
||||
Can include a 'params' dict for extra query parameters.
|
||||
|
||||
Returns:
|
||||
Response from the API
|
||||
"""
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
|
||||
params = {
|
||||
"keyword": keyword,
|
||||
"status": status,
|
||||
}
|
||||
if "params" in kwargs:
|
||||
params.update(kwargs.pop("params"))
|
||||
return await self._send_request("GET", url, params=params, **kwargs)
|
||||
|
||||
async def delete_document_segment(self, document_id: str, segment_id: str):
|
||||
"""Delete a segment from a document."""
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}"
|
||||
return await self._send_request("DELETE", url)
|
||||
|
||||
async def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs):
|
||||
"""Update a segment in a document."""
|
||||
data = {"segment": segment_data}
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}"
|
||||
return await self._send_request("POST", url, json=data, **kwargs)
|
||||
|
||||
# Advanced Knowledge Base APIs
|
||||
async def hit_testing(
|
||||
self,
|
||||
query: str,
|
||||
retrieval_model: Dict[str, Any] = None,
|
||||
external_retrieval_model: Dict[str, Any] = None,
|
||||
):
|
||||
"""Perform hit testing on the dataset."""
|
||||
data = {"query": query}
|
||||
if retrieval_model:
|
||||
data["retrieval_model"] = retrieval_model
|
||||
if external_retrieval_model:
|
||||
data["external_retrieval_model"] = external_retrieval_model
|
||||
url = f"/datasets/{self._get_dataset_id()}/hit-testing"
|
||||
return await self._send_request("POST", url, json=data)
|
||||
|
||||
async def get_dataset_metadata(self):
|
||||
"""Get dataset metadata."""
|
||||
url = f"/datasets/{self._get_dataset_id()}/metadata"
|
||||
return await self._send_request("GET", url)
|
||||
|
||||
async def create_dataset_metadata(self, metadata_data: Dict[str, Any]):
|
||||
"""Create dataset metadata."""
|
||||
url = f"/datasets/{self._get_dataset_id()}/metadata"
|
||||
return await self._send_request("POST", url, json=metadata_data)
|
||||
|
||||
async def update_dataset_metadata(self, metadata_id: str, metadata_data: Dict[str, Any]):
|
||||
"""Update dataset metadata."""
|
||||
url = f"/datasets/{self._get_dataset_id()}/metadata/{metadata_id}"
|
||||
return await self._send_request("PATCH", url, json=metadata_data)
|
||||
|
||||
async def get_built_in_metadata(self):
|
||||
"""Get built-in metadata."""
|
||||
url = f"/datasets/{self._get_dataset_id()}/metadata/built-in"
|
||||
return await self._send_request("GET", url)
|
||||
|
||||
async def manage_built_in_metadata(self, action: str, metadata_data: Dict[str, Any] = None):
|
||||
"""Manage built-in metadata with specified action."""
|
||||
data = metadata_data or {}
|
||||
url = f"/datasets/{self._get_dataset_id()}/metadata/built-in/{action}"
|
||||
return await self._send_request("POST", url, json=data)
|
||||
|
||||
async def update_documents_metadata(self, operation_data: List[Dict[str, Any]]):
|
||||
"""Update metadata for multiple documents."""
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/metadata"
|
||||
data = {"operation_data": operation_data}
|
||||
return await self._send_request("POST", url, json=data)
|
||||
|
||||
# Dataset Tags APIs
|
||||
async def list_dataset_tags(self):
|
||||
"""List all dataset tags."""
|
||||
return await self._send_request("GET", "/datasets/tags")
|
||||
|
||||
async def bind_dataset_tags(self, tag_ids: List[str]):
|
||||
"""Bind tags to dataset."""
|
||||
data = {"tag_ids": tag_ids, "target_id": self._get_dataset_id()}
|
||||
return await self._send_request("POST", "/datasets/tags/binding", json=data)
|
||||
|
||||
async def unbind_dataset_tag(self, tag_id: str):
|
||||
"""Unbind a single tag from dataset."""
|
||||
data = {"tag_id": tag_id, "target_id": self._get_dataset_id()}
|
||||
return await self._send_request("POST", "/datasets/tags/unbinding", json=data)
|
||||
|
||||
async def get_dataset_tags(self):
|
||||
"""Get tags for current dataset."""
|
||||
url = f"/datasets/{self._get_dataset_id()}/tags"
|
||||
return await self._send_request("GET", url)
|
||||
|
||||
# RAG Pipeline APIs
|
||||
async def get_datasource_plugins(self, is_published: bool = True):
|
||||
"""Get datasource plugins for RAG pipeline."""
|
||||
params = {"is_published": is_published}
|
||||
url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource-plugins"
|
||||
return await self._send_request("GET", url, params=params)
|
||||
|
||||
async def run_datasource_node(
|
||||
self,
|
||||
node_id: str,
|
||||
inputs: Dict[str, Any],
|
||||
datasource_type: str,
|
||||
is_published: bool = True,
|
||||
credential_id: str = None,
|
||||
):
|
||||
"""Run a datasource node in RAG pipeline."""
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"datasource_type": datasource_type,
|
||||
"is_published": is_published,
|
||||
}
|
||||
if credential_id:
|
||||
data["credential_id"] = credential_id
|
||||
url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run"
|
||||
return await self._send_request("POST", url, json=data, stream=True)
|
||||
|
||||
async def run_rag_pipeline(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
datasource_type: str,
|
||||
datasource_info_list: List[Dict[str, Any]],
|
||||
start_node_id: str,
|
||||
is_published: bool = True,
|
||||
response_mode: Literal["streaming", "blocking"] = "blocking",
|
||||
):
|
||||
"""Run RAG pipeline."""
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"datasource_type": datasource_type,
|
||||
"datasource_info_list": datasource_info_list,
|
||||
"start_node_id": start_node_id,
|
||||
"is_published": is_published,
|
||||
"response_mode": response_mode,
|
||||
}
|
||||
url = f"/datasets/{self._get_dataset_id()}/pipeline/run"
|
||||
return await self._send_request("POST", url, json=data, stream=response_mode == "streaming")
|
||||
|
||||
async def upload_pipeline_file(self, file_path: str):
|
||||
"""Upload file for RAG pipeline."""
|
||||
async with aiofiles.open(file_path, "rb") as f:
|
||||
files = {"file": (os.path.basename(file_path), f)}
|
||||
return await self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files)
|
||||
|
||||
# Dataset Management APIs
|
||||
async def get_dataset(self, dataset_id: str | None = None):
|
||||
"""Get detailed information about a specific dataset."""
|
||||
ds_id = dataset_id or self._get_dataset_id()
|
||||
url = f"/datasets/{ds_id}"
|
||||
return await self._send_request("GET", url)
|
||||
|
||||
async def update_dataset(
|
||||
self,
|
||||
dataset_id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
indexing_technique: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_model_provider: str | None = None,
|
||||
retrieval_model: Dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Update dataset configuration.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID (optional, uses current dataset_id if not provided)
|
||||
name: New dataset name
|
||||
description: New dataset description
|
||||
indexing_technique: Indexing technique ('high_quality' or 'economy')
|
||||
embedding_model: Embedding model name
|
||||
embedding_model_provider: Embedding model provider
|
||||
retrieval_model: Retrieval model configuration dict
|
||||
**kwargs: Additional parameters to pass to the API
|
||||
|
||||
Returns:
|
||||
Response from the API with updated dataset information
|
||||
"""
|
||||
ds_id = dataset_id or self._get_dataset_id()
|
||||
url = f"/datasets/{ds_id}"
|
||||
|
||||
payload = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"indexing_technique": indexing_technique,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_model_provider": embedding_model_provider,
|
||||
"retrieval_model": retrieval_model,
|
||||
}
|
||||
|
||||
data = {k: v for k, v in payload.items() if v is not None}
|
||||
data.update(kwargs)
|
||||
|
||||
return await self._send_request("PATCH", url, json=data)
|
||||
|
||||
async def batch_update_document_status(
|
||||
self,
|
||||
action: Literal["enable", "disable", "archive", "un_archive"],
|
||||
document_ids: List[str],
|
||||
dataset_id: str | None = None,
|
||||
):
|
||||
"""Batch update document status."""
|
||||
ds_id = dataset_id or self._get_dataset_id()
|
||||
url = f"/datasets/{ds_id}/documents/status/{action}"
|
||||
data = {"document_ids": document_ids}
|
||||
return await self._send_request("PATCH", url, json=data)
|
||||
|
|
@ -1,32 +1,114 @@
|
|||
import json
|
||||
from typing import Literal, Union, Dict, List, Any, Optional, IO
|
||||
import os
|
||||
from typing import Literal, Dict, List, Any, IO
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
|
||||
class DifyClient:
|
||||
def __init__(self, api_key, base_url: str = "https://api.dify.ai/v1"):
|
||||
"""Synchronous Dify API client.
|
||||
|
||||
This client uses httpx.Client for efficient connection pooling and resource management.
|
||||
It's recommended to use this client as a context manager:
|
||||
|
||||
Example:
|
||||
with DifyClient(api_key="your-key") as client:
|
||||
response = client.get_app_info()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.dify.ai/v1",
|
||||
timeout: float = 60.0,
|
||||
):
|
||||
"""Initialize the Dify client.
|
||||
|
||||
Args:
|
||||
api_key: Your Dify API key
|
||||
base_url: Base URL for the Dify API
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self._client = httpx.Client(
|
||||
base_url=base_url,
|
||||
timeout=httpx.Timeout(timeout, connect=5.0),
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
"""Support context manager protocol."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Clean up resources when exiting context."""
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""Close the HTTP client and release resources."""
|
||||
if hasattr(self, "_client"):
|
||||
self._client.close()
|
||||
|
||||
def _send_request(
|
||||
self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
json: dict | None = None,
|
||||
params: dict | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Send an HTTP request to the Dify API.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, PUT, PATCH, DELETE)
|
||||
endpoint: API endpoint path
|
||||
json: JSON request body
|
||||
params: Query parameters
|
||||
stream: Whether to stream the response
|
||||
**kwargs: Additional arguments to pass to httpx.request
|
||||
|
||||
Returns:
|
||||
httpx.Response object
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers, stream=stream)
|
||||
# httpx.Client automatically prepends base_url
|
||||
response = self._client.request(
|
||||
method,
|
||||
endpoint,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _send_request_with_files(self, method, endpoint, data, files):
|
||||
def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict):
|
||||
"""Send an HTTP request with file uploads.
|
||||
|
||||
Args:
|
||||
method: HTTP method (POST, PUT, etc.)
|
||||
endpoint: API endpoint path
|
||||
data: Form data
|
||||
files: Files to upload
|
||||
|
||||
Returns:
|
||||
httpx.Response object
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
response = requests.request(method, url, data=data, headers=headers, files=files)
|
||||
response = self._client.request(
|
||||
method,
|
||||
endpoint,
|
||||
data=data,
|
||||
headers=headers,
|
||||
files=files,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -65,7 +147,11 @@ class DifyClient:
|
|||
|
||||
class CompletionClient(DifyClient):
|
||||
def create_completion_message(
|
||||
self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None
|
||||
self,
|
||||
inputs: dict,
|
||||
response_mode: Literal["blocking", "streaming"],
|
||||
user: str,
|
||||
files: dict | None = None,
|
||||
):
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
|
|
@ -77,7 +163,7 @@ class CompletionClient(DifyClient):
|
|||
"POST",
|
||||
"/completion-messages",
|
||||
data,
|
||||
stream=True if response_mode == "streaming" else False,
|
||||
stream=(response_mode == "streaming"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -105,7 +191,7 @@ class ChatClient(DifyClient):
|
|||
"POST",
|
||||
"/chat-messages",
|
||||
data,
|
||||
stream=True if response_mode == "streaming" else False,
|
||||
stream=(response_mode == "streaming"),
|
||||
)
|
||||
|
||||
def get_suggested(self, message_id: str, user: str):
|
||||
|
|
@ -166,10 +252,6 @@ class ChatClient(DifyClient):
|
|||
embedding_model_name: str,
|
||||
):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
# Backend API requires these fields to be non-None values
|
||||
if score_threshold is None or embedding_provider_name is None or embedding_model_name is None:
|
||||
raise ValueError("score_threshold, embedding_provider_name, and embedding_model_name cannot be None")
|
||||
|
||||
data = {
|
||||
"score_threshold": score_threshold,
|
||||
"embedding_provider_name": embedding_provider_name,
|
||||
|
|
@ -181,11 +263,9 @@ class ChatClient(DifyClient):
|
|||
"""Get the status of an annotation reply action job."""
|
||||
return self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}")
|
||||
|
||||
def list_annotations(self, page: int = 1, limit: int = 20, keyword: str = ""):
|
||||
def list_annotations(self, page: int = 1, limit: int = 20, keyword: str | None = None):
|
||||
"""List annotations for the application."""
|
||||
params = {"page": page, "limit": limit}
|
||||
if keyword:
|
||||
params["keyword"] = keyword
|
||||
params = {"page": page, "limit": limit, "keyword": keyword}
|
||||
return self._send_request("GET", "/apps/annotations", params=params)
|
||||
|
||||
def create_annotation(self, question: str, answer: str):
|
||||
|
|
@ -202,9 +282,47 @@ class ChatClient(DifyClient):
|
|||
"""Delete an annotation."""
|
||||
return self._send_request("DELETE", f"/apps/annotations/{annotation_id}")
|
||||
|
||||
# Conversation Variables APIs
|
||||
def get_conversation_variables(self, conversation_id: str, user: str):
|
||||
"""Get all variables for a specific conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID to query variables for
|
||||
user: User identifier
|
||||
|
||||
Returns:
|
||||
Response from the API containing:
|
||||
- variables: List of conversation variables with their values
|
||||
- conversation_id: The conversation ID
|
||||
"""
|
||||
params = {"user": user}
|
||||
url = f"/conversations/{conversation_id}/variables"
|
||||
return self._send_request("GET", url, params=params)
|
||||
|
||||
def update_conversation_variable(self, conversation_id: str, variable_id: str, value: Any, user: str):
|
||||
"""Update a specific conversation variable.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID
|
||||
variable_id: The variable ID to update
|
||||
value: New value for the variable
|
||||
user: User identifier
|
||||
|
||||
Returns:
|
||||
Response from the API with updated variable information
|
||||
"""
|
||||
data = {"value": value, "user": user}
|
||||
url = f"/conversations/{conversation_id}/variables/{variable_id}"
|
||||
return self._send_request("PATCH", url, json=data)
|
||||
|
||||
|
||||
class WorkflowClient(DifyClient):
|
||||
def run(self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"):
|
||||
def run(
|
||||
self,
|
||||
inputs: dict,
|
||||
response_mode: Literal["blocking", "streaming"] = "streaming",
|
||||
user: str = "abc-123",
|
||||
):
|
||||
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
|
||||
return self._send_request("POST", "/workflows/run", data)
|
||||
|
||||
|
|
@ -252,7 +370,10 @@ class WorkflowClient(DifyClient):
|
|||
"""Run a specific workflow by workflow ID."""
|
||||
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
|
||||
return self._send_request(
|
||||
"POST", f"/workflows/{workflow_id}/run", data, stream=True if response_mode == "streaming" else False
|
||||
"POST",
|
||||
f"/workflows/{workflow_id}/run",
|
||||
data,
|
||||
stream=(response_mode == "streaming"),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -293,7 +414,7 @@ class KnowledgeBaseClient(DifyClient):
|
|||
return self._send_request("POST", "/datasets", {"name": name}, **kwargs)
|
||||
|
||||
def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
|
||||
return self._send_request("GET", f"/datasets?page={page}&limit={page_size}", **kwargs)
|
||||
return self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs)
|
||||
|
||||
def create_document_by_text(self, name, text, extra_params: dict | None = None, **kwargs):
|
||||
"""
|
||||
|
|
@ -333,7 +454,12 @@ class KnowledgeBaseClient(DifyClient):
|
|||
return self._send_request("POST", url, json=data, **kwargs)
|
||||
|
||||
def update_document_by_text(
|
||||
self, document_id: str, name: str, text: str, extra_params: dict | None = None, **kwargs
|
||||
self,
|
||||
document_id: str,
|
||||
name: str,
|
||||
text: str,
|
||||
extra_params: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Update a document by text.
|
||||
|
|
@ -368,7 +494,10 @@ class KnowledgeBaseClient(DifyClient):
|
|||
return self._send_request("POST", url, json=data, **kwargs)
|
||||
|
||||
def create_document_by_file(
|
||||
self, file_path: str, original_document_id: str | None = None, extra_params: dict | None = None
|
||||
self,
|
||||
file_path: str,
|
||||
original_document_id: str | None = None,
|
||||
extra_params: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Create a document by file.
|
||||
|
|
@ -395,17 +524,18 @@ class KnowledgeBaseClient(DifyClient):
|
|||
}
|
||||
:return: Response from the API
|
||||
"""
|
||||
files = {"file": open(file_path, "rb")}
|
||||
data = {
|
||||
"process_rule": {"mode": "automatic"},
|
||||
"indexing_technique": "high_quality",
|
||||
}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
if original_document_id is not None:
|
||||
data["original_document_id"] = original_document_id
|
||||
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
|
||||
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"file": (os.path.basename(file_path), f)}
|
||||
data = {
|
||||
"process_rule": {"mode": "automatic"},
|
||||
"indexing_technique": "high_quality",
|
||||
}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
if original_document_id is not None:
|
||||
data["original_document_id"] = original_document_id
|
||||
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
|
||||
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
|
||||
|
||||
def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None):
|
||||
"""
|
||||
|
|
@ -433,12 +563,13 @@ class KnowledgeBaseClient(DifyClient):
|
|||
}
|
||||
:return:
|
||||
"""
|
||||
files = {"file": open(file_path, "rb")}
|
||||
data = {}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
|
||||
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"file": (os.path.basename(file_path), f)}
|
||||
data = {}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
|
||||
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
|
||||
|
||||
def batch_indexing_status(self, batch_id: str, **kwargs):
|
||||
"""
|
||||
|
|
@ -516,6 +647,8 @@ class KnowledgeBaseClient(DifyClient):
|
|||
:param document_id: ID of the document
|
||||
:param keyword: query keyword, optional
|
||||
:param status: status of the segment, optional, e.g. completed
|
||||
:param kwargs: Additional parameters to pass to the API.
|
||||
Can include a 'params' dict for extra query parameters.
|
||||
"""
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
|
||||
params = {}
|
||||
|
|
@ -524,7 +657,7 @@ class KnowledgeBaseClient(DifyClient):
|
|||
if status is not None:
|
||||
params["status"] = status
|
||||
if "params" in kwargs:
|
||||
params.update(kwargs["params"])
|
||||
params.update(kwargs.pop("params"))
|
||||
return self._send_request("GET", url, params=params, **kwargs)
|
||||
|
||||
def delete_document_segment(self, document_id: str, segment_id: str):
|
||||
|
|
@ -553,7 +686,10 @@ class KnowledgeBaseClient(DifyClient):
|
|||
|
||||
# Advanced Knowledge Base APIs
|
||||
def hit_testing(
|
||||
self, query: str, retrieval_model: Dict[str, Any] = None, external_retrieval_model: Dict[str, Any] = None
|
||||
self,
|
||||
query: str,
|
||||
retrieval_model: Dict[str, Any] = None,
|
||||
external_retrieval_model: Dict[str, Any] = None,
|
||||
):
|
||||
"""Perform hit testing on the dataset."""
|
||||
data = {"query": query}
|
||||
|
|
@ -632,7 +768,11 @@ class KnowledgeBaseClient(DifyClient):
|
|||
credential_id: str = None,
|
||||
):
|
||||
"""Run a datasource node in RAG pipeline."""
|
||||
data = {"inputs": inputs, "datasource_type": datasource_type, "is_published": is_published}
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"datasource_type": datasource_type,
|
||||
"is_published": is_published,
|
||||
}
|
||||
if credential_id:
|
||||
data["credential_id"] = credential_id
|
||||
url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run"
|
||||
|
|
@ -662,5 +802,94 @@ class KnowledgeBaseClient(DifyClient):
|
|||
def upload_pipeline_file(self, file_path: str):
|
||||
"""Upload file for RAG pipeline."""
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"file": f}
|
||||
files = {"file": (os.path.basename(file_path), f)}
|
||||
return self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files)
|
||||
|
||||
# Dataset Management APIs
|
||||
def get_dataset(self, dataset_id: str | None = None):
|
||||
"""Get detailed information about a specific dataset.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID (optional, uses current dataset_id if not provided)
|
||||
|
||||
Returns:
|
||||
Response from the API containing dataset details including:
|
||||
- name, description, permission
|
||||
- indexing_technique, embedding_model, embedding_model_provider
|
||||
- retrieval_model configuration
|
||||
- document_count, word_count, app_count
|
||||
- created_at, updated_at
|
||||
"""
|
||||
ds_id = dataset_id or self._get_dataset_id()
|
||||
url = f"/datasets/{ds_id}"
|
||||
return self._send_request("GET", url)
|
||||
|
||||
def update_dataset(
|
||||
self,
|
||||
dataset_id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
indexing_technique: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_model_provider: str | None = None,
|
||||
retrieval_model: Dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Update dataset configuration.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID (optional, uses current dataset_id if not provided)
|
||||
name: New dataset name
|
||||
description: New dataset description
|
||||
indexing_technique: Indexing technique ('high_quality' or 'economy')
|
||||
embedding_model: Embedding model name
|
||||
embedding_model_provider: Embedding model provider
|
||||
retrieval_model: Retrieval model configuration dict
|
||||
**kwargs: Additional parameters to pass to the API
|
||||
|
||||
Returns:
|
||||
Response from the API with updated dataset information
|
||||
"""
|
||||
ds_id = dataset_id or self._get_dataset_id()
|
||||
url = f"/datasets/{ds_id}"
|
||||
|
||||
# Build data dictionary with all possible parameters
|
||||
payload = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"indexing_technique": indexing_technique,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_model_provider": embedding_model_provider,
|
||||
"retrieval_model": retrieval_model,
|
||||
}
|
||||
|
||||
# Filter out None values and merge with additional kwargs
|
||||
data = {k: v for k, v in payload.items() if v is not None}
|
||||
data.update(kwargs)
|
||||
|
||||
return self._send_request("PATCH", url, json=data)
|
||||
|
||||
def batch_update_document_status(
|
||||
self,
|
||||
action: Literal["enable", "disable", "archive", "un_archive"],
|
||||
document_ids: List[str],
|
||||
dataset_id: str | None = None,
|
||||
):
|
||||
"""Batch update document status (enable/disable/archive/unarchive).
|
||||
|
||||
Args:
|
||||
action: Action to perform on documents
|
||||
- 'enable': Enable documents for retrieval
|
||||
- 'disable': Disable documents from retrieval
|
||||
- 'archive': Archive documents
|
||||
- 'un_archive': Unarchive documents
|
||||
document_ids: List of document IDs to update
|
||||
dataset_id: Dataset ID (optional, uses current dataset_id if not provided)
|
||||
|
||||
Returns:
|
||||
Response from the API with operation result
|
||||
"""
|
||||
ds_id = dataset_id or self._get_dataset_id()
|
||||
url = f"/datasets/{ds_id}/documents/status/{action}"
|
||||
data = {"document_ids": document_ids}
|
||||
return self._send_request("PATCH", url, json=data)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
[project]
|
||||
name = "dify-client"
|
||||
version = "0.1.12"
|
||||
description = "A package for interacting with the Dify Service-API"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"httpx>=0.27.0",
|
||||
"aiofiles>=23.0.0",
|
||||
]
|
||||
authors = [
|
||||
{name = "Dify", email = "hello@dify.ai"}
|
||||
]
|
||||
license = {text = "MIT"}
|
||||
keywords = ["dify", "nlp", "ai", "language-processing"]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/langgenius/dify"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0.0",
|
||||
"pytest-asyncio>=0.21.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["dify_client"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
asyncio_mode = "auto"
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
from setuptools import setup
|
||||
|
||||
with open("README.md", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setup(
|
||||
name="dify-client",
|
||||
version="0.1.12",
|
||||
author="Dify",
|
||||
author_email="hello@dify.ai",
|
||||
description="A package for interacting with the Dify Service-API",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/langgenius/dify",
|
||||
license="MIT",
|
||||
packages=["dify_client"],
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
python_requires=">=3.6",
|
||||
install_requires=["requests"],
|
||||
keywords="dify nlp ai language-processing",
|
||||
include_package_data=True,
|
||||
)
|
||||
|
|
@ -0,0 +1,250 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test suite for async client implementation in the Python SDK.
|
||||
|
||||
This test validates the async/await functionality using httpx.AsyncClient
|
||||
and ensures API parity with sync clients.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
|
||||
from dify_client.async_client import (
|
||||
AsyncDifyClient,
|
||||
AsyncChatClient,
|
||||
AsyncCompletionClient,
|
||||
AsyncWorkflowClient,
|
||||
AsyncWorkspaceClient,
|
||||
AsyncKnowledgeBaseClient,
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncAPIParity(unittest.TestCase):
|
||||
"""Test that async clients have API parity with sync clients."""
|
||||
|
||||
def test_dify_client_api_parity(self):
|
||||
"""Test AsyncDifyClient has same methods as DifyClient."""
|
||||
from dify_client import DifyClient
|
||||
|
||||
sync_methods = {name for name in dir(DifyClient) if not name.startswith("_")}
|
||||
async_methods = {name for name in dir(AsyncDifyClient) if not name.startswith("_")}
|
||||
|
||||
# aclose is async-specific, close is sync-specific
|
||||
sync_methods.discard("close")
|
||||
async_methods.discard("aclose")
|
||||
|
||||
# Verify parity
|
||||
self.assertEqual(sync_methods, async_methods, "API parity mismatch for DifyClient")
|
||||
|
||||
def test_chat_client_api_parity(self):
|
||||
"""Test AsyncChatClient has same methods as ChatClient."""
|
||||
from dify_client import ChatClient
|
||||
|
||||
sync_methods = {name for name in dir(ChatClient) if not name.startswith("_")}
|
||||
async_methods = {name for name in dir(AsyncChatClient) if not name.startswith("_")}
|
||||
|
||||
sync_methods.discard("close")
|
||||
async_methods.discard("aclose")
|
||||
|
||||
self.assertEqual(sync_methods, async_methods, "API parity mismatch for ChatClient")
|
||||
|
||||
def test_completion_client_api_parity(self):
|
||||
"""Test AsyncCompletionClient has same methods as CompletionClient."""
|
||||
from dify_client import CompletionClient
|
||||
|
||||
sync_methods = {name for name in dir(CompletionClient) if not name.startswith("_")}
|
||||
async_methods = {name for name in dir(AsyncCompletionClient) if not name.startswith("_")}
|
||||
|
||||
sync_methods.discard("close")
|
||||
async_methods.discard("aclose")
|
||||
|
||||
self.assertEqual(sync_methods, async_methods, "API parity mismatch for CompletionClient")
|
||||
|
||||
def test_workflow_client_api_parity(self):
|
||||
"""Test AsyncWorkflowClient has same methods as WorkflowClient."""
|
||||
from dify_client import WorkflowClient
|
||||
|
||||
sync_methods = {name for name in dir(WorkflowClient) if not name.startswith("_")}
|
||||
async_methods = {name for name in dir(AsyncWorkflowClient) if not name.startswith("_")}
|
||||
|
||||
sync_methods.discard("close")
|
||||
async_methods.discard("aclose")
|
||||
|
||||
self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkflowClient")
|
||||
|
||||
def test_workspace_client_api_parity(self):
|
||||
"""Test AsyncWorkspaceClient has same methods as WorkspaceClient."""
|
||||
from dify_client import WorkspaceClient
|
||||
|
||||
sync_methods = {name for name in dir(WorkspaceClient) if not name.startswith("_")}
|
||||
async_methods = {name for name in dir(AsyncWorkspaceClient) if not name.startswith("_")}
|
||||
|
||||
sync_methods.discard("close")
|
||||
async_methods.discard("aclose")
|
||||
|
||||
self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkspaceClient")
|
||||
|
||||
def test_knowledge_base_client_api_parity(self):
|
||||
"""Test AsyncKnowledgeBaseClient has same methods as KnowledgeBaseClient."""
|
||||
from dify_client import KnowledgeBaseClient
|
||||
|
||||
sync_methods = {name for name in dir(KnowledgeBaseClient) if not name.startswith("_")}
|
||||
async_methods = {name for name in dir(AsyncKnowledgeBaseClient) if not name.startswith("_")}
|
||||
|
||||
sync_methods.discard("close")
|
||||
async_methods.discard("aclose")
|
||||
|
||||
self.assertEqual(sync_methods, async_methods, "API parity mismatch for KnowledgeBaseClient")
|
||||
|
||||
|
||||
class TestAsyncClientMocked(unittest.IsolatedAsyncioTestCase):
|
||||
"""Test async client with mocked httpx.AsyncClient."""
|
||||
|
||||
@patch("dify_client.async_client.httpx.AsyncClient")
|
||||
async def test_async_client_initialization(self, mock_httpx_async_client):
|
||||
"""Test async client initializes with httpx.AsyncClient."""
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_httpx_async_client.return_value = mock_client_instance
|
||||
|
||||
client = AsyncDifyClient("test-key", "https://api.dify.ai/v1")
|
||||
|
||||
# Verify httpx.AsyncClient was called
|
||||
mock_httpx_async_client.assert_called_once()
|
||||
self.assertEqual(client.api_key, "test-key")
|
||||
|
||||
await client.aclose()
|
||||
|
||||
@patch("dify_client.async_client.httpx.AsyncClient")
|
||||
async def test_async_context_manager(self, mock_httpx_async_client):
|
||||
"""Test async context manager works."""
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_httpx_async_client.return_value = mock_client_instance
|
||||
|
||||
async with AsyncDifyClient("test-key") as client:
|
||||
self.assertEqual(client.api_key, "test-key")
|
||||
|
||||
# Verify aclose was called
|
||||
mock_client_instance.aclose.assert_called_once()
|
||||
|
||||
@patch("dify_client.async_client.httpx.AsyncClient")
|
||||
async def test_async_send_request(self, mock_httpx_async_client):
|
||||
"""Test async _send_request method."""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.json = AsyncMock(return_value={"result": "success"})
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_httpx_async_client.return_value = mock_client_instance
|
||||
|
||||
async with AsyncDifyClient("test-key") as client:
|
||||
response = await client._send_request("GET", "/test")
|
||||
|
||||
# Verify request was called
|
||||
mock_client_instance.request.assert_called_once()
|
||||
call_args = mock_client_instance.request.call_args
|
||||
|
||||
# Verify parameters
|
||||
self.assertEqual(call_args[0][0], "GET")
|
||||
self.assertEqual(call_args[0][1], "/test")
|
||||
|
||||
@patch("dify_client.async_client.httpx.AsyncClient")
|
||||
async def test_async_chat_client(self, mock_httpx_async_client):
|
||||
"""Test AsyncChatClient functionality."""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.text = '{"answer": "Hello!"}'
|
||||
mock_response.json = AsyncMock(return_value={"answer": "Hello!"})
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_httpx_async_client.return_value = mock_client_instance
|
||||
|
||||
async with AsyncChatClient("test-key") as client:
|
||||
response = await client.create_chat_message({}, "Hi", "user123")
|
||||
self.assertIn("answer", response.text)
|
||||
|
||||
@patch("dify_client.async_client.httpx.AsyncClient")
|
||||
async def test_async_completion_client(self, mock_httpx_async_client):
|
||||
"""Test AsyncCompletionClient functionality."""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.text = '{"answer": "Response"}'
|
||||
mock_response.json = AsyncMock(return_value={"answer": "Response"})
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_httpx_async_client.return_value = mock_client_instance
|
||||
|
||||
async with AsyncCompletionClient("test-key") as client:
|
||||
response = await client.create_completion_message({"query": "test"}, "blocking", "user123")
|
||||
self.assertIn("answer", response.text)
|
||||
|
||||
@patch("dify_client.async_client.httpx.AsyncClient")
|
||||
async def test_async_workflow_client(self, mock_httpx_async_client):
|
||||
"""Test AsyncWorkflowClient functionality."""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.json = AsyncMock(return_value={"result": "success"})
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_httpx_async_client.return_value = mock_client_instance
|
||||
|
||||
async with AsyncWorkflowClient("test-key") as client:
|
||||
response = await client.run({"input": "test"}, "blocking", "user123")
|
||||
data = await response.json()
|
||||
self.assertEqual(data["result"], "success")
|
||||
|
||||
@patch("dify_client.async_client.httpx.AsyncClient")
|
||||
async def test_async_workspace_client(self, mock_httpx_async_client):
|
||||
"""Test AsyncWorkspaceClient functionality."""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.json = AsyncMock(return_value={"data": []})
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_httpx_async_client.return_value = mock_client_instance
|
||||
|
||||
async with AsyncWorkspaceClient("test-key") as client:
|
||||
response = await client.get_available_models("llm")
|
||||
data = await response.json()
|
||||
self.assertIn("data", data)
|
||||
|
||||
@patch("dify_client.async_client.httpx.AsyncClient")
|
||||
async def test_async_knowledge_base_client(self, mock_httpx_async_client):
|
||||
"""Test AsyncKnowledgeBaseClient functionality."""
|
||||
mock_response = AsyncMock()
|
||||
mock_response.json = AsyncMock(return_value={"data": [], "total": 0})
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_httpx_async_client.return_value = mock_client_instance
|
||||
|
||||
async with AsyncKnowledgeBaseClient("test-key") as client:
|
||||
response = await client.list_datasets()
|
||||
data = await response.json()
|
||||
self.assertIn("data", data)
|
||||
|
||||
@patch("dify_client.async_client.httpx.AsyncClient")
|
||||
async def test_all_async_client_classes(self, mock_httpx_async_client):
|
||||
"""Test all async client classes work with httpx.AsyncClient."""
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_httpx_async_client.return_value = mock_client_instance
|
||||
|
||||
clients = [
|
||||
AsyncDifyClient("key"),
|
||||
AsyncChatClient("key"),
|
||||
AsyncCompletionClient("key"),
|
||||
AsyncWorkflowClient("key"),
|
||||
AsyncWorkspaceClient("key"),
|
||||
AsyncKnowledgeBaseClient("key"),
|
||||
]
|
||||
|
||||
# Verify httpx.AsyncClient was called for each
|
||||
self.assertEqual(mock_httpx_async_client.call_count, 6)
|
||||
|
||||
# Clean up
|
||||
for client in clients:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,331 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test suite for httpx migration in the Python SDK.
|
||||
|
||||
This test validates that the migration from requests to httpx maintains
|
||||
backward compatibility and proper resource management.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from dify_client import (
|
||||
DifyClient,
|
||||
ChatClient,
|
||||
CompletionClient,
|
||||
WorkflowClient,
|
||||
WorkspaceClient,
|
||||
KnowledgeBaseClient,
|
||||
)
|
||||
|
||||
|
||||
class TestHttpxMigrationMocked(unittest.TestCase):
|
||||
"""Test cases for httpx migration with mocked requests."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.api_key = "test-api-key"
|
||||
self.base_url = "https://api.dify.ai/v1"
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_client_initialization(self, mock_httpx_client):
|
||||
"""Test that client initializes with httpx.Client."""
|
||||
mock_client_instance = Mock()
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
client = DifyClient(self.api_key, self.base_url)
|
||||
|
||||
# Verify httpx.Client was called with correct parameters
|
||||
mock_httpx_client.assert_called_once()
|
||||
call_kwargs = mock_httpx_client.call_args[1]
|
||||
self.assertEqual(call_kwargs["base_url"], self.base_url)
|
||||
|
||||
# Verify client properties
|
||||
self.assertEqual(client.api_key, self.api_key)
|
||||
self.assertEqual(client.base_url, self.base_url)
|
||||
|
||||
client.close()
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_context_manager_support(self, mock_httpx_client):
|
||||
"""Test that client works as context manager."""
|
||||
mock_client_instance = Mock()
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
with DifyClient(self.api_key, self.base_url) as client:
|
||||
self.assertEqual(client.api_key, self.api_key)
|
||||
|
||||
# Verify close was called
|
||||
mock_client_instance.close.assert_called_once()
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_manual_close(self, mock_httpx_client):
|
||||
"""Test manual close() method."""
|
||||
mock_client_instance = Mock()
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
client = DifyClient(self.api_key, self.base_url)
|
||||
client.close()
|
||||
|
||||
# Verify close was called
|
||||
mock_client_instance.close.assert_called_once()
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_send_request_httpx_compatibility(self, mock_httpx_client):
|
||||
"""Test _send_request uses httpx.Client.request properly."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
client = DifyClient(self.api_key, self.base_url)
|
||||
response = client._send_request("GET", "/test-endpoint")
|
||||
|
||||
# Verify httpx.Client.request was called correctly
|
||||
mock_client_instance.request.assert_called_once()
|
||||
call_args = mock_client_instance.request.call_args
|
||||
|
||||
# Verify method and endpoint
|
||||
self.assertEqual(call_args[0][0], "GET")
|
||||
self.assertEqual(call_args[0][1], "/test-endpoint")
|
||||
|
||||
# Verify headers contain authorization
|
||||
headers = call_args[1]["headers"]
|
||||
self.assertEqual(headers["Authorization"], f"Bearer {self.api_key}")
|
||||
self.assertEqual(headers["Content-Type"], "application/json")
|
||||
|
||||
client.close()
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_response_compatibility(self, mock_httpx_client):
|
||||
"""Test httpx.Response is compatible with requests.Response API."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"key": "value"}
|
||||
mock_response.text = '{"key": "value"}'
|
||||
mock_response.content = b'{"key": "value"}'
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
client = DifyClient(self.api_key, self.base_url)
|
||||
response = client._send_request("GET", "/test")
|
||||
|
||||
# Verify all common response methods work
|
||||
self.assertEqual(response.json(), {"key": "value"})
|
||||
self.assertEqual(response.text, '{"key": "value"}')
|
||||
self.assertEqual(response.content, b'{"key": "value"}')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.headers["Content-Type"], "application/json")
|
||||
|
||||
client.close()
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_all_client_classes_use_httpx(self, mock_httpx_client):
|
||||
"""Test that all client classes properly use httpx."""
|
||||
mock_client_instance = Mock()
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
clients = [
|
||||
DifyClient(self.api_key, self.base_url),
|
||||
ChatClient(self.api_key, self.base_url),
|
||||
CompletionClient(self.api_key, self.base_url),
|
||||
WorkflowClient(self.api_key, self.base_url),
|
||||
WorkspaceClient(self.api_key, self.base_url),
|
||||
KnowledgeBaseClient(self.api_key, self.base_url),
|
||||
]
|
||||
|
||||
# Verify httpx.Client was called for each client
|
||||
self.assertEqual(mock_httpx_client.call_count, 6)
|
||||
|
||||
# Clean up
|
||||
for client in clients:
|
||||
client.close()
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_json_parameter_handling(self, mock_httpx_client):
|
||||
"""Test that json parameter is passed correctly."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
client = DifyClient(self.api_key, self.base_url)
|
||||
test_data = {"key": "value", "number": 123}
|
||||
|
||||
client._send_request("POST", "/test", json=test_data)
|
||||
|
||||
# Verify json parameter was passed
|
||||
call_args = mock_client_instance.request.call_args
|
||||
self.assertEqual(call_args[1]["json"], test_data)
|
||||
|
||||
client.close()
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_params_parameter_handling(self, mock_httpx_client):
|
||||
"""Test that params parameter is passed correctly."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
client = DifyClient(self.api_key, self.base_url)
|
||||
test_params = {"page": 1, "limit": 20}
|
||||
|
||||
client._send_request("GET", "/test", params=test_params)
|
||||
|
||||
# Verify params parameter was passed
|
||||
call_args = mock_client_instance.request.call_args
|
||||
self.assertEqual(call_args[1]["params"], test_params)
|
||||
|
||||
client.close()
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_inheritance_chain(self, mock_httpx_client):
|
||||
"""Test that inheritance chain is maintained."""
|
||||
mock_client_instance = Mock()
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
# ChatClient inherits from DifyClient
|
||||
chat_client = ChatClient(self.api_key, self.base_url)
|
||||
self.assertIsInstance(chat_client, DifyClient)
|
||||
|
||||
# CompletionClient inherits from DifyClient
|
||||
completion_client = CompletionClient(self.api_key, self.base_url)
|
||||
self.assertIsInstance(completion_client, DifyClient)
|
||||
|
||||
# WorkflowClient inherits from DifyClient
|
||||
workflow_client = WorkflowClient(self.api_key, self.base_url)
|
||||
self.assertIsInstance(workflow_client, DifyClient)
|
||||
|
||||
# Clean up
|
||||
chat_client.close()
|
||||
completion_client.close()
|
||||
workflow_client.close()
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_nested_context_managers(self, mock_httpx_client):
|
||||
"""Test nested context managers work correctly."""
|
||||
mock_client_instance = Mock()
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
with DifyClient(self.api_key, self.base_url) as client1:
|
||||
with ChatClient(self.api_key, self.base_url) as client2:
|
||||
self.assertEqual(client1.api_key, self.api_key)
|
||||
self.assertEqual(client2.api_key, self.api_key)
|
||||
|
||||
# Both close methods should have been called
|
||||
self.assertEqual(mock_client_instance.close.call_count, 2)
|
||||
|
||||
|
||||
class TestChatClientHttpx(unittest.TestCase):
|
||||
"""Test ChatClient specific httpx integration."""
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_create_chat_message_httpx(self, mock_httpx_client):
|
||||
"""Test create_chat_message works with httpx."""
|
||||
mock_response = Mock()
|
||||
mock_response.text = '{"answer": "Hello!"}'
|
||||
mock_response.json.return_value = {"answer": "Hello!"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
with ChatClient("test-key") as client:
|
||||
response = client.create_chat_message({}, "Hi", "user123")
|
||||
self.assertIn("answer", response.text)
|
||||
self.assertEqual(response.json()["answer"], "Hello!")
|
||||
|
||||
|
||||
class TestCompletionClientHttpx(unittest.TestCase):
|
||||
"""Test CompletionClient specific httpx integration."""
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_create_completion_message_httpx(self, mock_httpx_client):
|
||||
"""Test create_completion_message works with httpx."""
|
||||
mock_response = Mock()
|
||||
mock_response.text = '{"answer": "Response"}'
|
||||
mock_response.json.return_value = {"answer": "Response"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
with CompletionClient("test-key") as client:
|
||||
response = client.create_completion_message({"query": "test"}, "blocking", "user123")
|
||||
self.assertIn("answer", response.text)
|
||||
|
||||
|
||||
class TestKnowledgeBaseClientHttpx(unittest.TestCase):
|
||||
"""Test KnowledgeBaseClient specific httpx integration."""
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_list_datasets_httpx(self, mock_httpx_client):
|
||||
"""Test list_datasets works with httpx."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"data": [], "total": 0}
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
with KnowledgeBaseClient("test-key") as client:
|
||||
response = client.list_datasets()
|
||||
data = response.json()
|
||||
self.assertIn("data", data)
|
||||
self.assertIn("total", data)
|
||||
|
||||
|
||||
class TestWorkflowClientHttpx(unittest.TestCase):
|
||||
"""Test WorkflowClient specific httpx integration."""
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_run_workflow_httpx(self, mock_httpx_client):
|
||||
"""Test run workflow works with httpx."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
with WorkflowClient("test-key") as client:
|
||||
response = client.run({"input": "test"}, "blocking", "user123")
|
||||
self.assertEqual(response.json()["result"], "success")
|
||||
|
||||
|
||||
class TestWorkspaceClientHttpx(unittest.TestCase):
|
||||
"""Test WorkspaceClient specific httpx integration."""
|
||||
|
||||
@patch("dify_client.client.httpx.Client")
|
||||
def test_get_available_models_httpx(self, mock_httpx_client):
|
||||
"""Test get_available_models works with httpx."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"data": []}
|
||||
mock_response.status_code = 200
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_httpx_client.return_value = mock_client_instance
|
||||
|
||||
with WorkspaceClient("test-key") as client:
|
||||
response = client.get_available_models("llm")
|
||||
self.assertIn("data", response.json())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,416 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test suite for the new Service API functionality in the Python SDK.
|
||||
|
||||
This test validates the implementation of the missing Service API endpoints
|
||||
that were added to the Python SDK to achieve complete coverage.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import json
|
||||
|
||||
from dify_client import (
|
||||
DifyClient,
|
||||
ChatClient,
|
||||
WorkflowClient,
|
||||
KnowledgeBaseClient,
|
||||
WorkspaceClient,
|
||||
)
|
||||
|
||||
|
||||
class TestNewServiceAPIs(unittest.TestCase):
|
||||
"""Test cases for new Service API implementations."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.api_key = "test-api-key"
|
||||
self.base_url = "https://api.dify.ai/v1"
|
||||
|
||||
@patch("dify_client.client.requests.request")
|
||||
def test_app_info_apis(self, mock_request):
|
||||
"""Test application info APIs."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"name": "Test App",
|
||||
"description": "Test Description",
|
||||
"tags": ["test", "api"],
|
||||
"mode": "chat",
|
||||
"author_name": "Test Author",
|
||||
}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
client = DifyClient(self.api_key, self.base_url)
|
||||
|
||||
# Test get_app_info
|
||||
result = client.get_app_info()
|
||||
mock_request.assert_called_with(
|
||||
"GET",
|
||||
f"{self.base_url}/info",
|
||||
json=None,
|
||||
params=None,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Test get_app_site_info
|
||||
client.get_app_site_info()
|
||||
mock_request.assert_called_with(
|
||||
"GET",
|
||||
f"{self.base_url}/site",
|
||||
json=None,
|
||||
params=None,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Test get_file_preview
|
||||
file_id = "test-file-id"
|
||||
client.get_file_preview(file_id)
|
||||
mock_request.assert_called_with(
|
||||
"GET",
|
||||
f"{self.base_url}/files/{file_id}/preview",
|
||||
json=None,
|
||||
params=None,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
@patch("dify_client.client.requests.request")
|
||||
def test_annotation_apis(self, mock_request):
|
||||
"""Test annotation APIs."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
client = ChatClient(self.api_key, self.base_url)
|
||||
|
||||
# Test annotation_reply_action - enable
|
||||
client.annotation_reply_action(
|
||||
action="enable",
|
||||
score_threshold=0.8,
|
||||
embedding_provider_name="openai",
|
||||
embedding_model_name="text-embedding-ada-002",
|
||||
)
|
||||
mock_request.assert_called_with(
|
||||
"POST",
|
||||
f"{self.base_url}/apps/annotation-reply/enable",
|
||||
json={
|
||||
"score_threshold": 0.8,
|
||||
"embedding_provider_name": "openai",
|
||||
"embedding_model_name": "text-embedding-ada-002",
|
||||
},
|
||||
params=None,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Test annotation_reply_action - disable (now requires same fields as enable)
|
||||
client.annotation_reply_action(
|
||||
action="disable",
|
||||
score_threshold=0.5,
|
||||
embedding_provider_name="openai",
|
||||
embedding_model_name="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
# Test annotation_reply_action with score_threshold=0 (edge case)
|
||||
client.annotation_reply_action(
|
||||
action="enable",
|
||||
score_threshold=0.0, # This should work and not raise ValueError
|
||||
embedding_provider_name="openai",
|
||||
embedding_model_name="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
# Test get_annotation_reply_status
|
||||
client.get_annotation_reply_status("enable", "job-123")
|
||||
|
||||
# Test list_annotations
|
||||
client.list_annotations(page=1, limit=20, keyword="test")
|
||||
|
||||
# Test create_annotation
|
||||
client.create_annotation("Test question?", "Test answer.")
|
||||
|
||||
# Test update_annotation
|
||||
client.update_annotation("annotation-123", "Updated question?", "Updated answer.")
|
||||
|
||||
# Test delete_annotation
|
||||
client.delete_annotation("annotation-123")
|
||||
|
||||
# Verify all calls were made (8 calls: enable + disable + enable with 0.0 + 5 other operations)
|
||||
self.assertEqual(mock_request.call_count, 8)
|
||||
|
||||
@patch("dify_client.client.requests.request")
|
||||
def test_knowledge_base_advanced_apis(self, mock_request):
|
||||
"""Test advanced knowledge base APIs."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
dataset_id = "test-dataset-id"
|
||||
client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id)
|
||||
|
||||
# Test hit_testing
|
||||
client.hit_testing("test query", {"type": "vector"})
|
||||
mock_request.assert_called_with(
|
||||
"POST",
|
||||
f"{self.base_url}/datasets/{dataset_id}/hit-testing",
|
||||
json={"query": "test query", "retrieval_model": {"type": "vector"}},
|
||||
params=None,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Test metadata operations
|
||||
client.get_dataset_metadata()
|
||||
client.create_dataset_metadata({"key": "value"})
|
||||
client.update_dataset_metadata("meta-123", {"key": "new_value"})
|
||||
client.get_built_in_metadata()
|
||||
client.manage_built_in_metadata("enable", {"type": "built_in"})
|
||||
client.update_documents_metadata([{"document_id": "doc1", "metadata": {"key": "value"}}])
|
||||
|
||||
# Test tag operations
|
||||
client.list_dataset_tags()
|
||||
client.bind_dataset_tags(["tag1", "tag2"])
|
||||
client.unbind_dataset_tag("tag1")
|
||||
client.get_dataset_tags()
|
||||
|
||||
# Verify multiple calls were made
|
||||
self.assertGreater(mock_request.call_count, 5)
|
||||
|
||||
@patch("dify_client.client.requests.request")
|
||||
def test_rag_pipeline_apis(self, mock_request):
|
||||
"""Test RAG pipeline APIs."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
dataset_id = "test-dataset-id"
|
||||
client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id)
|
||||
|
||||
# Test get_datasource_plugins
|
||||
client.get_datasource_plugins(is_published=True)
|
||||
mock_request.assert_called_with(
|
||||
"GET",
|
||||
f"{self.base_url}/datasets/{dataset_id}/pipeline/datasource-plugins",
|
||||
json=None,
|
||||
params={"is_published": True},
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Test run_datasource_node
|
||||
client.run_datasource_node(
|
||||
node_id="node-123",
|
||||
inputs={"param": "value"},
|
||||
datasource_type="online_document",
|
||||
is_published=True,
|
||||
credential_id="cred-123",
|
||||
)
|
||||
|
||||
# Test run_rag_pipeline with blocking mode
|
||||
client.run_rag_pipeline(
|
||||
inputs={"query": "test"},
|
||||
datasource_type="online_document",
|
||||
datasource_info_list=[{"id": "ds1"}],
|
||||
start_node_id="start-node",
|
||||
is_published=True,
|
||||
response_mode="blocking",
|
||||
)
|
||||
|
||||
# Test run_rag_pipeline with streaming mode
|
||||
client.run_rag_pipeline(
|
||||
inputs={"query": "test"},
|
||||
datasource_type="online_document",
|
||||
datasource_info_list=[{"id": "ds1"}],
|
||||
start_node_id="start-node",
|
||||
is_published=True,
|
||||
response_mode="streaming",
|
||||
)
|
||||
|
||||
self.assertEqual(mock_request.call_count, 4)
|
||||
|
||||
@patch("dify_client.client.requests.request")
|
||||
def test_workspace_apis(self, mock_request):
|
||||
"""Test workspace APIs."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"data": [{"name": "gpt-3.5-turbo", "type": "llm"}, {"name": "gpt-4", "type": "llm"}]
|
||||
}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
client = WorkspaceClient(self.api_key, self.base_url)
|
||||
|
||||
# Test get_available_models
|
||||
result = client.get_available_models("llm")
|
||||
mock_request.assert_called_with(
|
||||
"GET",
|
||||
f"{self.base_url}/workspaces/current/models/model-types/llm",
|
||||
json=None,
|
||||
params=None,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
@patch("dify_client.client.requests.request")
|
||||
def test_workflow_advanced_apis(self, mock_request):
|
||||
"""Test advanced workflow APIs."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
client = WorkflowClient(self.api_key, self.base_url)
|
||||
|
||||
# Test get_workflow_logs
|
||||
client.get_workflow_logs(keyword="test", status="succeeded", page=1, limit=20)
|
||||
mock_request.assert_called_with(
|
||||
"GET",
|
||||
f"{self.base_url}/workflows/logs",
|
||||
json=None,
|
||||
params={"page": 1, "limit": 20, "keyword": "test", "status": "succeeded"},
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Test get_workflow_logs with additional filters
|
||||
client.get_workflow_logs(
|
||||
keyword="test",
|
||||
status="succeeded",
|
||||
page=1,
|
||||
limit=20,
|
||||
created_at__before="2024-01-01",
|
||||
created_at__after="2023-01-01",
|
||||
created_by_account="user123",
|
||||
)
|
||||
|
||||
# Test run_specific_workflow
|
||||
client.run_specific_workflow(
|
||||
workflow_id="workflow-123", inputs={"param": "value"}, response_mode="streaming", user="user-123"
|
||||
)
|
||||
|
||||
self.assertEqual(mock_request.call_count, 3)
|
||||
|
||||
def test_error_handling(self):
|
||||
"""Test error handling for required parameters."""
|
||||
client = ChatClient(self.api_key, self.base_url)
|
||||
|
||||
# Test annotation_reply_action with missing required parameters would be a TypeError now
|
||||
# since parameters are required in method signature
|
||||
with self.assertRaises(TypeError):
|
||||
client.annotation_reply_action("enable")
|
||||
|
||||
# Test annotation_reply_action with explicit None values should raise ValueError
|
||||
with self.assertRaises(ValueError) as context:
|
||||
client.annotation_reply_action("enable", None, "provider", "model")
|
||||
|
||||
self.assertIn("cannot be None", str(context.exception))
|
||||
|
||||
# Test KnowledgeBaseClient without dataset_id
|
||||
kb_client = KnowledgeBaseClient(self.api_key, self.base_url)
|
||||
with self.assertRaises(ValueError) as context:
|
||||
kb_client.hit_testing("test query")
|
||||
|
||||
self.assertIn("dataset_id is not set", str(context.exception))
|
||||
|
||||
@patch("dify_client.client.open")
|
||||
@patch("dify_client.client.requests.request")
|
||||
def test_file_upload_apis(self, mock_request, mock_open):
|
||||
"""Test file upload APIs."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
mock_file = MagicMock()
|
||||
mock_open.return_value.__enter__.return_value = mock_file
|
||||
|
||||
dataset_id = "test-dataset-id"
|
||||
client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id)
|
||||
|
||||
# Test upload_pipeline_file
|
||||
client.upload_pipeline_file("/path/to/test.pdf")
|
||||
|
||||
mock_open.assert_called_with("/path/to/test.pdf", "rb")
|
||||
mock_request.assert_called_once()
|
||||
|
||||
def test_comprehensive_coverage(self):
|
||||
"""Test that all previously missing APIs are now implemented."""
|
||||
|
||||
# Test DifyClient methods
|
||||
dify_methods = ["get_app_info", "get_app_site_info", "get_file_preview"]
|
||||
client = DifyClient(self.api_key)
|
||||
for method in dify_methods:
|
||||
self.assertTrue(hasattr(client, method), f"DifyClient missing method: {method}")
|
||||
|
||||
# Test ChatClient annotation methods
|
||||
chat_methods = [
|
||||
"annotation_reply_action",
|
||||
"get_annotation_reply_status",
|
||||
"list_annotations",
|
||||
"create_annotation",
|
||||
"update_annotation",
|
||||
"delete_annotation",
|
||||
]
|
||||
chat_client = ChatClient(self.api_key)
|
||||
for method in chat_methods:
|
||||
self.assertTrue(hasattr(chat_client, method), f"ChatClient missing method: {method}")
|
||||
|
||||
# Test WorkflowClient advanced methods
|
||||
workflow_methods = ["get_workflow_logs", "run_specific_workflow"]
|
||||
workflow_client = WorkflowClient(self.api_key)
|
||||
for method in workflow_methods:
|
||||
self.assertTrue(hasattr(workflow_client, method), f"WorkflowClient missing method: {method}")
|
||||
|
||||
# Test KnowledgeBaseClient advanced methods
|
||||
kb_methods = [
|
||||
"hit_testing",
|
||||
"get_dataset_metadata",
|
||||
"create_dataset_metadata",
|
||||
"update_dataset_metadata",
|
||||
"get_built_in_metadata",
|
||||
"manage_built_in_metadata",
|
||||
"update_documents_metadata",
|
||||
"list_dataset_tags",
|
||||
"bind_dataset_tags",
|
||||
"unbind_dataset_tag",
|
||||
"get_dataset_tags",
|
||||
"get_datasource_plugins",
|
||||
"run_datasource_node",
|
||||
"run_rag_pipeline",
|
||||
"upload_pipeline_file",
|
||||
]
|
||||
kb_client = KnowledgeBaseClient(self.api_key)
|
||||
for method in kb_methods:
|
||||
self.assertTrue(hasattr(kb_client, method), f"KnowledgeBaseClient missing method: {method}")
|
||||
|
||||
# Test WorkspaceClient methods
|
||||
workspace_methods = ["get_available_models"]
|
||||
workspace_client = WorkspaceClient(self.api_key)
|
||||
for method in workspace_methods:
|
||||
self.assertTrue(hasattr(workspace_client, method), f"WorkspaceClient missing method: {method}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,271 @@
|
|||
version = 1
|
||||
revision = 3
|
||||
requires-python = ">=3.10"
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
version = "25.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/41/c3/534eac40372d8ee36ef40df62ec129bee4fdb5ad9706e58a29be53b2c970/aiofiles-25.1.0.tar.gz", hash = "sha256:a8d728f0a29de45dc521f18f07297428d56992a742f0cd2701ba86e44d23d5b2", size = 46354, upload-time = "2025-10-09T20:51:04.358Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/8a/340a1555ae33d7354dbca4faa54948d76d89a27ceef032c8c3bc661d003e/aiofiles-25.1.0-py3-none-any.whl", hash = "sha256:abe311e527c862958650f9438e859c1fa7568a141b22abcd015e120e86a85695", size = 14668, upload-time = "2025-10-09T20:51:03.174Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.11.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "idna" },
|
||||
{ name = "sniffio" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c6/78/7d432127c41b50bccba979505f272c16cbcadcc33645d5fa3a738110ae75/anyio-4.11.0.tar.gz", hash = "sha256:82a8d0b81e318cc5ce71a5f1f8b5c4e63619620b63141ef8c995fa0db95a57c4", size = 219094, upload-time = "2025-09-23T09:19:12.58Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backports-asyncio-runner"
|
||||
version = "1.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2025.10.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4c/5b/b6ce21586237c77ce67d01dc5507039d444b630dd76611bbca2d8e5dcd91/certifi-2025.10.5.tar.gz", hash = "sha256:47c09d31ccf2acf0be3f701ea53595ee7e0b8fa08801c6624be771df09ae7b43", size = 164519, upload-time = "2025-10-05T04:12:15.808Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e4/37/af0d2ef3967ac0d6113837b44a4f0bfe1328c2b9763bd5b1744520e5cfed/certifi-2025.10.5-py3-none-any.whl", hash = "sha256:0f212c2744a9bb6de0c56639a6f68afe01ecd92d91f14ae897c4fe7bbeeef0de", size = 163286, upload-time = "2025-10-05T04:12:14.03Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dify-client"
|
||||
version = "0.1.12"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
{ name = "httpx" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
dev = [
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiofiles", specifier = ">=23.0.0" },
|
||||
{ name = "httpx", specifier = ">=0.27.0" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" },
|
||||
]
|
||||
provides-extras = ["dev"]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h11"
|
||||
version = "0.16.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpcore"
|
||||
version = "1.0.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "h11" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpx"
|
||||
version = "0.28.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "certifi" },
|
||||
{ name = "httpcore" },
|
||||
{ name = "idna" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.10"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "25.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.19.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.4.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "iniconfig" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pluggy" },
|
||||
{ name = "pygments" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "1.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" },
|
||||
{ name = "pytest" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/48/06ee6eabe4fdd9ecd48bf488f4ac783844fd777f547b8d1b61c11939974e/tomli-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5192f562738228945d7b13d4930baffda67b69425a7f0da96d360b0a3888136b", size = 154819, upload-time = "2025-10-08T22:01:17.964Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/01/88793757d54d8937015c75dcdfb673c65471945f6be98e6a0410fba167ed/tomli-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:be71c93a63d738597996be9528f4abe628d1adf5e6eb11607bc8fe1a510b5dae", size = 148766, upload-time = "2025-10-08T22:01:18.959Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/17/5e2c956f0144b812e7e107f94f1cc54af734eb17b5191c0bbfb72de5e93e/tomli-2.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4665508bcbac83a31ff8ab08f424b665200c0e1e645d2bd9ab3d3e557b6185b", size = 240771, upload-time = "2025-10-08T22:01:20.106Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/f4/0fbd014909748706c01d16824eadb0307115f9562a15cbb012cd9b3512c5/tomli-2.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4021923f97266babc6ccab9f5068642a0095faa0a51a246a6a02fccbb3514eaf", size = 248586, upload-time = "2025-10-08T22:01:21.164Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/77/fed85e114bde5e81ecf9bc5da0cc69f2914b38f4708c80ae67d0c10180c5/tomli-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4ea38c40145a357d513bffad0ed869f13c1773716cf71ccaa83b0fa0cc4e42f", size = 244792, upload-time = "2025-10-08T22:01:22.417Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/55/92/afed3d497f7c186dc71e6ee6d4fcb0acfa5f7d0a1a2878f8beae379ae0cc/tomli-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad805ea85eda330dbad64c7ea7a4556259665bdf9d2672f5dccc740eb9d3ca05", size = 248909, upload-time = "2025-10-08T22:01:23.859Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/84/ef50c51b5a9472e7265ce1ffc7f24cd4023d289e109f669bdb1553f6a7c2/tomli-2.3.0-cp313-cp313-win32.whl", hash = "sha256:97d5eec30149fd3294270e889b4234023f2c69747e555a27bd708828353ab606", size = 96946, upload-time = "2025-10-08T22:01:24.893Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/b7/718cd1da0884f281f95ccfa3a6cc572d30053cba64603f79d431d3c9b61b/tomli-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0c95ca56fbe89e065c6ead5b593ee64b84a26fca063b5d71a1122bf26e533999", size = 107705, upload-time = "2025-10-08T22:01:26.153Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/94/aeafa14a52e16163008060506fcb6aa1949d13548d13752171a755c65611/tomli-2.3.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:cebc6fe843e0733ee827a282aca4999b596241195f43b4cc371d64fc6639da9e", size = 154244, upload-time = "2025-10-08T22:01:27.06Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/e4/1e58409aa78eefa47ccd19779fc6f36787edbe7d4cd330eeeedb33a4515b/tomli-2.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4c2ef0244c75aba9355561272009d934953817c49f47d768070c3c94355c2aa3", size = 148637, upload-time = "2025-10-08T22:01:28.059Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/b6/d1eccb62f665e44359226811064596dd6a366ea1f985839c566cd61525ae/tomli-2.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c22a8bf253bacc0cf11f35ad9808b6cb75ada2631c2d97c971122583b129afbc", size = 241925, upload-time = "2025-10-08T22:01:29.066Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/70/91/7cdab9a03e6d3d2bb11beae108da5bdc1c34bdeb06e21163482544ddcc90/tomli-2.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0eea8cc5c5e9f89c9b90c4896a8deefc74f518db5927d0e0e8d4a80953d774d0", size = 249045, upload-time = "2025-10-08T22:01:31.98Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/1b/8c26874ed1f6e4f1fcfeb868db8a794cbe9f227299402db58cfcc858766c/tomli-2.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b74a0e59ec5d15127acdabd75ea17726ac4c5178ae51b85bfe39c4f8a278e879", size = 245835, upload-time = "2025-10-08T22:01:32.989Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/42/8e3c6a9a4b1a1360c1a2a39f0b972cef2cc9ebd56025168c4137192a9321/tomli-2.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b5870b50c9db823c595983571d1296a6ff3e1b88f734a4c8f6fc6188397de005", size = 253109, upload-time = "2025-10-08T22:01:34.052Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/22/0c/b4da635000a71b5f80130937eeac12e686eefb376b8dee113b4a582bba42/tomli-2.3.0-cp314-cp314-win32.whl", hash = "sha256:feb0dacc61170ed7ab602d3d972a58f14ee3ee60494292d384649a3dc38ef463", size = 97930, upload-time = "2025-10-08T22:01:35.082Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/74/cb1abc870a418ae99cd5c9547d6bce30701a954e0e721821df483ef7223c/tomli-2.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:b273fcbd7fc64dc3600c098e39136522650c49bca95df2d11cf3b626422392c8", size = 107964, upload-time = "2025-10-08T22:01:36.057Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/54/78/5c46fff6432a712af9f792944f4fcd7067d8823157949f4e40c56b8b3c83/tomli-2.3.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:940d56ee0410fa17ee1f12b817b37a4d4e4dc4d27340863cc67236c74f582e77", size = 163065, upload-time = "2025-10-08T22:01:37.27Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/67/f85d9bd23182f45eca8939cd2bc7050e1f90c41f4a2ecbbd5963a1d1c486/tomli-2.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f85209946d1fe94416debbb88d00eb92ce9cd5266775424ff81bc959e001acaf", size = 159088, upload-time = "2025-10-08T22:01:38.235Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/5a/4b546a0405b9cc0659b399f12b6adb750757baf04250b148d3c5059fc4eb/tomli-2.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a56212bdcce682e56b0aaf79e869ba5d15a6163f88d5451cbde388d48b13f530", size = 268193, upload-time = "2025-10-08T22:01:39.712Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/4f/2c12a72ae22cf7b59a7fe75b3465b7aba40ea9145d026ba41cb382075b0e/tomli-2.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5f3ffd1e098dfc032d4d3af5c0ac64f6d286d98bc148698356847b80fa4de1b", size = 275488, upload-time = "2025-10-08T22:01:40.773Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/04/a038d65dbe160c3aa5a624e93ad98111090f6804027d474ba9c37c8ae186/tomli-2.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5e01decd096b1530d97d5d85cb4dff4af2d8347bd35686654a004f8dea20fc67", size = 272669, upload-time = "2025-10-08T22:01:41.824Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/2f/8b7c60a9d1612a7cbc39ffcca4f21a73bf368a80fc25bccf8253e2563267/tomli-2.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8a35dd0e643bb2610f156cca8db95d213a90015c11fee76c946aa62b7ae7e02f", size = 279709, upload-time = "2025-10-08T22:01:43.177Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/46/cc36c679f09f27ded940281c38607716c86cf8ba4a518d524e349c8b4874/tomli-2.3.0-cp314-cp314t-win32.whl", hash = "sha256:a1f7f282fe248311650081faafa5f4732bdbfef5d45fe3f2e702fbc6f2d496e0", size = 107563, upload-time = "2025-10-08T22:01:44.233Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/84/ff/426ca8683cf7b753614480484f6437f568fd2fda2edbdf57a2d3d8b27a0b/tomli-2.3.0-cp314-cp314t-win_amd64.whl", hash = "sha256:70a251f8d4ba2d9ac2542eecf008b3c8a9fc5c3f9f02c56a9d7952612be2fdba", size = 119756, upload-time = "2025-10-08T22:01:45.234Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.15.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" },
|
||||
]
|
||||
|
|
@ -13,39 +13,60 @@ import { ThemeProvider } from 'next-themes'
|
|||
import useTheme from '@/hooks/use-theme'
|
||||
import { useEffect, useState } from 'react'
|
||||
|
||||
const DARK_MODE_MEDIA_QUERY = /prefers-color-scheme:\s*dark/i
|
||||
|
||||
// Setup browser environment for testing
|
||||
const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = false) => {
|
||||
// Mock localStorage
|
||||
const mockStorage = {
|
||||
getItem: jest.fn((key: string) => {
|
||||
if (key === 'theme') return storedTheme
|
||||
return null
|
||||
}),
|
||||
setItem: jest.fn(),
|
||||
removeItem: jest.fn(),
|
||||
if (typeof window === 'undefined')
|
||||
return
|
||||
|
||||
try {
|
||||
window.localStorage.clear()
|
||||
}
|
||||
catch {
|
||||
// ignore if localStorage has been replaced by a throwing stub
|
||||
}
|
||||
|
||||
// Mock system theme preference
|
||||
const mockMatchMedia = jest.fn((query: string) => ({
|
||||
matches: query.includes('dark') && systemPrefersDark,
|
||||
media: query,
|
||||
addListener: jest.fn(),
|
||||
removeListener: jest.fn(),
|
||||
}))
|
||||
if (storedTheme === null)
|
||||
window.localStorage.removeItem('theme')
|
||||
else
|
||||
window.localStorage.setItem('theme', storedTheme)
|
||||
|
||||
if (typeof window !== 'undefined') {
|
||||
Object.defineProperty(window, 'localStorage', {
|
||||
value: mockStorage,
|
||||
configurable: true,
|
||||
})
|
||||
document.documentElement.removeAttribute('data-theme')
|
||||
|
||||
Object.defineProperty(window, 'matchMedia', {
|
||||
value: mockMatchMedia,
|
||||
configurable: true,
|
||||
})
|
||||
const mockMatchMedia: typeof window.matchMedia = (query: string) => {
|
||||
const listeners = new Set<(event: MediaQueryListEvent) => void>()
|
||||
const isDarkQuery = DARK_MODE_MEDIA_QUERY.test(query)
|
||||
const matches = isDarkQuery ? systemPrefersDark : false
|
||||
|
||||
const mediaQueryList: MediaQueryList = {
|
||||
matches,
|
||||
media: query,
|
||||
onchange: null,
|
||||
addListener: (listener: MediaQueryListListener) => {
|
||||
listeners.add(listener)
|
||||
},
|
||||
removeListener: (listener: MediaQueryListListener) => {
|
||||
listeners.delete(listener)
|
||||
},
|
||||
addEventListener: (_event, listener: EventListener) => {
|
||||
if (typeof listener === 'function')
|
||||
listeners.add(listener as MediaQueryListListener)
|
||||
},
|
||||
removeEventListener: (_event, listener: EventListener) => {
|
||||
if (typeof listener === 'function')
|
||||
listeners.delete(listener as MediaQueryListListener)
|
||||
},
|
||||
dispatchEvent: (event: Event) => {
|
||||
listeners.forEach(listener => listener(event as MediaQueryListEvent))
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
return mediaQueryList
|
||||
}
|
||||
|
||||
return { mockStorage, mockMatchMedia }
|
||||
jest.spyOn(window, 'matchMedia').mockImplementation(mockMatchMedia)
|
||||
}
|
||||
|
||||
// Simulate real page component based on Dify's actual theme usage
|
||||
|
|
@ -94,7 +115,17 @@ const TestThemeProvider = ({ children }: { children: React.ReactNode }) => (
|
|||
|
||||
describe('Real Browser Environment Dark Mode Flicker Test', () => {
|
||||
beforeEach(() => {
|
||||
jest.restoreAllMocks()
|
||||
jest.clearAllMocks()
|
||||
if (typeof window !== 'undefined') {
|
||||
try {
|
||||
window.localStorage.clear()
|
||||
}
|
||||
catch {
|
||||
// ignore when localStorage is replaced with an error-throwing stub
|
||||
}
|
||||
document.documentElement.removeAttribute('data-theme')
|
||||
}
|
||||
})
|
||||
|
||||
describe('Page Refresh Scenario Simulation', () => {
|
||||
|
|
@ -323,35 +354,40 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => {
|
|||
|
||||
describe('Edge Cases and Error Handling', () => {
|
||||
test('handles localStorage access errors gracefully', async () => {
|
||||
// Mock localStorage to throw an error
|
||||
setupMockEnvironment(null)
|
||||
|
||||
const mockStorage = {
|
||||
getItem: jest.fn(() => {
|
||||
throw new Error('LocalStorage access denied')
|
||||
}),
|
||||
setItem: jest.fn(),
|
||||
removeItem: jest.fn(),
|
||||
clear: jest.fn(),
|
||||
}
|
||||
|
||||
if (typeof window !== 'undefined') {
|
||||
Object.defineProperty(window, 'localStorage', {
|
||||
value: mockStorage,
|
||||
configurable: true,
|
||||
})
|
||||
}
|
||||
|
||||
render(
|
||||
<TestThemeProvider>
|
||||
<PageComponent />
|
||||
</TestThemeProvider>,
|
||||
)
|
||||
|
||||
// Should fallback gracefully without crashing
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('theme-indicator')).toBeInTheDocument()
|
||||
Object.defineProperty(window, 'localStorage', {
|
||||
value: mockStorage,
|
||||
configurable: true,
|
||||
})
|
||||
|
||||
// Should default to light theme when localStorage fails
|
||||
expect(screen.getByTestId('visual-appearance')).toHaveTextContent('Appearance: light')
|
||||
try {
|
||||
render(
|
||||
<TestThemeProvider>
|
||||
<PageComponent />
|
||||
</TestThemeProvider>,
|
||||
)
|
||||
|
||||
// Should fallback gracefully without crashing
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('theme-indicator')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Should default to light theme when localStorage fails
|
||||
expect(screen.getByTestId('visual-appearance')).toHaveTextContent('Appearance: light')
|
||||
}
|
||||
finally {
|
||||
Reflect.deleteProperty(window, 'localStorage')
|
||||
}
|
||||
})
|
||||
|
||||
test('handles invalid theme values in localStorage', async () => {
|
||||
|
|
@ -403,6 +439,8 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => {
|
|||
|
||||
setupMockEnvironment('dark')
|
||||
|
||||
expect(window.localStorage.getItem('theme')).toBe('dark')
|
||||
|
||||
render(
|
||||
<TestThemeProvider>
|
||||
<PerformanceTestComponent />
|
||||
|
|
|
|||
|
|
@ -17,12 +17,9 @@ import type {
|
|||
import { noop } from 'lodash-es'
|
||||
|
||||
export type EmbeddedChatbotContextValue = {
|
||||
userCanAccess?: boolean
|
||||
appInfoError?: any
|
||||
appInfoLoading?: boolean
|
||||
appMeta?: AppMeta
|
||||
appData?: AppData
|
||||
appParams?: ChatConfig
|
||||
appMeta: AppMeta | null
|
||||
appData: AppData | null
|
||||
appParams: ChatConfig | null
|
||||
appChatListDataLoading?: boolean
|
||||
currentConversationId: string
|
||||
currentConversationItem?: ConversationItem
|
||||
|
|
@ -59,7 +56,10 @@ export type EmbeddedChatbotContextValue = {
|
|||
}
|
||||
|
||||
export const EmbeddedChatbotContext = createContext<EmbeddedChatbotContextValue>({
|
||||
userCanAccess: false,
|
||||
appData: null,
|
||||
appMeta: null,
|
||||
appParams: null,
|
||||
appChatListDataLoading: false,
|
||||
currentConversationId: '',
|
||||
appPrevChatList: [],
|
||||
pinnedConversationList: [],
|
||||
|
|
|
|||
|
|
@ -18,9 +18,6 @@ import { CONVERSATION_ID_INFO } from '../constants'
|
|||
import { buildChatItemTree, getProcessedInputsFromUrlParams, getProcessedSystemVariablesFromUrlParams, getProcessedUserVariablesFromUrlParams } from '../utils'
|
||||
import { getProcessedFilesFromResponse } from '../../file-uploader/utils'
|
||||
import {
|
||||
fetchAppInfo,
|
||||
fetchAppMeta,
|
||||
fetchAppParams,
|
||||
fetchChatList,
|
||||
fetchConversations,
|
||||
generationConversationName,
|
||||
|
|
@ -36,8 +33,7 @@ import { InputVarType } from '@/app/components/workflow/types'
|
|||
import { TransferMethod } from '@/types/app'
|
||||
import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils'
|
||||
import { noop } from 'lodash-es'
|
||||
import { useGetUserCanAccessApp } from '@/service/access-control'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useWebAppStore } from '@/context/web-app-context'
|
||||
|
||||
function getFormattedChatList(messages: any[]) {
|
||||
const newChatList: ChatItem[] = []
|
||||
|
|
@ -67,18 +63,10 @@ function getFormattedChatList(messages: any[]) {
|
|||
|
||||
export const useEmbeddedChatbot = () => {
|
||||
const isInstalledApp = false
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const { data: appInfo, isLoading: appInfoLoading, error: appInfoError } = useSWR('appInfo', fetchAppInfo)
|
||||
const { isPending: isCheckingPermission, data: userCanAccessResult } = useGetUserCanAccessApp({
|
||||
appId: appInfo?.app_id,
|
||||
isInstalledApp,
|
||||
enabled: systemFeatures.webapp_auth.enabled,
|
||||
})
|
||||
|
||||
const appData = useMemo(() => {
|
||||
return appInfo
|
||||
}, [appInfo])
|
||||
const appId = useMemo(() => appData?.app_id, [appData])
|
||||
const appInfo = useWebAppStore(s => s.appInfo)
|
||||
const appMeta = useWebAppStore(s => s.appMeta)
|
||||
const appParams = useWebAppStore(s => s.appParams)
|
||||
const appId = useMemo(() => appInfo?.app_id, [appInfo])
|
||||
|
||||
const [userId, setUserId] = useState<string>()
|
||||
const [conversationId, setConversationId] = useState<string>()
|
||||
|
|
@ -145,8 +133,6 @@ export const useEmbeddedChatbot = () => {
|
|||
return currentConversationId
|
||||
}, [currentConversationId, newConversationId])
|
||||
|
||||
const { data: appParams } = useSWR(['appParams', isInstalledApp, appId], () => fetchAppParams(isInstalledApp, appId))
|
||||
const { data: appMeta } = useSWR(['appMeta', isInstalledApp, appId], () => fetchAppMeta(isInstalledApp, appId))
|
||||
const { data: appPinnedConversationData } = useSWR(['appConversationData', isInstalledApp, appId, true], () => fetchConversations(isInstalledApp, appId, undefined, true, 100))
|
||||
const { data: appConversationData, isLoading: appConversationDataLoading, mutate: mutateAppConversationData } = useSWR(['appConversationData', isInstalledApp, appId, false], () => fetchConversations(isInstalledApp, appId, undefined, false, 100))
|
||||
const { data: appChatListData, isLoading: appChatListDataLoading } = useSWR(chatShouldReloadKey ? ['appChatList', chatShouldReloadKey, isInstalledApp, appId] : null, () => fetchChatList(chatShouldReloadKey, isInstalledApp, appId))
|
||||
|
|
@ -398,16 +384,13 @@ export const useEmbeddedChatbot = () => {
|
|||
}, [isInstalledApp, appId, t, notify])
|
||||
|
||||
return {
|
||||
appInfoError,
|
||||
appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && isCheckingPermission),
|
||||
userCanAccess: systemFeatures.webapp_auth.enabled ? userCanAccessResult?.result : true,
|
||||
isInstalledApp,
|
||||
allowResetChat,
|
||||
appId,
|
||||
currentConversationId,
|
||||
currentConversationItem,
|
||||
handleConversationIdInfoChange,
|
||||
appData,
|
||||
appData: appInfo,
|
||||
appParams: appParams || {} as ChatConfig,
|
||||
appMeta,
|
||||
appPinnedConversationData,
|
||||
|
|
|
|||
|
|
@ -101,7 +101,6 @@ const EmbeddedChatbotWrapper = () => {
|
|||
|
||||
const {
|
||||
appData,
|
||||
userCanAccess,
|
||||
appParams,
|
||||
appMeta,
|
||||
appChatListDataLoading,
|
||||
|
|
@ -135,7 +134,6 @@ const EmbeddedChatbotWrapper = () => {
|
|||
} = useEmbeddedChatbot()
|
||||
|
||||
return <EmbeddedChatbotContext.Provider value={{
|
||||
userCanAccess,
|
||||
appData,
|
||||
appParams,
|
||||
appMeta,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,152 @@
|
|||
import React from 'react'
|
||||
import { cleanup, fireEvent, render } from '@testing-library/react'
|
||||
import InlineDeleteConfirm from './index'
|
||||
|
||||
// Mock react-i18next
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
'common.operation.deleteConfirmTitle': 'Delete?',
|
||||
'common.operation.yes': 'Yes',
|
||||
'common.operation.no': 'No',
|
||||
'common.operation.confirmAction': 'Please confirm your action.',
|
||||
}
|
||||
return translations[key] || key
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
afterEach(cleanup)
|
||||
|
||||
describe('InlineDeleteConfirm', () => {
|
||||
describe('Rendering', () => {
|
||||
test('should render with default text', () => {
|
||||
const onConfirm = jest.fn()
|
||||
const onCancel = jest.fn()
|
||||
const { getByText } = render(
|
||||
<InlineDeleteConfirm onConfirm={onConfirm} onCancel={onCancel} />,
|
||||
)
|
||||
|
||||
expect(getByText('Delete?')).toBeInTheDocument()
|
||||
expect(getByText('No')).toBeInTheDocument()
|
||||
expect(getByText('Yes')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
test('should render with custom text', () => {
|
||||
const onConfirm = jest.fn()
|
||||
const onCancel = jest.fn()
|
||||
const { getByText } = render(
|
||||
<InlineDeleteConfirm
|
||||
title="Remove?"
|
||||
confirmText="Confirm"
|
||||
cancelText="Cancel"
|
||||
onConfirm={onConfirm}
|
||||
onCancel={onCancel}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(getByText('Remove?')).toBeInTheDocument()
|
||||
expect(getByText('Cancel')).toBeInTheDocument()
|
||||
expect(getByText('Confirm')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
test('should have proper ARIA attributes', () => {
|
||||
const onConfirm = jest.fn()
|
||||
const onCancel = jest.fn()
|
||||
const { container } = render(
|
||||
<InlineDeleteConfirm onConfirm={onConfirm} onCancel={onCancel} />,
|
||||
)
|
||||
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
expect(wrapper).toHaveAttribute('aria-labelledby', 'inline-delete-confirm-title')
|
||||
expect(wrapper).toHaveAttribute('aria-describedby', 'inline-delete-confirm-description')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Button interactions', () => {
|
||||
test('should call onCancel when cancel button is clicked', () => {
|
||||
const onConfirm = jest.fn()
|
||||
const onCancel = jest.fn()
|
||||
const { getByText } = render(
|
||||
<InlineDeleteConfirm onConfirm={onConfirm} onCancel={onCancel} />,
|
||||
)
|
||||
|
||||
fireEvent.click(getByText('No'))
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
expect(onConfirm).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
test('should call onConfirm when confirm button is clicked', () => {
|
||||
const onConfirm = jest.fn()
|
||||
const onCancel = jest.fn()
|
||||
const { getByText } = render(
|
||||
<InlineDeleteConfirm onConfirm={onConfirm} onCancel={onCancel} />,
|
||||
)
|
||||
|
||||
fireEvent.click(getByText('Yes'))
|
||||
expect(onConfirm).toHaveBeenCalledTimes(1)
|
||||
expect(onCancel).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Variant prop', () => {
|
||||
test('should render with delete variant by default', () => {
|
||||
const onConfirm = jest.fn()
|
||||
const onCancel = jest.fn()
|
||||
const { getByText } = render(
|
||||
<InlineDeleteConfirm onConfirm={onConfirm} onCancel={onCancel} />,
|
||||
)
|
||||
|
||||
const confirmButton = getByText('Yes').closest('button')
|
||||
expect(confirmButton?.className).toContain('btn-destructive')
|
||||
})
|
||||
|
||||
test('should render without destructive class for warning variant', () => {
|
||||
const onConfirm = jest.fn()
|
||||
const onCancel = jest.fn()
|
||||
const { getByText } = render(
|
||||
<InlineDeleteConfirm
|
||||
variant="warning"
|
||||
onConfirm={onConfirm}
|
||||
onCancel={onCancel}
|
||||
/>,
|
||||
)
|
||||
|
||||
const confirmButton = getByText('Yes').closest('button')
|
||||
expect(confirmButton?.className).not.toContain('btn-destructive')
|
||||
})
|
||||
|
||||
test('should render without destructive class for info variant', () => {
|
||||
const onConfirm = jest.fn()
|
||||
const onCancel = jest.fn()
|
||||
const { getByText } = render(
|
||||
<InlineDeleteConfirm
|
||||
variant="info"
|
||||
onConfirm={onConfirm}
|
||||
onCancel={onCancel}
|
||||
/>,
|
||||
)
|
||||
|
||||
const confirmButton = getByText('Yes').closest('button')
|
||||
expect(confirmButton?.className).not.toContain('btn-destructive')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Custom className', () => {
|
||||
test('should apply custom className to wrapper', () => {
|
||||
const onConfirm = jest.fn()
|
||||
const onCancel = jest.fn()
|
||||
const { container } = render(
|
||||
<InlineDeleteConfirm
|
||||
className="custom-class"
|
||||
onConfirm={onConfirm}
|
||||
onCancel={onCancel}
|
||||
/>,
|
||||
)
|
||||
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
expect(wrapper.className).toContain('custom-class')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
export type InlineDeleteConfirmProps = {
|
||||
title?: string
|
||||
confirmText?: string
|
||||
cancelText?: string
|
||||
onConfirm: () => void
|
||||
onCancel: () => void
|
||||
className?: string
|
||||
variant?: 'delete' | 'warning' | 'info'
|
||||
}
|
||||
|
||||
const InlineDeleteConfirm: FC<InlineDeleteConfirmProps> = ({
|
||||
title,
|
||||
confirmText,
|
||||
cancelText,
|
||||
onConfirm,
|
||||
onCancel,
|
||||
className,
|
||||
variant = 'delete',
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const titleText = title || t('common.operation.deleteConfirmTitle', 'Delete?')
|
||||
const confirmTxt = confirmText || t('common.operation.yes', 'Yes')
|
||||
const cancelTxt = cancelText || t('common.operation.no', 'No')
|
||||
|
||||
return (
|
||||
<div
|
||||
aria-labelledby="inline-delete-confirm-title"
|
||||
aria-describedby="inline-delete-confirm-description"
|
||||
className={cn(
|
||||
'flex w-[120px] flex-col justify-center gap-1.5',
|
||||
'rounded-[10px] border-[0.5px] border-components-panel-border-subtle',
|
||||
'bg-components-panel-bg-blur px-2 pb-2 pt-1.5',
|
||||
'backdrop-blur-[10px]',
|
||||
'shadow-lg',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div
|
||||
id="inline-delete-confirm-title"
|
||||
className="system-xs-semibold text-text-primary"
|
||||
>
|
||||
{titleText}
|
||||
</div>
|
||||
|
||||
<div className="flex w-full items-center justify-center gap-1">
|
||||
<Button
|
||||
size="small"
|
||||
variant="secondary"
|
||||
onClick={onCancel}
|
||||
aria-label={cancelTxt}
|
||||
className="flex-1"
|
||||
>
|
||||
{cancelTxt}
|
||||
</Button>
|
||||
<Button
|
||||
size="small"
|
||||
variant="primary"
|
||||
destructive={variant === 'delete'}
|
||||
onClick={onConfirm}
|
||||
aria-label={confirmTxt}
|
||||
className="flex-1"
|
||||
>
|
||||
{confirmTxt}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<span id="inline-delete-confirm-description" className="sr-only">
|
||||
{t('common.operation.confirmAction', 'Please confirm your action.')}
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
InlineDeleteConfirm.displayName = 'InlineDeleteConfirm'
|
||||
|
||||
export default InlineDeleteConfirm
|
||||
|
|
@ -7,6 +7,7 @@ import { useInvalidateStrategyProviders } from '@/service/use-strategy'
|
|||
import type { Plugin, PluginDeclaration, PluginManifestInMarket } from '../../types'
|
||||
import { PluginType } from '../../types'
|
||||
import { useInvalidDataSourceList } from '@/service/use-pipeline'
|
||||
import { useInvalidDataSourceListAuth } from '@/service/use-datasource'
|
||||
|
||||
const useRefreshPluginList = () => {
|
||||
const invalidateInstalledPluginList = useInvalidateInstalledPluginList()
|
||||
|
|
@ -19,6 +20,8 @@ const useRefreshPluginList = () => {
|
|||
const invalidateAllBuiltInTools = useInvalidateAllBuiltInTools()
|
||||
const invalidateAllDataSources = useInvalidDataSourceList()
|
||||
|
||||
const invalidateDataSourceListAuth = useInvalidDataSourceListAuth()
|
||||
|
||||
const invalidateStrategyProviders = useInvalidateStrategyProviders()
|
||||
return {
|
||||
refreshPluginList: (manifest?: PluginManifestInMarket | Plugin | PluginDeclaration | null, refreshAllType?: boolean) => {
|
||||
|
|
@ -32,8 +35,10 @@ const useRefreshPluginList = () => {
|
|||
// TODO: update suggested tools. It's a function in hook useMarketplacePlugins,handleUpdatePlugins
|
||||
}
|
||||
|
||||
if ((manifest && PluginType.datasource.includes(manifest.category)) || refreshAllType)
|
||||
if ((manifest && PluginType.datasource.includes(manifest.category)) || refreshAllType) {
|
||||
invalidateAllDataSources()
|
||||
invalidateDataSourceListAuth()
|
||||
}
|
||||
|
||||
// model select
|
||||
if ((manifest && PluginType.model.includes(manifest.category)) || refreshAllType) {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
} from 'react'
|
||||
import Link from 'next/link'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiArrowRightUpLine } from '@remixicon/react'
|
||||
import { BlockEnum } from '../types'
|
||||
import type {
|
||||
OnSelectBlock,
|
||||
|
|
@ -14,10 +13,12 @@ import type { DataSourceDefaultValue, ToolDefaultValue } from './types'
|
|||
import Tools from './tools'
|
||||
import { ViewType } from './view-type-select'
|
||||
import cn from '@/utils/classnames'
|
||||
import type { ListRef } from '@/app/components/workflow/block-selector/market-place-plugin/list'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
import PluginList, { type ListRef } from '@/app/components/workflow/block-selector/market-place-plugin/list'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { DEFAULT_FILE_EXTENSIONS_IN_LOCAL_FILE_DATA_SOURCE } from './constants'
|
||||
import { useMarketplacePlugins } from '../../plugins/marketplace/hooks'
|
||||
import { PluginType } from '../../plugins/types'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
|
||||
type AllToolsProps = {
|
||||
className?: string
|
||||
|
|
@ -34,9 +35,26 @@ const DataSources = ({
|
|||
onSelect,
|
||||
dataSources,
|
||||
}: AllToolsProps) => {
|
||||
const { t } = useTranslation()
|
||||
const language = useGetLanguage()
|
||||
const pluginRef = useRef<ListRef>(null)
|
||||
const wrapElemRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const isMatchingKeywords = (text: string, keywords: string) => {
|
||||
return text.toLowerCase().includes(keywords.toLowerCase())
|
||||
}
|
||||
|
||||
const filteredDatasources = useMemo(() => {
|
||||
const hasFilter = searchText
|
||||
if (!hasFilter)
|
||||
return dataSources.filter(toolWithProvider => toolWithProvider.tools.length > 0)
|
||||
|
||||
return dataSources.filter((toolWithProvider) => {
|
||||
return isMatchingKeywords(toolWithProvider.name, searchText) || toolWithProvider.tools.some((tool) => {
|
||||
return tool.label[language].toLowerCase().includes(searchText.toLowerCase()) || tool.name.toLowerCase().includes(searchText.toLowerCase())
|
||||
})
|
||||
})
|
||||
}, [searchText, dataSources, language])
|
||||
|
||||
const handleSelect = useCallback((_: any, toolDefaultValue: ToolDefaultValue) => {
|
||||
let defaultValue: DataSourceDefaultValue = {
|
||||
plugin_id: toolDefaultValue?.provider_id,
|
||||
|
|
@ -55,8 +73,24 @@ const DataSources = ({
|
|||
}
|
||||
onSelect(BlockEnum.DataSource, toolDefaultValue && defaultValue)
|
||||
}, [onSelect])
|
||||
|
||||
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
|
||||
|
||||
const {
|
||||
queryPluginsWithDebounced: fetchPlugins,
|
||||
plugins: notInstalledPlugins = [],
|
||||
} = useMarketplacePlugins()
|
||||
|
||||
useEffect(() => {
|
||||
if (!enable_marketplace) return
|
||||
if (searchText) {
|
||||
fetchPlugins({
|
||||
query: searchText,
|
||||
category: PluginType.datasource,
|
||||
})
|
||||
}
|
||||
}, [searchText, enable_marketplace])
|
||||
|
||||
return (
|
||||
<div className={cn(className)}>
|
||||
<div
|
||||
|
|
@ -66,24 +100,23 @@ const DataSources = ({
|
|||
>
|
||||
<Tools
|
||||
className={toolContentClassName}
|
||||
tools={dataSources}
|
||||
tools={filteredDatasources}
|
||||
onSelect={handleSelect as OnSelectBlock}
|
||||
viewType={ViewType.flat}
|
||||
hasSearchText={!!searchText}
|
||||
canNotSelectMultiple
|
||||
/>
|
||||
{
|
||||
enable_marketplace && (
|
||||
<Link
|
||||
className='system-sm-medium sticky bottom-0 z-10 flex h-8 cursor-pointer items-center rounded-b-lg border-[0.5px] border-t border-components-panel-border bg-components-panel-bg-blur px-4 py-1 text-text-accent-light-mode-only shadow-lg'
|
||||
href={getMarketplaceUrl('')}
|
||||
target='_blank'
|
||||
>
|
||||
<span>{t('plugin.findMoreInMarketplace')}</span>
|
||||
<RiArrowRightUpLine className='ml-0.5 h-3 w-3' />
|
||||
</Link>
|
||||
)
|
||||
}
|
||||
{/* Plugins from marketplace */}
|
||||
{enable_marketplace && (
|
||||
<PluginList
|
||||
ref={pluginRef}
|
||||
wrapElemRef={wrapElemRef}
|
||||
list={notInstalledPlugins}
|
||||
tags={[]}
|
||||
searchText={searchText}
|
||||
toolContentClassName={toolContentClassName}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,94 @@
|
|||
import type { ComponentType } from 'react'
|
||||
import { BlockEnum } from '../types'
|
||||
import StartNode from './start/node'
|
||||
import StartPanel from './start/panel'
|
||||
import EndNode from './end/node'
|
||||
import EndPanel from './end/panel'
|
||||
import AnswerNode from './answer/node'
|
||||
import AnswerPanel from './answer/panel'
|
||||
import LLMNode from './llm/node'
|
||||
import LLMPanel from './llm/panel'
|
||||
import KnowledgeRetrievalNode from './knowledge-retrieval/node'
|
||||
import KnowledgeRetrievalPanel from './knowledge-retrieval/panel'
|
||||
import QuestionClassifierNode from './question-classifier/node'
|
||||
import QuestionClassifierPanel from './question-classifier/panel'
|
||||
import IfElseNode from './if-else/node'
|
||||
import IfElsePanel from './if-else/panel'
|
||||
import CodeNode from './code/node'
|
||||
import CodePanel from './code/panel'
|
||||
import TemplateTransformNode from './template-transform/node'
|
||||
import TemplateTransformPanel from './template-transform/panel'
|
||||
import HttpNode from './http/node'
|
||||
import HttpPanel from './http/panel'
|
||||
import ToolNode from './tool/node'
|
||||
import ToolPanel from './tool/panel'
|
||||
import VariableAssignerNode from './variable-assigner/node'
|
||||
import VariableAssignerPanel from './variable-assigner/panel'
|
||||
import AssignerNode from './assigner/node'
|
||||
import AssignerPanel from './assigner/panel'
|
||||
import ParameterExtractorNode from './parameter-extractor/node'
|
||||
import ParameterExtractorPanel from './parameter-extractor/panel'
|
||||
import IterationNode from './iteration/node'
|
||||
import IterationPanel from './iteration/panel'
|
||||
import LoopNode from './loop/node'
|
||||
import LoopPanel from './loop/panel'
|
||||
import DocExtractorNode from './document-extractor/node'
|
||||
import DocExtractorPanel from './document-extractor/panel'
|
||||
import ListFilterNode from './list-operator/node'
|
||||
import ListFilterPanel from './list-operator/panel'
|
||||
import AgentNode from './agent/node'
|
||||
import AgentPanel from './agent/panel'
|
||||
import DataSourceNode from './data-source/node'
|
||||
import DataSourcePanel from './data-source/panel'
|
||||
import KnowledgeBaseNode from './knowledge-base/node'
|
||||
import KnowledgeBasePanel from './knowledge-base/panel'
|
||||
|
||||
export const NodeComponentMap: Record<string, ComponentType<any>> = {
|
||||
[BlockEnum.Start]: StartNode,
|
||||
[BlockEnum.End]: EndNode,
|
||||
[BlockEnum.Answer]: AnswerNode,
|
||||
[BlockEnum.LLM]: LLMNode,
|
||||
[BlockEnum.KnowledgeRetrieval]: KnowledgeRetrievalNode,
|
||||
[BlockEnum.QuestionClassifier]: QuestionClassifierNode,
|
||||
[BlockEnum.IfElse]: IfElseNode,
|
||||
[BlockEnum.Code]: CodeNode,
|
||||
[BlockEnum.TemplateTransform]: TemplateTransformNode,
|
||||
[BlockEnum.HttpRequest]: HttpNode,
|
||||
[BlockEnum.Tool]: ToolNode,
|
||||
[BlockEnum.VariableAssigner]: VariableAssignerNode,
|
||||
[BlockEnum.Assigner]: AssignerNode,
|
||||
[BlockEnum.VariableAggregator]: VariableAssignerNode,
|
||||
[BlockEnum.ParameterExtractor]: ParameterExtractorNode,
|
||||
[BlockEnum.Iteration]: IterationNode,
|
||||
[BlockEnum.Loop]: LoopNode,
|
||||
[BlockEnum.DocExtractor]: DocExtractorNode,
|
||||
[BlockEnum.ListFilter]: ListFilterNode,
|
||||
[BlockEnum.Agent]: AgentNode,
|
||||
[BlockEnum.DataSource]: DataSourceNode,
|
||||
[BlockEnum.KnowledgeBase]: KnowledgeBaseNode,
|
||||
}
|
||||
|
||||
export const PanelComponentMap: Record<string, ComponentType<any>> = {
|
||||
[BlockEnum.Start]: StartPanel,
|
||||
[BlockEnum.End]: EndPanel,
|
||||
[BlockEnum.Answer]: AnswerPanel,
|
||||
[BlockEnum.LLM]: LLMPanel,
|
||||
[BlockEnum.KnowledgeRetrieval]: KnowledgeRetrievalPanel,
|
||||
[BlockEnum.QuestionClassifier]: QuestionClassifierPanel,
|
||||
[BlockEnum.IfElse]: IfElsePanel,
|
||||
[BlockEnum.Code]: CodePanel,
|
||||
[BlockEnum.TemplateTransform]: TemplateTransformPanel,
|
||||
[BlockEnum.HttpRequest]: HttpPanel,
|
||||
[BlockEnum.Tool]: ToolPanel,
|
||||
[BlockEnum.VariableAssigner]: VariableAssignerPanel,
|
||||
[BlockEnum.VariableAggregator]: VariableAssignerPanel,
|
||||
[BlockEnum.Assigner]: AssignerPanel,
|
||||
[BlockEnum.ParameterExtractor]: ParameterExtractorPanel,
|
||||
[BlockEnum.Iteration]: IterationPanel,
|
||||
[BlockEnum.Loop]: LoopPanel,
|
||||
[BlockEnum.DocExtractor]: DocExtractorPanel,
|
||||
[BlockEnum.ListFilter]: ListFilterPanel,
|
||||
[BlockEnum.Agent]: AgentPanel,
|
||||
[BlockEnum.DataSource]: DataSourcePanel,
|
||||
[BlockEnum.KnowledgeBase]: KnowledgeBasePanel,
|
||||
}
|
||||
|
|
@ -1,101 +1,5 @@
|
|||
import type { ComponentType } from 'react'
|
||||
import { BlockEnum } from '../types'
|
||||
import StartNode from './start/node'
|
||||
import StartPanel from './start/panel'
|
||||
import EndNode from './end/node'
|
||||
import EndPanel from './end/panel'
|
||||
import AnswerNode from './answer/node'
|
||||
import AnswerPanel from './answer/panel'
|
||||
import LLMNode from './llm/node'
|
||||
import LLMPanel from './llm/panel'
|
||||
import KnowledgeRetrievalNode from './knowledge-retrieval/node'
|
||||
import KnowledgeRetrievalPanel from './knowledge-retrieval/panel'
|
||||
import QuestionClassifierNode from './question-classifier/node'
|
||||
import QuestionClassifierPanel from './question-classifier/panel'
|
||||
import IfElseNode from './if-else/node'
|
||||
import IfElsePanel from './if-else/panel'
|
||||
import CodeNode from './code/node'
|
||||
import CodePanel from './code/panel'
|
||||
import TemplateTransformNode from './template-transform/node'
|
||||
import TemplateTransformPanel from './template-transform/panel'
|
||||
import HttpNode from './http/node'
|
||||
import HttpPanel from './http/panel'
|
||||
import ToolNode from './tool/node'
|
||||
import ToolPanel from './tool/panel'
|
||||
import VariableAssignerNode from './variable-assigner/node'
|
||||
import VariableAssignerPanel from './variable-assigner/panel'
|
||||
import AssignerNode from './assigner/node'
|
||||
import AssignerPanel from './assigner/panel'
|
||||
import ParameterExtractorNode from './parameter-extractor/node'
|
||||
import ParameterExtractorPanel from './parameter-extractor/panel'
|
||||
import IterationNode from './iteration/node'
|
||||
import IterationPanel from './iteration/panel'
|
||||
import LoopNode from './loop/node'
|
||||
import LoopPanel from './loop/panel'
|
||||
import DocExtractorNode from './document-extractor/node'
|
||||
import DocExtractorPanel from './document-extractor/panel'
|
||||
import ListFilterNode from './list-operator/node'
|
||||
import ListFilterPanel from './list-operator/panel'
|
||||
import AgentNode from './agent/node'
|
||||
import AgentPanel from './agent/panel'
|
||||
import DataSourceNode from './data-source/node'
|
||||
import DataSourcePanel from './data-source/panel'
|
||||
import KnowledgeBaseNode from './knowledge-base/node'
|
||||
import KnowledgeBasePanel from './knowledge-base/panel'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
|
||||
export const NodeComponentMap: Record<string, ComponentType<any>> = {
|
||||
[BlockEnum.Start]: StartNode,
|
||||
[BlockEnum.End]: EndNode,
|
||||
[BlockEnum.Answer]: AnswerNode,
|
||||
[BlockEnum.LLM]: LLMNode,
|
||||
[BlockEnum.KnowledgeRetrieval]: KnowledgeRetrievalNode,
|
||||
[BlockEnum.QuestionClassifier]: QuestionClassifierNode,
|
||||
[BlockEnum.IfElse]: IfElseNode,
|
||||
[BlockEnum.Code]: CodeNode,
|
||||
[BlockEnum.TemplateTransform]: TemplateTransformNode,
|
||||
[BlockEnum.HttpRequest]: HttpNode,
|
||||
[BlockEnum.Tool]: ToolNode,
|
||||
[BlockEnum.VariableAssigner]: VariableAssignerNode,
|
||||
[BlockEnum.Assigner]: AssignerNode,
|
||||
[BlockEnum.VariableAggregator]: VariableAssignerNode,
|
||||
[BlockEnum.ParameterExtractor]: ParameterExtractorNode,
|
||||
[BlockEnum.Iteration]: IterationNode,
|
||||
[BlockEnum.Loop]: LoopNode,
|
||||
[BlockEnum.DocExtractor]: DocExtractorNode,
|
||||
[BlockEnum.ListFilter]: ListFilterNode,
|
||||
[BlockEnum.Agent]: AgentNode,
|
||||
[BlockEnum.DataSource]: DataSourceNode,
|
||||
[BlockEnum.KnowledgeBase]: KnowledgeBaseNode,
|
||||
}
|
||||
|
||||
export const PanelComponentMap: Record<string, ComponentType<any>> = {
|
||||
[BlockEnum.Start]: StartPanel,
|
||||
[BlockEnum.End]: EndPanel,
|
||||
[BlockEnum.Answer]: AnswerPanel,
|
||||
[BlockEnum.LLM]: LLMPanel,
|
||||
[BlockEnum.KnowledgeRetrieval]: KnowledgeRetrievalPanel,
|
||||
[BlockEnum.QuestionClassifier]: QuestionClassifierPanel,
|
||||
[BlockEnum.IfElse]: IfElsePanel,
|
||||
[BlockEnum.Code]: CodePanel,
|
||||
[BlockEnum.TemplateTransform]: TemplateTransformPanel,
|
||||
[BlockEnum.HttpRequest]: HttpPanel,
|
||||
[BlockEnum.Tool]: ToolPanel,
|
||||
[BlockEnum.VariableAssigner]: VariableAssignerPanel,
|
||||
[BlockEnum.VariableAggregator]: VariableAssignerPanel,
|
||||
[BlockEnum.Assigner]: AssignerPanel,
|
||||
[BlockEnum.ParameterExtractor]: ParameterExtractorPanel,
|
||||
[BlockEnum.Iteration]: IterationPanel,
|
||||
[BlockEnum.Loop]: LoopPanel,
|
||||
[BlockEnum.DocExtractor]: DocExtractorPanel,
|
||||
[BlockEnum.ListFilter]: ListFilterPanel,
|
||||
[BlockEnum.Agent]: AgentPanel,
|
||||
[BlockEnum.DataSource]: DataSourcePanel,
|
||||
[BlockEnum.KnowledgeBase]: KnowledgeBasePanel,
|
||||
}
|
||||
|
||||
export const CUSTOM_NODE_TYPE = 'custom'
|
||||
|
||||
export const FILE_TYPE_OPTIONS = [
|
||||
{ value: 'image', i18nKey: 'image' },
|
||||
{ value: 'document', i18nKey: 'doc' },
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import { CUSTOM_NODE } from '../constants'
|
|||
import {
|
||||
NodeComponentMap,
|
||||
PanelComponentMap,
|
||||
} from './constants'
|
||||
} from './components'
|
||||
import BaseNode from './_base/node'
|
||||
import BasePanel from './_base/components/workflow-panel'
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ const ChunkStructure = ({
|
|||
<Field
|
||||
fieldTitleProps={{
|
||||
title: t('workflow.nodes.knowledgeBase.chunkStructure'),
|
||||
tooltip: t('workflow.nodes.knowledgeBase.chunkStructure'),
|
||||
tooltip: t('workflow.nodes.knowledgeBase.chunkStructureTip.message'),
|
||||
operation: chunkStructure && (
|
||||
<Selector
|
||||
options={options}
|
||||
|
|
|
|||
|
|
@ -61,6 +61,10 @@ const translation = {
|
|||
selectAll: 'Alles auswählen',
|
||||
deSelectAll: 'Alle abwählen',
|
||||
config: 'Konfiguration',
|
||||
yes: 'Ja',
|
||||
deleteConfirmTitle: 'Löschen?',
|
||||
no: 'Nein',
|
||||
confirmAction: 'Bitte bestätigen Sie Ihre Aktion.',
|
||||
},
|
||||
placeholder: {
|
||||
input: 'Bitte eingeben',
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ const translation = {
|
|||
cancel: 'Cancel',
|
||||
clear: 'Clear',
|
||||
save: 'Save',
|
||||
yes: 'Yes',
|
||||
no: 'No',
|
||||
deleteConfirmTitle: 'Delete?',
|
||||
confirmAction: 'Please confirm your action.',
|
||||
saveAndEnable: 'Save & Enable',
|
||||
edit: 'Edit',
|
||||
add: 'Add',
|
||||
|
|
|
|||
|
|
@ -61,6 +61,10 @@ const translation = {
|
|||
deSelectAll: 'Deseleccionar todo',
|
||||
selectAll: 'Seleccionar todo',
|
||||
config: 'Config',
|
||||
confirmAction: 'Por favor, confirme su acción.',
|
||||
deleteConfirmTitle: '¿Eliminar?',
|
||||
yes: 'Sí',
|
||||
no: 'No',
|
||||
},
|
||||
errorMsg: {
|
||||
fieldRequired: '{{field}} es requerido',
|
||||
|
|
|
|||
|
|
@ -61,6 +61,10 @@ const translation = {
|
|||
selectAll: 'انتخاب همه',
|
||||
deSelectAll: 'همه را انتخاب نکنید',
|
||||
config: 'تنظیمات',
|
||||
no: 'نه',
|
||||
deleteConfirmTitle: 'حذف شود؟',
|
||||
yes: 'بله',
|
||||
confirmAction: 'لطفاً اقدام خود را تأیید کنید.',
|
||||
},
|
||||
errorMsg: {
|
||||
fieldRequired: '{{field}} الزامی است',
|
||||
|
|
|
|||
|
|
@ -61,6 +61,10 @@ const translation = {
|
|||
deSelectAll: 'Désélectionner tout',
|
||||
selectAll: 'Sélectionner tout',
|
||||
config: 'Config',
|
||||
no: 'Non',
|
||||
confirmAction: 'Veuillez confirmer votre action.',
|
||||
deleteConfirmTitle: 'Supprimer ?',
|
||||
yes: 'Oui',
|
||||
},
|
||||
placeholder: {
|
||||
input: 'Veuillez entrer',
|
||||
|
|
|
|||
|
|
@ -61,6 +61,10 @@ const translation = {
|
|||
selectAll: 'सभी चुनें',
|
||||
deSelectAll: 'सभी चयन हटाएँ',
|
||||
config: 'कॉन्फ़िगरेशन',
|
||||
no: 'नहीं',
|
||||
yes: 'हाँ',
|
||||
deleteConfirmTitle: 'हटाएं?',
|
||||
confirmAction: 'कृपया अपनी क्रिया की पुष्टि करें।',
|
||||
},
|
||||
errorMsg: {
|
||||
fieldRequired: '{{field}} आवश्यक है',
|
||||
|
|
|
|||
|
|
@ -67,6 +67,10 @@ const translation = {
|
|||
sure: 'Saya yakin',
|
||||
imageCopied: 'Gambar yang disalin',
|
||||
config: 'Konfigurasi',
|
||||
deleteConfirmTitle: 'Hapus?',
|
||||
confirmAction: 'Silakan konfirmasi tindakan Anda.',
|
||||
yes: 'Ya',
|
||||
no: 'Tidak',
|
||||
},
|
||||
errorMsg: {
|
||||
urlError: 'URL harus dimulai dengan http:// atau https://',
|
||||
|
|
|
|||
|
|
@ -61,6 +61,10 @@ const translation = {
|
|||
selectAll: 'Seleziona tutto',
|
||||
deSelectAll: 'Deseleziona tutto',
|
||||
config: 'Config',
|
||||
no: 'No',
|
||||
yes: 'Sì',
|
||||
confirmAction: 'Per favore conferma la tua azione.',
|
||||
deleteConfirmTitle: 'Eliminare?',
|
||||
},
|
||||
errorMsg: {
|
||||
fieldRequired: '{{field}} è obbligatorio',
|
||||
|
|
|
|||
|
|
@ -67,6 +67,10 @@ const translation = {
|
|||
selectAll: 'すべて選択',
|
||||
deSelectAll: 'すべて選択解除',
|
||||
config: 'コンフィグ',
|
||||
yes: 'はい',
|
||||
no: 'いいえ',
|
||||
deleteConfirmTitle: '削除しますか?',
|
||||
confirmAction: '操作を確認してください。',
|
||||
},
|
||||
errorMsg: {
|
||||
fieldRequired: '{{field}}は必要です',
|
||||
|
|
|
|||
|
|
@ -61,6 +61,10 @@ const translation = {
|
|||
selectAll: '모두 선택',
|
||||
deSelectAll: '모두 선택 해제',
|
||||
config: '구성',
|
||||
no: '아니요',
|
||||
yes: '네',
|
||||
deleteConfirmTitle: '삭제하시겠습니까?',
|
||||
confirmAction: '귀하의 행동을 확인해 주세요.',
|
||||
},
|
||||
placeholder: {
|
||||
input: '입력해주세요',
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue