Merge branch 'main' into feat/mcp-authentication

This commit is contained in:
zxhlyh 2025-10-13 16:52:42 +08:00
commit 9c6d059227
114 changed files with 3941 additions and 1416 deletions

View File

@ -1521,6 +1521,14 @@ def transform_datasource_credentials():
auth_count = 0
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1
if not firecrawl_tenant_credential.credentials:
click.echo(
click.style(
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
fg="yellow",
)
)
continue
# get credential api key
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
@ -1576,6 +1584,14 @@ def transform_datasource_credentials():
auth_count = 0
for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1
if not jina_tenant_credential.credentials:
click.echo(
click.style(
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
fg="yellow",
)
)
continue
# get credential api key
credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")

View File

@ -1,5 +1,4 @@
import flask_restx
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import select
@ -8,7 +7,8 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from models.dataset import Dataset
from models.model import ApiToken, App
@ -57,6 +57,8 @@ class BaseApiKeyListResource(Resource):
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = db.session.scalars(
select(ApiToken).where(
@ -69,8 +71,10 @@ class BaseApiKeyListResource(Resource):
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
current_key_count = (
@ -108,6 +112,8 @@ class BaseApiKeyResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner

View File

@ -1,9 +1,9 @@
from flask import request
from flask_login import current_user
from flask_restx import Resource, reqparse
from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from services.billing_service import BillingService
from .. import console_ns
@ -17,6 +17,8 @@ class ComplianceApi(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
parser = reqparse.RequestParser()
parser.add_argument("doc_name", type=str, required=True, location="args")
args = parser.parse_args()

View File

@ -45,6 +45,79 @@ def _validate_name(name: str) -> str:
return name
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
"""
Get supported retrieval methods based on vector database type.
Args:
vector_type: Vector database type, can be None
is_mock: Whether this is a Mock API, affects MILVUS handling
Returns:
Dictionary containing supported retrieval methods
Raises:
ValueError: If vector_type is None or unsupported
"""
if vector_type is None:
raise ValueError("Vector store type is not configured.")
# Define vector database types that only support semantic search
semantic_only_types = {
VectorType.RELYT,
VectorType.TIDB_VECTOR,
VectorType.CHROMA,
VectorType.PGVECTO_RS,
VectorType.VIKINGDB,
VectorType.UPSTASH,
}
# Define vector database types that support all retrieval methods
full_search_types = {
VectorType.QDRANT,
VectorType.WEAVIATE,
VectorType.OPENSEARCH,
VectorType.ANALYTICDB,
VectorType.MYSCALE,
VectorType.ORACLE,
VectorType.ELASTICSEARCH,
VectorType.ELASTICSEARCH_JA,
VectorType.PGVECTOR,
VectorType.VASTBASE,
VectorType.TIDB_ON_QDRANT,
VectorType.LINDORM,
VectorType.COUCHBASE,
VectorType.OPENGAUSS,
VectorType.OCEANBASE,
VectorType.TABLESTORE,
VectorType.HUAWEI_CLOUD,
VectorType.TENCENT,
VectorType.MATRIXONE,
VectorType.CLICKZETTA,
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
full_methods = {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
]
}
if vector_type == VectorType.MILVUS:
return semantic_methods if is_mock else full_methods
if vector_type in semantic_only_types:
return semantic_methods
elif vector_type in full_search_types:
return full_methods
else:
raise ValueError(f"Unsupported vector db type {vector_type}.")
@console_ns.route("/datasets")
class DatasetListApi(Resource):
@api.doc("get_datasets")
@ -777,50 +850,7 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required
def get(self):
vector_type = dify_config.VECTOR_STORE
match vector_type:
case (
VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR
| VectorType.VASTBASE
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
| VectorType.COUCHBASE
| VectorType.MILVUS
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.HUAWEI_CLOUD
| VectorType.TENCENT
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
| VectorType.ALIBABACLOUD_MYSQL
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH,
]
}
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
@ -833,49 +863,7 @@ class DatasetRetrievalSettingMockApi(Resource):
@login_required
@account_initialization_required
def get(self, vector_type):
match vector_type:
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.VASTBASE
| VectorType.LINDORM
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.TENCENT
| VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
| VectorType.ALIBABACLOUD_MYSQL
):
return {
"retrieval_method": [
RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH,
]
}
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")

View File

@ -1,7 +1,5 @@
import logging
from typing import cast
from flask_login import current_user
from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -21,6 +19,7 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from libs.login import current_user
from models.account import Account
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@ -31,6 +30,7 @@ logger = logging.getLogger(__name__)
class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
@ -57,11 +57,12 @@ class DatasetsHitTestingBase:
@staticmethod
def perform_hit_testing(dataset, args):
assert isinstance(current_user, Account)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=cast(Account, current_user),
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,

View File

@ -2,15 +2,15 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask_login import current_user
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import login_required
from libs.login import current_user, login_required
from models import InstalledApp
from models.account import Account
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -24,6 +24,8 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
installed_app = (
db.session.query(InstalledApp)
.where(
@ -56,6 +58,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
assert isinstance(current_user, Account)
app_id = installed_app.app_id
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(

View File

@ -1,11 +1,11 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from constants import HIDDEN_VALUE
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
@ -47,6 +47,8 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tenant_id = current_user.current_tenant_id
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@ -68,6 +70,8 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json")
@ -95,6 +99,8 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
@ -119,6 +125,8 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
@ -146,6 +154,8 @@ class APIBasedExtensionDetailAPI(Resource):
@login_required
@account_initialization_required
def delete(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id

View File

@ -1,7 +1,7 @@
from flask_login import current_user
from flask_restx import Resource, fields
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from services.feature_service import FeatureService
from . import api, console_ns
@ -23,6 +23,8 @@ class FeatureApi(Resource):
@cloud_utm_record
def get(self):
"""Get feature configuration for current tenant"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
return FeatureService.get_features(current_user.current_tenant_id).model_dump()

View File

@ -1,8 +1,6 @@
import urllib.parse
from typing import cast
import httpx
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
import services
@ -16,6 +14,7 @@ from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from libs.login import current_user
from models.account import Account
from services.file_service import FileService
@ -65,7 +64,8 @@ class RemoteFileUploadApi(Resource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
user = cast(Account, current_user)
assert isinstance(current_user, Account)
user = current_user
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,

View File

@ -1,12 +1,12 @@
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from models.model import Tag
from services.tag_service import TagService
@ -24,6 +24,8 @@ class TagListApi(Resource):
@account_initialization_required
@marshal_with(dataset_tag_fields)
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tag_type = request.args.get("type", type=str, default="")
keyword = request.args.get("keyword", default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
@ -34,8 +36,10 @@ class TagListApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@ -59,9 +63,11 @@ class TagUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def patch(self, tag_id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@ -81,9 +87,11 @@ class TagUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def delete(self, tag_id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
TagService.delete_tag(tag_id)
@ -97,8 +105,10 @@ class TagBindingCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@ -123,8 +133,10 @@ class TagBindingDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()

View File

@ -1,10 +1,10 @@
from flask_login import current_user
from flask_restx import Resource, fields
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from services.agent_service import AgentService
@ -21,7 +21,9 @@ class AgentProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
assert isinstance(current_user, Account)
user = current_user
assert user.current_tenant_id is not None
user_id = user.id
tenant_id = user.current_tenant_id
@ -43,7 +45,9 @@ class AgentProviderApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_name: str):
assert isinstance(current_user, Account)
user = current_user
assert user.current_tenant_id is not None
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden
@ -6,10 +5,18 @@ from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from services.plugin.endpoint_service import EndpointService
def _current_account_with_tenant() -> tuple[Account, str]:
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
return current_user, tenant_id
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
@api.doc("create_endpoint")
@ -34,7 +41,7 @@ class EndpointCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
@ -51,7 +58,7 @@ class EndpointCreateApi(Resource):
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
plugin_unique_identifier=plugin_unique_identifier,
name=name,
@ -80,7 +87,7 @@ class EndpointListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -93,7 +100,7 @@ class EndpointListApi(Resource):
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
page=page,
page_size=page_size,
@ -123,7 +130,7 @@ class EndpointListForSinglePluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -138,7 +145,7 @@ class EndpointListForSinglePluginApi(Resource):
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints_for_single_plugin(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
page=page,
@ -165,7 +172,7 @@ class EndpointDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -177,9 +184,7 @@ class EndpointDeleteApi(Resource):
endpoint_id = args["endpoint_id"]
return {
"success": EndpointService.delete_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}
@ -207,7 +212,7 @@ class EndpointUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -224,7 +229,7 @@ class EndpointUpdateApi(Resource):
return {
"success": EndpointService.update_endpoint(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
name=name,
@ -250,7 +255,7 @@ class EndpointEnableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -262,9 +267,7 @@ class EndpointEnableApi(Resource):
raise Forbidden()
return {
"success": EndpointService.enable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}
@ -285,7 +288,7 @@ class EndpointDisableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -297,7 +300,5 @@ class EndpointDisableApi(Resource):
raise Forbidden()
return {
"success": EndpointService.disable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}

View File

@ -1,7 +1,6 @@
from urllib import parse
from flask import abort, request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
import services
@ -26,7 +25,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError

View File

@ -1,7 +1,6 @@
import logging
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
@ -24,7 +23,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account, Tenant, TenantStatus
from services.account_service import TenantService
from services.feature_service import FeatureService

View File

@ -7,13 +7,13 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request
from flask_login import current_user
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import AccountStatus
from libs.login import current_user
from models.account import Account, AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
@ -25,11 +25,16 @@ P = ParamSpec("P")
R = TypeVar("R")
def _current_account() -> Account:
assert isinstance(current_user, Account)
return current_user
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization
account = current_user
account = _current_account()
if account.status == AccountStatus.UNINITIALIZED:
raise AccountNotInitializedError()
@ -75,7 +80,9 @@ def only_edition_self_hosted(view: Callable[P, R]):
def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
@ -87,7 +94,10 @@ def cloud_edition_billing_resource_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
members = features.members
apps = features.apps
@ -128,7 +138,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
@ -151,10 +163,13 @@ def cloud_edition_billing_rate_limit_check(resource: str):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{current_user.current_tenant_id}"
key = f"rate_limit_{tenant_id}"
redis_client.zadd(key, {current_time: current_time})
@ -165,7 +180,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=current_user.current_tenant_id,
tenant_id=tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
@ -185,14 +200,17 @@ def cloud_utm_record(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
OperationService.record_utm(tenant_id, utm_info_dict)
return view(*args, **kwargs)
@ -271,7 +289,9 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)
@ -284,7 +304,9 @@ def is_allow_transfer_owner(view: Callable[P, R]):
def knowledge_pipeline_publish_enabled(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
if features.knowledge_pipeline.publish_enabled:
return view(*args, **kwargs)
abort(403)

View File

@ -70,7 +70,11 @@ class ModelConfigConverter:
if not model_mode:
model_mode = LLMMode.CHAT
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value
try:
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE])
except ValueError:
# Fall back to CHAT mode if the stored value is invalid
model_mode = LLMMode.CHAT
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")

View File

@ -1,5 +1,5 @@
import enum
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field, ValidationInfo, field_validator
@ -218,7 +218,7 @@ class DatasourceLabel(BaseModel):
icon: str = Field(..., description="The icon of the tool")
class DatasourceInvokeFrom(Enum):
class DatasourceInvokeFrom(StrEnum):
"""
Enum class for datasource invoke
"""

View File

@ -1414,7 +1414,7 @@ class ProviderConfiguration(BaseModel):
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum, StrEnum, auto
from enum import StrEnum, auto
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
@ -7,7 +7,7 @@ from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
class ConfigurateMethod(Enum):
class ConfigurateMethod(StrEnum):
"""
Enum class for configurate method of provider model.
"""

View File

@ -255,7 +255,7 @@ class BasePluginClient:
except Exception:
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
logger.error("Error in stream reponse for plugin %s", rep.__dict__)
logger.error("Error in stream response for plugin %s", rep.__dict__)
self._handle_plugin_daemon_error(error.error_type, error.message)
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
if rep.data is None:

View File

@ -1046,7 +1046,7 @@ class ProviderManager:
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT:
if credential_form_schema.type == FormType.SECRET_INPUT:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables

View File

@ -34,7 +34,7 @@ class RetrievalService:
@classmethod
def retrieve(
cls,
retrieval_method: str,
retrieval_method: RetrievalMethod,
dataset_id: str,
query: str,
top_k: int,
@ -56,7 +56,7 @@ class RetrievalService:
# Optimize multithreading with thread pools
with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore
futures = []
if retrieval_method == "keyword_search":
if retrieval_method == RetrievalMethod.KEYWORD_SEARCH:
futures.append(
executor.submit(
cls.keyword_search,
@ -220,7 +220,7 @@ class RetrievalService:
score_threshold: float | None,
reranking_model: dict | None,
all_documents: list,
retrieval_method: str,
retrieval_method: RetrievalMethod,
exceptions: list,
document_ids_filter: list[str] | None = None,
):

View File

@ -1,11 +1,11 @@
from collections.abc import Mapping
from enum import Enum
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class DatasourceStreamEvent(Enum):
class DatasourceStreamEvent(StrEnum):
"""
Datasource Stream event
"""

View File

@ -1,6 +1,7 @@
import logging
import os
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
@ -49,7 +50,8 @@ class UnstructuredWordExtractor(BaseExtractor):
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
documents = []
for chunk in chunks:
text = chunk.text.strip()

View File

@ -4,6 +4,7 @@ import logging
from bs4 import BeautifulSoup
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
@ -46,7 +47,8 @@ class UnstructuredEmailExtractor(BaseExtractor):
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
documents = []
for chunk in chunks:
text = chunk.text.strip()

View File

@ -2,6 +2,7 @@ import logging
import pypandoc # type: ignore
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
@ -40,7 +41,8 @@ class UnstructuredEpubExtractor(BaseExtractor):
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
documents = []
for chunk in chunks:
text = chunk.text.strip()

View File

@ -1,5 +1,6 @@
import logging
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
@ -32,7 +33,8 @@ class UnstructuredMarkdownExtractor(BaseExtractor):
elements = partition_md(filename=self._file_path)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
documents = []
for chunk in chunks:
text = chunk.text.strip()

View File

@ -1,5 +1,6 @@
import logging
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
@ -31,7 +32,8 @@ class UnstructuredMsgExtractor(BaseExtractor):
elements = partition_msg(filename=self._file_path)
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
documents = []
for chunk in chunks:
text = chunk.text.strip()

View File

@ -1,5 +1,6 @@
import logging
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
@ -32,7 +33,8 @@ class UnstructuredXmlExtractor(BaseExtractor):
from unstructured.chunking.title import chunk_by_title
chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
max_characters = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
chunks = chunk_by_title(elements, max_characters=max_characters, combine_text_under_n_chars=max_characters)
documents = []
for chunk in chunks:
text = chunk.text.strip()

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
from configs import dify_config
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
@ -49,7 +50,7 @@ class BaseIndexProcessor(ABC):
@abstractmethod
def retrieve(
self,
retrieval_method: str,
retrieval_method: RetrievalMethod,
query: str,
dataset: Dataset,
top_k: int,

View File

@ -14,6 +14,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset, DatasetProcessRule
@ -106,7 +107,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
def retrieve(
self,
retrieval_method: str,
retrieval_method: RetrievalMethod,
query: str,
dataset: Dataset,
top_k: int,

View File

@ -16,6 +16,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
@ -161,7 +162,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
def retrieve(
self,
retrieval_method: str,
retrieval_method: RetrievalMethod,
query: str,
dataset: Dataset,
top_k: int,

View File

@ -21,6 +21,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document, QAStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
@ -141,7 +142,7 @@ class QAIndexProcessor(BaseIndexProcessor):
def retrieve(
self,
retrieval_method: str,
retrieval_method: RetrievalMethod,
query: str,
dataset: Dataset,
top_k: int,

View File

@ -364,7 +364,7 @@ class DatasetRetrieval:
top_k = retrieval_model_config["top_k"]
# get retrieval method
if dataset.indexing_technique == "economy":
retrieval_method = "keyword_search"
retrieval_method = RetrievalMethod.KEYWORD_SEARCH
else:
retrieval_method = retrieval_model_config["search_method"]
# get reranking model
@ -623,7 +623,7 @@ class DatasetRetrieval:
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search",
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id,
query=query,
top_k=top_k,

View File

@ -1,7 +1,7 @@
from enum import Enum
from enum import StrEnum
class RetrievalMethod(Enum):
class RetrievalMethod(StrEnum):
SEMANTIC_SEARCH = "semantic_search"
FULL_TEXT_SEARCH = "full_text_search"
HYBRID_SEARCH = "hybrid_search"

View File

@ -76,7 +76,8 @@ class MCPToolProviderController(ToolProviderController):
)
for remote_mcp_tool in remote_mcp_tools
]
if not db_provider.icon:
raise ValueError("Database provider icon is required")
return cls(
entity=ToolProviderEntityWithPlugin(
identity=ToolProviderIdentity(

View File

@ -172,7 +172,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search",
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 4,

View File

@ -130,7 +130,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search",
retrieval_method=RetrievalMethod.KEYWORD_SEARCH,
dataset_id=dataset.id,
query=query,
top_k=self.top_k,

View File

@ -5,6 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models.
"""
import json
from decimal import Decimal
from typing import cast
from core.model_manager import ModelManager
@ -118,10 +119,10 @@ class ModelInvocationUtils:
model_response="",
prompt_tokens=prompt_tokens,
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
answer_unit_price=Decimal(),
answer_price_unit=Decimal(),
provider_response_latency=0,
total_price=0,
total_price=Decimal(),
currency="USD",
)
@ -152,7 +153,7 @@ class ModelInvocationUtils:
raise InvokeModelError(f"Invoke error: {e}")
# update tool model invoke
tool_model_invoke.model_response = response.message.content
tool_model_invoke.model_response = str(response.message.content)
if response.usage:
tool_model_invoke.answer_tokens = response.usage.completion_tokens
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price

View File

@ -1,7 +1,7 @@
from enum import Enum, StrEnum
from enum import StrEnum
class NodeState(Enum):
class NodeState(StrEnum):
"""State of a node or edge during workflow execution."""
UNKNOWN = "unknown"

View File

@ -10,7 +10,7 @@ When limits are exceeded, the layer automatically aborts execution.
import logging
import time
from enum import Enum
from enum import StrEnum
from typing import final
from typing_extensions import override
@ -24,7 +24,7 @@ from core.workflow.graph_events import (
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
class LimitType(Enum):
class LimitType(StrEnum):
"""Types of execution limits that can be exceeded."""
STEP_LIMIT = "step_limit"

View File

@ -2,6 +2,7 @@ from typing import Literal, Union
from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.nodes.base import BaseNodeData
@ -63,7 +64,7 @@ class RetrievalSetting(BaseModel):
Retrieval Setting.
"""
search_method: Literal["semantic_search", "keyword_search", "full_text_search", "hybrid_search"]
search_method: RetrievalMethod
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False

View File

@ -37,10 +37,11 @@ config.set_main_option('sqlalchemy.url', get_engine_url())
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
from models.base import Base
from models.base import TypeBase
def get_metadata():
return Base.metadata
return TypeBase.metadata
def include_object(object, name, type_, reflected, compare_to):
if type_ == "foreign_key_constraint":

View File

@ -6,12 +6,12 @@ from sqlalchemy import DateTime, String, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from models.base import Base
from models.base import TypeBase
from .types import StringUUID
class DataSourceOauthBinding(Base):
class DataSourceOauthBinding(TypeBase):
__tablename__ = "data_source_oauth_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="source_binding_pkey"),
@ -19,17 +19,25 @@ class DataSourceOauthBinding(Base):
sa.Index("source_info_idx", "source_info", postgresql_using="gin"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
source_info = mapped_column(JSONB, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
source_info: Mapped[dict] = mapped_column(JSONB, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
class DataSourceApiKeyAuthBinding(Base):
class DataSourceApiKeyAuthBinding(TypeBase):
__tablename__ = "data_source_api_key_auth_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
@ -37,14 +45,22 @@ class DataSourceApiKeyAuthBinding(Base):
sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
category: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
credentials = mapped_column(sa.Text, nullable=True) # JSON
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
disabled: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
credentials: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None) # JSON
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
def to_dict(self):
return {
@ -52,7 +68,7 @@ class DataSourceApiKeyAuthBinding(Base):
"tenant_id": self.tenant_id,
"category": self.category,
"provider": self.provider,
"credentials": json.loads(self.credentials),
"credentials": json.loads(self.credentials) if self.credentials else None,
"created_at": self.created_at.timestamp(),
"updated_at": self.updated_at.timestamp(),
"disabled": self.disabled,

View File

@ -6,41 +6,43 @@ from sqlalchemy import DateTime, String
from sqlalchemy.orm import Mapped, mapped_column
from libs.datetime_utils import naive_utc_now
from models.base import Base
from models.base import TypeBase
class CeleryTask(Base):
class CeleryTask(TypeBase):
"""Task result/status."""
__tablename__ = "celery_taskmeta"
id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
task_id = mapped_column(String(155), unique=True)
status = mapped_column(String(50), default=states.PENDING)
result = mapped_column(sa.PickleType, nullable=True)
date_done = mapped_column(
id: Mapped[int] = mapped_column(
sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True, init=False
)
task_id: Mapped[str] = mapped_column(String(155), unique=True)
status: Mapped[str] = mapped_column(String(50), default=states.PENDING)
result: Mapped[bytes | None] = mapped_column(sa.PickleType, nullable=True, default=None)
date_done: Mapped[datetime | None] = mapped_column(
DateTime,
default=lambda: naive_utc_now(),
onupdate=lambda: naive_utc_now(),
default=naive_utc_now,
onupdate=naive_utc_now,
nullable=True,
)
traceback = mapped_column(sa.Text, nullable=True)
name = mapped_column(String(155), nullable=True)
args = mapped_column(sa.LargeBinary, nullable=True)
kwargs = mapped_column(sa.LargeBinary, nullable=True)
worker = mapped_column(String(155), nullable=True)
retries: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
queue = mapped_column(String(155), nullable=True)
traceback: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
name: Mapped[str | None] = mapped_column(String(155), nullable=True, default=None)
args: Mapped[bytes | None] = mapped_column(sa.LargeBinary, nullable=True, default=None)
kwargs: Mapped[bytes | None] = mapped_column(sa.LargeBinary, nullable=True, default=None)
worker: Mapped[str | None] = mapped_column(String(155), nullable=True, default=None)
retries: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
queue: Mapped[str | None] = mapped_column(String(155), nullable=True, default=None)
class CeleryTaskSet(Base):
class CeleryTaskSet(TypeBase):
"""TaskSet result."""
__tablename__ = "celery_tasksetmeta"
id: Mapped[int] = mapped_column(
sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True, init=False
)
taskset_id = mapped_column(String(155), unique=True)
result = mapped_column(sa.PickleType, nullable=True)
date_done: Mapped[datetime | None] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True)
taskset_id: Mapped[str] = mapped_column(String(155), unique=True)
result: Mapped[bytes | None] = mapped_column(sa.PickleType, nullable=True, default=None)
date_done: Mapped[datetime | None] = mapped_column(DateTime, default=naive_utc_now, nullable=True)

View File

@ -1,6 +1,7 @@
import json
from collections.abc import Mapping
from datetime import datetime
from decimal import Decimal
from typing import TYPE_CHECKING, Any, cast
from urllib.parse import urlparse
@ -13,7 +14,7 @@ from core.helper import encrypter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from models.base import Base, TypeBase
from models.base import TypeBase
from .engine import db
from .model import Account, App, Tenant
@ -42,28 +43,28 @@ class ToolOAuthSystemClient(TypeBase):
# tenant level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthTenantClient(Base):
class ToolOAuthTenantClient(TypeBase):
__tablename__ = "tool_oauth_tenant_clients"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), init=False)
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False, init=False)
@property
def oauth_params(self) -> dict[str, Any]:
return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
class BuiltinToolProvider(Base):
class BuiltinToolProvider(TypeBase):
"""
This table stores the tool provider information for built-in tools for each tenant.
"""
@ -75,37 +76,45 @@ class BuiltinToolProvider(Base):
)
# id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
name: Mapped[str] = mapped_column(
String(256), nullable=False, server_default=sa.text("'API KEY 1'::character varying")
String(256),
nullable=False,
server_default=sa.text("'API KEY 1'::character varying"),
)
# id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
# who created this tool provider
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# name of the tool provider
provider: Mapped[str] = mapped_column(String(256), nullable=False)
# credential of the tool provider
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
encrypted_credentials: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
sa.DateTime,
nullable=False,
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
onupdate=func.current_timestamp(),
init=False,
)
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
# credential type, e.g., "api-key", "oauth2"
credential_type: Mapped[str] = mapped_column(
String(32), nullable=False, server_default=sa.text("'api-key'::character varying")
String(32), nullable=False, server_default=sa.text("'api-key'::character varying"), default="api-key"
)
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"), default=-1)
@property
def credentials(self) -> dict[str, Any]:
if not self.encrypted_credentials:
return {}
return cast(dict[str, Any], json.loads(self.encrypted_credentials))
class ApiToolProvider(Base):
class ApiToolProvider(TypeBase):
"""
The table stores the api providers.
"""
@ -116,31 +125,43 @@ class ApiToolProvider(Base):
sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
# name of the api provider
name = mapped_column(String(255), nullable=False, server_default=sa.text("'API KEY 1'::character varying"))
name: Mapped[str] = mapped_column(
String(255),
nullable=False,
server_default=sa.text("'API KEY 1'::character varying"),
)
# icon
icon: Mapped[str] = mapped_column(String(255), nullable=False)
# original schema
schema = mapped_column(sa.Text, nullable=False)
schema: Mapped[str] = mapped_column(sa.Text, nullable=False)
schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False)
# who created this tool
user_id = mapped_column(StringUUID, nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
tenant_id = mapped_column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# description of the provider
description = mapped_column(sa.Text, nullable=False)
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
# json format tools
tools_str = mapped_column(sa.Text, nullable=False)
tools_str: Mapped[str] = mapped_column(sa.Text, nullable=False)
# json format credentials
credentials_str = mapped_column(sa.Text, nullable=False)
credentials_str: Mapped[str] = mapped_column(sa.Text, nullable=False)
# privacy policy
privacy_policy = mapped_column(String(255), nullable=True)
privacy_policy: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
# custom_disclaimer
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
@property
def schema_type(self) -> "ApiProviderSchemaType":
@ -189,7 +210,7 @@ class ToolLabelBinding(TypeBase):
label_name: Mapped[str] = mapped_column(String(40), nullable=False)
class WorkflowToolProvider(Base):
class WorkflowToolProvider(TypeBase):
"""
The table stores the workflow providers.
"""
@ -201,7 +222,7 @@ class WorkflowToolProvider(Base):
sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
# name of the workflow provider
name: Mapped[str] = mapped_column(String(255), nullable=False)
# label of the workflow provider
@ -219,15 +240,19 @@ class WorkflowToolProvider(Base):
# description of the provider
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
# parameter configuration
parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]")
parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]", default="[]")
# privacy policy
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="")
privacy_policy: Mapped[str | None] = mapped_column(String(255), nullable=True, server_default="", default=None)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
sa.DateTime,
nullable=False,
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
onupdate=func.current_timestamp(),
init=False,
)
@property
@ -252,7 +277,7 @@ class WorkflowToolProvider(Base):
return db.session.query(App).where(App.id == self.app_id).first()
class MCPToolProvider(Base):
class MCPToolProvider(TypeBase):
"""
The table stores the mcp providers.
"""
@ -265,7 +290,7 @@ class MCPToolProvider(Base):
sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
# name of the mcp provider
name: Mapped[str] = mapped_column(String(40), nullable=False)
# server identifier of the mcp provider
@ -275,27 +300,33 @@ class MCPToolProvider(Base):
# hash of server_url for uniqueness check
server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False)
# icon of the mcp provider
icon: Mapped[str] = mapped_column(String(255), nullable=True)
icon: Mapped[str | None] = mapped_column(String(255), nullable=True)
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# who created this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# encrypted credentials
encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True)
encrypted_credentials: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
# authed
authed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
# tools
tools: Mapped[str] = mapped_column(sa.Text, nullable=False, default="[]")
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")
sa.DateTime,
nullable=False,
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
onupdate=func.current_timestamp(),
init=False,
)
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"), default=30.0)
sse_read_timeout: Mapped[float] = mapped_column(
sa.Float, nullable=False, server_default=sa.text("300"), default=300.0
)
timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30"))
sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300"))
# encrypted headers for MCP server requests
encrypted_headers: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
encrypted_headers: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
def load_user(self) -> Account | None:
return db.session.query(Account).where(Account.id == self.user_id).first()
@ -306,9 +337,11 @@ class MCPToolProvider(Base):
@property
def credentials(self) -> dict[str, Any]:
if not self.encrypted_credentials:
return {}
try:
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
except Exception:
except json.JSONDecodeError:
return {}
@property
@ -321,6 +354,7 @@ class MCPToolProvider(Base):
def provider_icon(self) -> Mapping[str, str] | str:
from core.file import helpers as file_helpers
assert self.icon
try:
return json.loads(self.icon)
except json.JSONDecodeError:
@ -419,7 +453,7 @@ class MCPToolProvider(Base):
return encrypter.decrypt(self.credentials)
class ToolModelInvoke(Base):
class ToolModelInvoke(TypeBase):
"""
store the invoke logs from tool invoke
"""
@ -427,37 +461,47 @@ class ToolModelInvoke(Base):
__tablename__ = "tool_model_invokes"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
# who invoke this tool
user_id = mapped_column(StringUUID, nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
tenant_id = mapped_column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# provider
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# type
tool_type = mapped_column(String(40), nullable=False)
tool_type: Mapped[str] = mapped_column(String(40), nullable=False)
# tool name
tool_name = mapped_column(String(128), nullable=False)
tool_name: Mapped[str] = mapped_column(String(128), nullable=False)
# invoke parameters
model_parameters = mapped_column(sa.Text, nullable=False)
model_parameters: Mapped[str] = mapped_column(sa.Text, nullable=False)
# prompt messages
prompt_messages = mapped_column(sa.Text, nullable=False)
prompt_messages: Mapped[str] = mapped_column(sa.Text, nullable=False)
# invoke response
model_response = mapped_column(sa.Text, nullable=False)
model_response: Mapped[str] = mapped_column(sa.Text, nullable=False)
prompt_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False)
answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001"))
provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_price = mapped_column(sa.Numeric(10, 7))
answer_unit_price: Mapped[Decimal] = mapped_column(sa.Numeric(10, 4), nullable=False)
answer_price_unit: Mapped[Decimal] = mapped_column(
sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")
)
provider_response_latency: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric(10, 7))
currency: Mapped[str] = mapped_column(String(255), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
@deprecated
class ToolConversationVariables(Base):
class ToolConversationVariables(TypeBase):
"""
store the conversation variables from tool invoke
"""
@ -470,18 +514,26 @@ class ToolConversationVariables(Base):
sa.Index("conversation_id_idx", "conversation_id"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
# conversation user id
user_id = mapped_column(StringUUID, nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# tenant id
tenant_id = mapped_column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# conversation id
conversation_id = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# variables pool
variables_str = mapped_column(sa.Text, nullable=False)
variables_str: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
@property
def variables(self):
@ -519,7 +571,7 @@ class ToolFile(TypeBase):
@deprecated
class DeprecatedPublishedAppTool(Base):
class DeprecatedPublishedAppTool(TypeBase):
"""
The table stores the apps published as a tool for each person.
"""
@ -530,26 +582,34 @@ class DeprecatedPublishedAppTool(Base):
sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
# id of the app
app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# who published this tool
description = mapped_column(sa.Text, nullable=False)
description: Mapped[str] = mapped_column(sa.Text, nullable=False)
# llm_description of the tool, for LLM
llm_description = mapped_column(sa.Text, nullable=False)
llm_description: Mapped[str] = mapped_column(sa.Text, nullable=False)
# query description, query will be seem as a parameter of the tool,
# to describe this parameter to llm, we need this field
query_description = mapped_column(sa.Text, nullable=False)
query_description: Mapped[str] = mapped_column(sa.Text, nullable=False)
# query name, the name of the query parameter
query_name = mapped_column(String(40), nullable=False)
query_name: Mapped[str] = mapped_column(String(40), nullable=False)
# name of the tool provider
tool_name = mapped_column(String(40), nullable=False)
tool_name: Mapped[str] = mapped_column(String(40), nullable=False)
# author
author = mapped_column(String(40), nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
author: Mapped[str] = mapped_column(String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=sa.text("CURRENT_TIMESTAMP(0)"),
onupdate=func.current_timestamp(),
init=False,
)
@property
def description_i18n(self) -> "I18nObject":

View File

@ -4,46 +4,58 @@ import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column
from models.base import Base
from models.base import TypeBase
from .engine import db
from .model import Message
from .types import StringUUID
class SavedMessage(Base):
class SavedMessage(TypeBase):
__tablename__ = "saved_messages"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="saved_message_pkey"),
sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
message_id = mapped_column(StringUUID, nullable=False)
created_by_role = mapped_column(
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_by_role: Mapped[str] = mapped_column(
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
server_default=func.current_timestamp(),
init=False,
)
@property
def message(self):
return db.session.query(Message).where(Message.id == self.message_id).first()
class PinnedConversation(Base):
class PinnedConversation(TypeBase):
__tablename__ = "pinned_conversations"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID)
created_by_role = mapped_column(
String(255), nullable=False, server_default=sa.text("'end_user'::character varying")
created_by_role: Mapped[str] = mapped_column(
String(255),
nullable=False,
server_default=sa.text("'end_user'::character varying"),
)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
server_default=func.current_timestamp(),
init=False,
)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -36,7 +36,7 @@ dependencies = [
"markdown~=3.5.1",
"numpy~=1.26.4",
"openpyxl~=3.1.5",
"opik~=1.7.25",
"opik~=1.8.72",
"opentelemetry-api==1.27.0",
"opentelemetry-distro==0.48b0",
"opentelemetry-exporter-otlp==1.27.0",

View File

@ -26,10 +26,9 @@ class ApiKeyAuthService:
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
args["credentials"]["config"]["api_key"] = api_key
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
data_source_api_key_binding.tenant_id = tenant_id
data_source_api_key_binding.category = args["category"]
data_source_api_key_binding.provider = args["provider"]
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
)
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
@ -48,6 +47,8 @@ class ApiKeyAuthService:
)
if not data_source_api_key_bindings:
return None
if not data_source_api_key_bindings.credentials:
return None
credentials = json.loads(data_source_api_key_bindings.credentials)
return credentials

View File

@ -1470,7 +1470,7 @@ class DocumentService:
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
@ -1752,7 +1752,7 @@ class DocumentService:
# dataset.collection_binding_id = dataset_collection_binding.id
# if not dataset.retrieval_model:
# default_retrieval_model = {
# "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
# "search_method": RetrievalMethod.SEMANTIC_SEARCH,
# "reranking_enable": False,
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
# "top_k": 2,
@ -2205,7 +2205,7 @@ class DocumentService:
retrieval_model = knowledge_config.retrieval_model
else:
retrieval_model = RetrievalModel(
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
search_method=RetrievalMethod.SEMANTIC_SEARCH,
reranking_enable=False,
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
top_k=4,

View File

@ -3,6 +3,8 @@ from typing import Literal
from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
class ParentMode(StrEnum):
FULL_DOC = "full-doc"
@ -95,7 +97,7 @@ class WeightModel(BaseModel):
class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
search_method: RetrievalMethod
reranking_enable: bool
reranking_model: RerankingModel | None = None
reranking_mode: str | None = None

View File

@ -2,6 +2,8 @@ from typing import Literal
from pydantic import BaseModel, field_validator
from core.rag.retrieval.retrieval_methods import RetrievalMethod
class IconInfo(BaseModel):
icon: str
@ -83,7 +85,7 @@ class RetrievalSetting(BaseModel):
Retrieval Setting.
"""
search_method: Literal["semantic_search", "full_text_search", "keyword_search", "hybrid_search"]
search_method: RetrievalMethod
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from enum import StrEnum
from pydantic import BaseModel, ConfigDict, model_validator
@ -27,7 +27,7 @@ from core.model_runtime.entities.provider_entities import (
from models.provider import ProviderType
class CustomConfigurationStatus(Enum):
class CustomConfigurationStatus(StrEnum):
"""
Enum class for custom configuration status.
"""

View File

@ -88,9 +88,9 @@ class ExternalDatasetService:
else:
raise ValueError(f"invalid endpoint: {endpoint}")
try:
response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
except Exception:
raise ValueError(f"failed to connect to the endpoint: {endpoint}")
response = ssrf_proxy.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
except Exception as e:
raise ValueError(f"failed to connect to the endpoint: {endpoint}") from e
if response.status_code == 502:
raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}")
if response.status_code == 404:

View File

@ -63,7 +63,7 @@ class HitTestingService:
if metadata_condition and not document_ids_filter:
return cls.compact_retrieve_response(query, [])
all_documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k", 4),

View File

@ -9,6 +9,7 @@ from flask_login import current_user
from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from factories import variable_factory
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
@ -164,7 +165,7 @@ class RagPipelineTransformService:
if retrieval_model:
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
if indexing_technique == "economy":
retrieval_setting.search_method = "keyword_search"
retrieval_setting.search_method = RetrievalMethod.KEYWORD_SEARCH
knowledge_configuration.retrieval_model = retrieval_setting
else:
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()

View File

@ -148,7 +148,7 @@ class ApiToolManageService:
description=extra_info.get("description", ""),
schema_type_str=schema_type,
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
credentials_str={},
credentials_str="{}",
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
)

View File

@ -683,7 +683,7 @@ class BuiltinToolManageService:
cache=NoOpProviderCredentialCache(),
)
original_params = encrypter.decrypt(custom_client_params.oauth_params)
new_params: dict = {
new_params = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}

View File

@ -188,6 +188,8 @@ class MCPToolManageService:
raise
user = mcp_provider.load_user()
if not mcp_provider.icon:
raise ValueError("MCP provider icon is required")
return ToolProviderApiEntity(
id=mcp_provider.id,
name=mcp_provider.name,

View File

@ -152,7 +152,8 @@ class ToolTransformService:
if decrypt_credentials:
credentials = db_provider.credentials
if not db_provider.tenant_id:
raise ValueError(f"Required tenant_id is missing for BuiltinToolProvider with id {db_provider.id}")
# init tool configuration
encrypter, _ = create_provider_encrypter(
tenant_id=db_provider.tenant_id,

View File

@ -60,7 +60,7 @@ class TestAccountInitialization:
return "success"
# Act
with patch("controllers.console.wraps.current_user", mock_user):
with patch("controllers.console.wraps._current_account", return_value=mock_user):
result = protected_view()
# Assert
@ -77,7 +77,7 @@ class TestAccountInitialization:
return "success"
# Act & Assert
with patch("controllers.console.wraps.current_user", mock_user):
with patch("controllers.console.wraps._current_account", return_value=mock_user):
with pytest.raises(AccountNotInitializedError):
protected_view()
@ -163,7 +163,7 @@ class TestBillingResourceLimits:
return "member_added"
# Act
with patch("controllers.console.wraps.current_user"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = add_member()
@ -185,7 +185,7 @@ class TestBillingResourceLimits:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
add_member()
@ -207,7 +207,7 @@ class TestBillingResourceLimits:
# Test 1: Should reject when source is datasets
with app.test_request_context("/?source=datasets"):
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
upload_document()
@ -215,7 +215,7 @@ class TestBillingResourceLimits:
# Test 2: Should allow when source is not datasets
with app.test_request_context("/?source=other"):
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = upload_document()
assert result == "document_uploaded"
@ -239,7 +239,7 @@ class TestRateLimiting:
return "knowledge_success"
# Act
with patch("controllers.console.wraps.current_user"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):
@ -271,7 +271,7 @@ class TestRateLimiting:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):

View File

@ -1,10 +1,12 @@
import os
from pytest_mock import MockerFixture
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
def test_firecrawl_web_extractor_crawl_mode(mocker):
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
url = "https://firecrawl.dev"
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
base_url = "https://api.firecrawl.dev"

View File

@ -1,5 +1,7 @@
from unittest import mock
from pytest_mock import MockerFixture
from core.rag.extractor import notion_extractor
user_id = "user1"
@ -57,7 +59,7 @@ def _remove_multiple_new_lines(text):
return text.strip()
def test_notion_page(mocker):
def test_notion_page(mocker: MockerFixture):
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
mocked_notion_page = {
"object": "list",
@ -77,7 +79,7 @@ def test_notion_page(mocker):
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
def test_notion_database(mocker):
def test_notion_database(mocker: MockerFixture):
page_title_list = ["page1", "page2", "page3"]
mocked_notion_database = {
"object": "list",

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
import redis
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager
@ -39,7 +40,7 @@ def lb_model_manager():
return lb_model_manager
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
# initialize redis client
redis_client.initialize(redis.Redis())

View File

@ -14,7 +14,13 @@ from core.entities.provider_entities import (
)
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
FormOption,
FormType,
ProviderEntity,
)
from models.provider import Provider, ProviderType
@ -306,3 +312,174 @@ class TestProviderConfiguration:
# Assert
assert credentials == {"openai_api_key": "test_key"}
def test_extract_secret_variables_with_secret_input(self, provider_configuration):
"""Test extracting secret variables from credential form schemas"""
# Arrange
credential_form_schemas = [
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
type=FormType.SECRET_INPUT,
required=True,
),
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="secret_token",
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
type=FormType.SECRET_INPUT,
required=False,
),
]
# Act
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
# Assert
assert len(secret_variables) == 2
assert "api_key" in secret_variables
assert "secret_token" in secret_variables
assert "model_name" not in secret_variables
def test_extract_secret_variables_no_secret_input(self, provider_configuration):
"""Test extracting secret variables when no secret input fields exist"""
# Arrange
credential_form_schemas = [
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=FormType.SELECT,
required=True,
options=[FormOption(label=I18nObject(en_US="0.1", zh_Hans="0.1"), value="0.1")],
),
]
# Act
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
# Assert
assert len(secret_variables) == 0
def test_extract_secret_variables_empty_list(self, provider_configuration):
"""Test extracting secret variables from empty credential form schemas"""
# Arrange
credential_form_schemas = []
# Act
secret_variables = provider_configuration.extract_secret_variables(credential_form_schemas)
# Assert
assert len(secret_variables) == 0
@patch("core.entities.provider_configuration.encrypter")
def test_obfuscated_credentials_with_secret_variables(self, mock_encrypter, provider_configuration):
"""Test obfuscating credentials with secret variables"""
# Arrange
credentials = {
"api_key": "sk-1234567890abcdef",
"model_name": "gpt-4",
"secret_token": "secret_value_123",
"temperature": "0.7",
}
credential_form_schemas = [
CredentialFormSchema(
variable="api_key",
label=I18nObject(en_US="API Key", zh_Hans="API 密钥"),
type=FormType.SECRET_INPUT,
required=True,
),
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="secret_token",
label=I18nObject(en_US="Secret Token", zh_Hans="密钥令牌"),
type=FormType.SECRET_INPUT,
required=False,
),
CredentialFormSchema(
variable="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=FormType.TEXT_INPUT,
required=True,
),
]
mock_encrypter.obfuscated_token.side_effect = lambda x: f"***{x[-4:]}"
# Act
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
# Assert
assert obfuscated["api_key"] == "***cdef"
assert obfuscated["model_name"] == "gpt-4" # Not obfuscated
assert obfuscated["secret_token"] == "***_123"
assert obfuscated["temperature"] == "0.7" # Not obfuscated
# Verify encrypter was called for secret fields only
assert mock_encrypter.obfuscated_token.call_count == 2
mock_encrypter.obfuscated_token.assert_any_call("sk-1234567890abcdef")
mock_encrypter.obfuscated_token.assert_any_call("secret_value_123")
def test_obfuscated_credentials_no_secret_variables(self, provider_configuration):
"""Test obfuscating credentials when no secret variables exist"""
# Arrange
credentials = {
"model_name": "gpt-4",
"temperature": "0.7",
"max_tokens": "1000",
}
credential_form_schemas = [
CredentialFormSchema(
variable="model_name",
label=I18nObject(en_US="Model Name", zh_Hans="模型名称"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=FormType.TEXT_INPUT,
required=True,
),
CredentialFormSchema(
variable="max_tokens",
label=I18nObject(en_US="Max Tokens", zh_Hans="最大令牌数"),
type=FormType.TEXT_INPUT,
required=True,
),
]
# Act
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
# Assert
assert obfuscated == credentials # No changes expected
def test_obfuscated_credentials_empty_credentials(self, provider_configuration):
"""Test obfuscating empty credentials"""
# Arrange
credentials = {}
credential_form_schemas = []
# Act
obfuscated = provider_configuration.obfuscated_credentials(credentials, credential_form_schemas)
# Assert
assert obfuscated == {}

View File

@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelSettings
from core.model_runtime.entities.model_entities import ModelType
@ -7,19 +8,25 @@ from models.provider import LoadBalancingModelConfig, ProviderModelSetting
@pytest.fixture
def mock_provider_entity(mocker):
def mock_provider_entity(mocker: MockerFixture):
mock_entity = mocker.Mock()
mock_entity.provider = "openai"
mock_entity.configurate_methods = ["predefined-model"]
mock_entity.supported_model_types = [ModelType.LLM]
mock_entity.model_credential_schema = mocker.Mock()
mock_entity.model_credential_schema.credential_form_schemas = []
# Use PropertyMock to ensure credential_form_schemas is iterable
provider_credential_schema = mocker.Mock()
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.provider_credential_schema = provider_credential_schema
model_credential_schema = mocker.Mock()
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.model_credential_schema = model_credential_schema
return mock_entity
def test__to_model_settings(mocker, mock_provider_entity):
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
@ -79,7 +86,7 @@ def test__to_model_settings(mocker, mock_provider_entity):
assert result[0].load_balancing_configs[1].name == "first"
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
@ -127,7 +134,7 @@ def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
assert len(result[0].load_balancing_configs) == 0
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(

View File

@ -1520,7 +1520,7 @@ requires-dist = [
{ name = "opentelemetry-sdk", specifier = "==1.27.0" },
{ name = "opentelemetry-semantic-conventions", specifier = "==0.48b0" },
{ name = "opentelemetry-util-http", specifier = "==0.48b0" },
{ name = "opik", specifier = "~=1.7.25" },
{ name = "opik", specifier = "~=1.8.72" },
{ name = "packaging", specifier = "~=23.2" },
{ name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=2.2.2" },
{ name = "psycogreen", specifier = "~=1.0.2" },
@ -4019,7 +4019,7 @@ wheels = [
[[package]]
name = "opik"
version = "1.7.43"
version = "1.8.72"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "boto3-stubs", extra = ["bedrock-runtime"] },
@ -4038,9 +4038,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "uuid6" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ba/52/cea0317bc3207bc967b48932781995d9cdb2c490e7e05caa00ff660f7205/opik-1.7.43.tar.gz", hash = "sha256:0b02522b0b74d0a67b141939deda01f8bb69690eda6b04a7cecb1c7f0649ccd0", size = 326886, upload-time = "2025-07-07T10:30:07.715Z" }
sdist = { url = "https://files.pythonhosted.org/packages/aa/08/679b60db21994cf3318d4cdd1d08417c1877b79ac20971a8d80f118c9455/opik-1.8.72.tar.gz", hash = "sha256:26fcb003dc609d96b52eaf6a12fb16eb2b69eb0d1b35d88279ec612925d23944", size = 409774, upload-time = "2025-10-10T13:22:38.2Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/76/ae/f3566bdc3c49a1a8f795b1b6e726ef211c87e31f92d870ca6d63999c9bbf/opik-1.7.43-py3-none-any.whl", hash = "sha256:a66395c8b5ea7c24846f72dafc70c74d5b8f24ffbc4c8a1b3a7f9456e550568d", size = 625356, upload-time = "2025-07-07T10:30:06.389Z" },
{ url = "https://files.pythonhosted.org/packages/f8/f5/04d35af828d127de65a36286ce5b53e7310087a6b55a56f398daa7f0c9a6/opik-1.8.72-py3-none-any.whl", hash = "sha256:697e361a8364666f36aeb197aaba7ffa0696b49f04d2257b733d436749c90a8c", size = 768233, upload-time = "2025-10-10T13:22:36.352Z" },
]
[[package]]

View File

@ -1 +1,3 @@
recursive-include dify_client *.py
include README.md
include LICENSE

View File

@ -10,6 +10,8 @@ First, install `dify-client` python sdk package:
pip install dify-client
```
### Synchronous Usage
Write your code with sdk:
- completion generate with `blocking` response_mode
@ -221,3 +223,187 @@ answer = result.get("data").get("outputs")
print(answer["answer"])
```
- Dataset Management
```python
from dify_client import KnowledgeBaseClient
api_key = "your_api_key"
dataset_id = "your_dataset_id"
# Use context manager to ensure proper resource cleanup
with KnowledgeBaseClient(api_key, dataset_id) as kb_client:
# Get dataset information
dataset_info = kb_client.get_dataset()
dataset_info.raise_for_status()
print(dataset_info.json())
# Update dataset configuration
update_response = kb_client.update_dataset(
name="Updated Dataset Name",
description="Updated description",
indexing_technique="high_quality"
)
update_response.raise_for_status()
print(update_response.json())
# Batch update document status
batch_response = kb_client.batch_update_document_status(
action="enable",
document_ids=["doc_id_1", "doc_id_2", "doc_id_3"]
)
batch_response.raise_for_status()
print(batch_response.json())
```
- Conversation Variables Management
```python
from dify_client import ChatClient
api_key = "your_api_key"
# Use context manager to ensure proper resource cleanup
with ChatClient(api_key) as chat_client:
# Get all conversation variables
variables = chat_client.get_conversation_variables(
conversation_id="conversation_id",
user="user_id"
)
variables.raise_for_status()
print(variables.json())
# Update a specific conversation variable
update_var = chat_client.update_conversation_variable(
conversation_id="conversation_id",
variable_id="variable_id",
value="new_value",
user="user_id"
)
update_var.raise_for_status()
print(update_var.json())
```
### Asynchronous Usage
The SDK provides full async/await support for all API operations using `httpx.AsyncClient`. All async clients mirror their synchronous counterparts but require `await` for method calls.
- async chat with `blocking` response_mode
```python
import asyncio
from dify_client import AsyncChatClient
api_key = "your_api_key"
async def main():
# Use async context manager for proper resource cleanup
async with AsyncChatClient(api_key) as client:
response = await client.create_chat_message(
inputs={},
query="Hello, how are you?",
user="user_id",
response_mode="blocking"
)
response.raise_for_status()
result = response.json()
print(result.get('answer'))
# Run the async function
asyncio.run(main())
```
- async completion with `streaming` response_mode
```python
import asyncio
import json
from dify_client import AsyncCompletionClient
api_key = "your_api_key"
async def main():
async with AsyncCompletionClient(api_key) as client:
response = await client.create_completion_message(
inputs={"query": "What's the weather?"},
response_mode="streaming",
user="user_id"
)
response.raise_for_status()
# Stream the response
async for line in response.aiter_lines():
if line.startswith('data:'):
data = line[5:].strip()
if data:
chunk = json.loads(data)
print(chunk.get('answer', ''), end='', flush=True)
asyncio.run(main())
```
- async workflow execution
```python
import asyncio
from dify_client import AsyncWorkflowClient
api_key = "your_api_key"
async def main():
async with AsyncWorkflowClient(api_key) as client:
response = await client.run(
inputs={"query": "What is machine learning?"},
response_mode="blocking",
user="user_id"
)
response.raise_for_status()
result = response.json()
print(result.get("data").get("outputs"))
asyncio.run(main())
```
- async dataset management
```python
import asyncio
from dify_client import AsyncKnowledgeBaseClient
api_key = "your_api_key"
dataset_id = "your_dataset_id"
async def main():
async with AsyncKnowledgeBaseClient(api_key, dataset_id) as kb_client:
# Get dataset information
dataset_info = await kb_client.get_dataset()
dataset_info.raise_for_status()
print(dataset_info.json())
# List documents
docs = await kb_client.list_documents(page=1, page_size=10)
docs.raise_for_status()
print(docs.json())
asyncio.run(main())
```
**Benefits of Async Usage:**
- **Better Performance**: Handle multiple concurrent API requests efficiently
- **Non-blocking I/O**: Don't block the event loop during network operations
- **Scalability**: Ideal for applications handling many simultaneous requests
- **Modern Python**: Leverages Python's native async/await syntax
**Available Async Clients:**
- `AsyncDifyClient` - Base async client
- `AsyncChatClient` - Async chat operations
- `AsyncCompletionClient` - Async completion operations
- `AsyncWorkflowClient` - Async workflow operations
- `AsyncKnowledgeBaseClient` - Async dataset/knowledge base operations
- `AsyncWorkspaceClient` - Async workspace operations
```
```

View File

@ -7,11 +7,28 @@ from dify_client.client import (
WorkspaceClient,
)
from dify_client.async_client import (
AsyncChatClient,
AsyncCompletionClient,
AsyncDifyClient,
AsyncKnowledgeBaseClient,
AsyncWorkflowClient,
AsyncWorkspaceClient,
)
__all__ = [
# Synchronous clients
"ChatClient",
"CompletionClient",
"DifyClient",
"KnowledgeBaseClient",
"WorkflowClient",
"WorkspaceClient",
# Asynchronous clients
"AsyncChatClient",
"AsyncCompletionClient",
"AsyncDifyClient",
"AsyncKnowledgeBaseClient",
"AsyncWorkflowClient",
"AsyncWorkspaceClient",
]

View File

@ -0,0 +1,808 @@
"""Asynchronous Dify API client.
This module provides async/await support for all Dify API operations using httpx.AsyncClient.
All client classes mirror their synchronous counterparts but require `await` for method calls.
Example:
import asyncio
from dify_client import AsyncChatClient
async def main():
async with AsyncChatClient(api_key="your-key") as client:
response = await client.create_chat_message(
inputs={},
query="Hello",
user="user-123"
)
print(response.json())
asyncio.run(main())
"""
import json
import os
from typing import Literal, Dict, List, Any, IO
import aiofiles
import httpx
class AsyncDifyClient:
"""Asynchronous Dify API client.
This client uses httpx.AsyncClient for efficient async connection pooling.
It's recommended to use this client as a context manager:
Example:
async with AsyncDifyClient(api_key="your-key") as client:
response = await client.get_app_info()
"""
def __init__(
self,
api_key: str,
base_url: str = "https://api.dify.ai/v1",
timeout: float = 60.0,
):
"""Initialize the async Dify client.
Args:
api_key: Your Dify API key
base_url: Base URL for the Dify API
timeout: Request timeout in seconds (default: 60.0)
"""
self.api_key = api_key
self.base_url = base_url
self._client = httpx.AsyncClient(
base_url=base_url,
timeout=httpx.Timeout(timeout, connect=5.0),
)
async def __aenter__(self):
"""Support async context manager protocol."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Clean up resources when exiting async context."""
await self.aclose()
async def aclose(self):
"""Close the async HTTP client and release resources."""
if hasattr(self, "_client"):
await self._client.aclose()
async def _send_request(
self,
method: str,
endpoint: str,
json: dict | None = None,
params: dict | None = None,
stream: bool = False,
**kwargs,
):
"""Send an async HTTP request to the Dify API.
Args:
method: HTTP method (GET, POST, PUT, PATCH, DELETE)
endpoint: API endpoint path
json: JSON request body
params: Query parameters
stream: Whether to stream the response
**kwargs: Additional arguments to pass to httpx.request
Returns:
httpx.Response object
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
response = await self._client.request(
method,
endpoint,
json=json,
params=params,
headers=headers,
**kwargs,
)
return response
async def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict):
"""Send an async HTTP request with file uploads.
Args:
method: HTTP method (POST, PUT, etc.)
endpoint: API endpoint path
data: Form data
files: Files to upload
Returns:
httpx.Response object
"""
headers = {"Authorization": f"Bearer {self.api_key}"}
response = await self._client.request(
method,
endpoint,
data=data,
headers=headers,
files=files,
)
return response
async def message_feedback(self, message_id: str, rating: Literal["like", "dislike"], user: str):
"""Send feedback for a message."""
data = {"rating": rating, "user": user}
return await self._send_request("POST", f"/messages/{message_id}/feedbacks", data)
async def get_application_parameters(self, user: str):
"""Get application parameters."""
params = {"user": user}
return await self._send_request("GET", "/parameters", params=params)
async def file_upload(self, user: str, files: dict):
"""Upload a file."""
data = {"user": user}
return await self._send_request_with_files("POST", "/files/upload", data=data, files=files)
async def text_to_audio(self, text: str, user: str, streaming: bool = False):
"""Convert text to audio."""
data = {"text": text, "user": user, "streaming": streaming}
return await self._send_request("POST", "/text-to-audio", json=data)
async def get_meta(self, user: str):
"""Get metadata."""
params = {"user": user}
return await self._send_request("GET", "/meta", params=params)
async def get_app_info(self):
"""Get basic application information including name, description, tags, and mode."""
return await self._send_request("GET", "/info")
async def get_app_site_info(self):
"""Get application site information."""
return await self._send_request("GET", "/site")
async def get_file_preview(self, file_id: str):
"""Get file preview by file ID."""
return await self._send_request("GET", f"/files/{file_id}/preview")
class AsyncCompletionClient(AsyncDifyClient):
"""Async client for Completion API operations."""
async def create_completion_message(
self,
inputs: dict,
response_mode: Literal["blocking", "streaming"],
user: str,
files: dict | None = None,
):
"""Create a completion message.
Args:
inputs: Input variables for the completion
response_mode: Response mode ('blocking' or 'streaming')
user: User identifier
files: Optional files to include
Returns:
httpx.Response object
"""
data = {
"inputs": inputs,
"response_mode": response_mode,
"user": user,
"files": files,
}
return await self._send_request(
"POST",
"/completion-messages",
data,
stream=(response_mode == "streaming"),
)
class AsyncChatClient(AsyncDifyClient):
"""Async client for Chat API operations."""
async def create_chat_message(
self,
inputs: dict,
query: str,
user: str,
response_mode: Literal["blocking", "streaming"] = "blocking",
conversation_id: str | None = None,
files: dict | None = None,
):
"""Create a chat message.
Args:
inputs: Input variables for the chat
query: User query/message
user: User identifier
response_mode: Response mode ('blocking' or 'streaming')
conversation_id: Optional conversation ID for context
files: Optional files to include
Returns:
httpx.Response object
"""
data = {
"inputs": inputs,
"query": query,
"user": user,
"response_mode": response_mode,
"files": files,
}
if conversation_id:
data["conversation_id"] = conversation_id
return await self._send_request(
"POST",
"/chat-messages",
data,
stream=(response_mode == "streaming"),
)
async def get_suggested(self, message_id: str, user: str):
"""Get suggested questions for a message."""
params = {"user": user}
return await self._send_request("GET", f"/messages/{message_id}/suggested", params=params)
async def stop_message(self, task_id: str, user: str):
"""Stop a running message generation."""
data = {"user": user}
return await self._send_request("POST", f"/chat-messages/{task_id}/stop", data)
async def get_conversations(
self,
user: str,
last_id: str | None = None,
limit: int | None = None,
pinned: bool | None = None,
):
"""Get list of conversations."""
params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned}
return await self._send_request("GET", "/conversations", params=params)
async def get_conversation_messages(
self,
user: str,
conversation_id: str | None = None,
first_id: str | None = None,
limit: int | None = None,
):
"""Get messages from a conversation."""
params = {
"user": user,
"conversation_id": conversation_id,
"first_id": first_id,
"limit": limit,
}
return await self._send_request("GET", "/messages", params=params)
async def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str):
"""Rename a conversation."""
data = {"name": name, "auto_generate": auto_generate, "user": user}
return await self._send_request("POST", f"/conversations/{conversation_id}/name", data)
async def delete_conversation(self, conversation_id: str, user: str):
"""Delete a conversation."""
data = {"user": user}
return await self._send_request("DELETE", f"/conversations/{conversation_id}", data)
async def audio_to_text(self, audio_file: IO[bytes] | tuple, user: str):
"""Convert audio to text."""
data = {"user": user}
files = {"file": audio_file}
return await self._send_request_with_files("POST", "/audio-to-text", data, files)
# Annotation APIs
async def annotation_reply_action(
self,
action: Literal["enable", "disable"],
score_threshold: float,
embedding_provider_name: str,
embedding_model_name: str,
):
"""Enable or disable annotation reply feature."""
data = {
"score_threshold": score_threshold,
"embedding_provider_name": embedding_provider_name,
"embedding_model_name": embedding_model_name,
}
return await self._send_request("POST", f"/apps/annotation-reply/{action}", json=data)
async def get_annotation_reply_status(self, action: Literal["enable", "disable"], job_id: str):
"""Get the status of an annotation reply action job."""
return await self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}")
async def list_annotations(self, page: int = 1, limit: int = 20, keyword: str | None = None):
"""List annotations for the application."""
params = {"page": page, "limit": limit, "keyword": keyword}
return await self._send_request("GET", "/apps/annotations", params=params)
async def create_annotation(self, question: str, answer: str):
"""Create a new annotation."""
data = {"question": question, "answer": answer}
return await self._send_request("POST", "/apps/annotations", json=data)
async def update_annotation(self, annotation_id: str, question: str, answer: str):
"""Update an existing annotation."""
data = {"question": question, "answer": answer}
return await self._send_request("PUT", f"/apps/annotations/{annotation_id}", json=data)
async def delete_annotation(self, annotation_id: str):
"""Delete an annotation."""
return await self._send_request("DELETE", f"/apps/annotations/{annotation_id}")
# Conversation Variables APIs
async def get_conversation_variables(self, conversation_id: str, user: str):
"""Get all variables for a specific conversation.
Args:
conversation_id: The conversation ID to query variables for
user: User identifier
Returns:
Response from the API containing:
- variables: List of conversation variables with their values
- conversation_id: The conversation ID
"""
params = {"user": user}
url = f"/conversations/{conversation_id}/variables"
return await self._send_request("GET", url, params=params)
async def update_conversation_variable(self, conversation_id: str, variable_id: str, value: Any, user: str):
"""Update a specific conversation variable.
Args:
conversation_id: The conversation ID
variable_id: The variable ID to update
value: New value for the variable
user: User identifier
Returns:
Response from the API with updated variable information
"""
data = {"value": value, "user": user}
url = f"/conversations/{conversation_id}/variables/{variable_id}"
return await self._send_request("PATCH", url, json=data)
class AsyncWorkflowClient(AsyncDifyClient):
"""Async client for Workflow API operations."""
async def run(
self,
inputs: dict,
response_mode: Literal["blocking", "streaming"] = "streaming",
user: str = "abc-123",
):
"""Run a workflow."""
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
return await self._send_request("POST", "/workflows/run", data)
async def stop(self, task_id: str, user: str):
"""Stop a running workflow task."""
data = {"user": user}
return await self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data)
async def get_result(self, workflow_run_id: str):
"""Get workflow run result."""
return await self._send_request("GET", f"/workflows/run/{workflow_run_id}")
async def get_workflow_logs(
self,
keyword: str = None,
status: Literal["succeeded", "failed", "stopped"] | None = None,
page: int = 1,
limit: int = 20,
created_at__before: str = None,
created_at__after: str = None,
created_by_end_user_session_id: str = None,
created_by_account: str = None,
):
"""Get workflow execution logs with optional filtering."""
params = {
"page": page,
"limit": limit,
"keyword": keyword,
"status": status,
"created_at__before": created_at__before,
"created_at__after": created_at__after,
"created_by_end_user_session_id": created_by_end_user_session_id,
"created_by_account": created_by_account,
}
return await self._send_request("GET", "/workflows/logs", params=params)
async def run_specific_workflow(
self,
workflow_id: str,
inputs: dict,
response_mode: Literal["blocking", "streaming"] = "streaming",
user: str = "abc-123",
):
"""Run a specific workflow by workflow ID."""
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
return await self._send_request(
"POST",
f"/workflows/{workflow_id}/run",
data,
stream=(response_mode == "streaming"),
)
class AsyncWorkspaceClient(AsyncDifyClient):
"""Async client for workspace-related operations."""
async def get_available_models(self, model_type: str):
"""Get available models by model type."""
url = f"/workspaces/current/models/model-types/{model_type}"
return await self._send_request("GET", url)
class AsyncKnowledgeBaseClient(AsyncDifyClient):
"""Async client for Knowledge Base API operations."""
def __init__(
self,
api_key: str,
base_url: str = "https://api.dify.ai/v1",
dataset_id: str | None = None,
timeout: float = 60.0,
):
"""Construct an AsyncKnowledgeBaseClient object.
Args:
api_key: API key of Dify
base_url: Base URL of Dify API
dataset_id: ID of the dataset
timeout: Request timeout in seconds
"""
super().__init__(api_key=api_key, base_url=base_url, timeout=timeout)
self.dataset_id = dataset_id
def _get_dataset_id(self):
"""Get the dataset ID, raise error if not set."""
if self.dataset_id is None:
raise ValueError("dataset_id is not set")
return self.dataset_id
async def create_dataset(self, name: str, **kwargs):
"""Create a new dataset."""
return await self._send_request("POST", "/datasets", {"name": name}, **kwargs)
async def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
"""List all datasets."""
return await self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs)
async def create_document_by_text(self, name: str, text: str, extra_params: dict | None = None, **kwargs):
"""Create a document by text.
Args:
name: Name of the document
text: Text content of the document
extra_params: Extra parameters for the API
Returns:
Response from the API
"""
data = {
"indexing_technique": "high_quality",
"process_rule": {"mode": "automatic"},
"name": name,
"text": text,
}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
url = f"/datasets/{self._get_dataset_id()}/document/create_by_text"
return await self._send_request("POST", url, json=data, **kwargs)
async def update_document_by_text(
self,
document_id: str,
name: str,
text: str,
extra_params: dict | None = None,
**kwargs,
):
"""Update a document by text."""
data = {"name": name, "text": text}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
return await self._send_request("POST", url, json=data, **kwargs)
async def create_document_by_file(
self,
file_path: str,
original_document_id: str | None = None,
extra_params: dict | None = None,
):
"""Create a document by file."""
async with aiofiles.open(file_path, "rb") as f:
files = {"file": (os.path.basename(file_path), f)}
data = {
"process_rule": {"mode": "automatic"},
"indexing_technique": "high_quality",
}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
if original_document_id is not None:
data["original_document_id"] = original_document_id
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
return await self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
async def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None):
"""Update a document by file."""
async with aiofiles.open(file_path, "rb") as f:
files = {"file": (os.path.basename(file_path), f)}
data = {}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
return await self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
async def batch_indexing_status(self, batch_id: str, **kwargs):
"""Get the status of the batch indexing."""
url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status"
return await self._send_request("GET", url, **kwargs)
async def delete_dataset(self):
"""Delete this dataset."""
url = f"/datasets/{self._get_dataset_id()}"
return await self._send_request("DELETE", url)
async def delete_document(self, document_id: str):
"""Delete a document."""
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}"
return await self._send_request("DELETE", url)
async def list_documents(
self,
page: int | None = None,
page_size: int | None = None,
keyword: str | None = None,
**kwargs,
):
"""Get a list of documents in this dataset."""
params = {
"page": page,
"limit": page_size,
"keyword": keyword,
}
url = f"/datasets/{self._get_dataset_id()}/documents"
return await self._send_request("GET", url, params=params, **kwargs)
async def add_segments(self, document_id: str, segments: list[dict], **kwargs):
"""Add segments to a document."""
data = {"segments": segments}
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
return await self._send_request("POST", url, json=data, **kwargs)
async def query_segments(
self,
document_id: str,
keyword: str | None = None,
status: str | None = None,
**kwargs,
):
"""Query segments in this document.
Args:
document_id: ID of the document
keyword: Query keyword (optional)
status: Status of the segment (optional, e.g., 'completed')
**kwargs: Additional parameters to pass to the API.
Can include a 'params' dict for extra query parameters.
Returns:
Response from the API
"""
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
params = {
"keyword": keyword,
"status": status,
}
if "params" in kwargs:
params.update(kwargs.pop("params"))
return await self._send_request("GET", url, params=params, **kwargs)
async def delete_document_segment(self, document_id: str, segment_id: str):
"""Delete a segment from a document."""
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}"
return await self._send_request("DELETE", url)
async def update_document_segment(self, document_id: str, segment_id: str, segment_data: dict, **kwargs):
"""Update a segment in a document."""
data = {"segment": segment_data}
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}"
return await self._send_request("POST", url, json=data, **kwargs)
# Advanced Knowledge Base APIs
async def hit_testing(
self,
query: str,
retrieval_model: Dict[str, Any] = None,
external_retrieval_model: Dict[str, Any] = None,
):
"""Perform hit testing on the dataset."""
data = {"query": query}
if retrieval_model:
data["retrieval_model"] = retrieval_model
if external_retrieval_model:
data["external_retrieval_model"] = external_retrieval_model
url = f"/datasets/{self._get_dataset_id()}/hit-testing"
return await self._send_request("POST", url, json=data)
async def get_dataset_metadata(self):
"""Get dataset metadata."""
url = f"/datasets/{self._get_dataset_id()}/metadata"
return await self._send_request("GET", url)
async def create_dataset_metadata(self, metadata_data: Dict[str, Any]):
"""Create dataset metadata."""
url = f"/datasets/{self._get_dataset_id()}/metadata"
return await self._send_request("POST", url, json=metadata_data)
async def update_dataset_metadata(self, metadata_id: str, metadata_data: Dict[str, Any]):
"""Update dataset metadata."""
url = f"/datasets/{self._get_dataset_id()}/metadata/{metadata_id}"
return await self._send_request("PATCH", url, json=metadata_data)
async def get_built_in_metadata(self):
"""Get built-in metadata."""
url = f"/datasets/{self._get_dataset_id()}/metadata/built-in"
return await self._send_request("GET", url)
async def manage_built_in_metadata(self, action: str, metadata_data: Dict[str, Any] = None):
"""Manage built-in metadata with specified action."""
data = metadata_data or {}
url = f"/datasets/{self._get_dataset_id()}/metadata/built-in/{action}"
return await self._send_request("POST", url, json=data)
async def update_documents_metadata(self, operation_data: List[Dict[str, Any]]):
"""Update metadata for multiple documents."""
url = f"/datasets/{self._get_dataset_id()}/documents/metadata"
data = {"operation_data": operation_data}
return await self._send_request("POST", url, json=data)
# Dataset Tags APIs
async def list_dataset_tags(self):
"""List all dataset tags."""
return await self._send_request("GET", "/datasets/tags")
async def bind_dataset_tags(self, tag_ids: List[str]):
"""Bind tags to dataset."""
data = {"tag_ids": tag_ids, "target_id": self._get_dataset_id()}
return await self._send_request("POST", "/datasets/tags/binding", json=data)
async def unbind_dataset_tag(self, tag_id: str):
"""Unbind a single tag from dataset."""
data = {"tag_id": tag_id, "target_id": self._get_dataset_id()}
return await self._send_request("POST", "/datasets/tags/unbinding", json=data)
async def get_dataset_tags(self):
"""Get tags for current dataset."""
url = f"/datasets/{self._get_dataset_id()}/tags"
return await self._send_request("GET", url)
# RAG Pipeline APIs
async def get_datasource_plugins(self, is_published: bool = True):
"""Get datasource plugins for RAG pipeline."""
params = {"is_published": is_published}
url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource-plugins"
return await self._send_request("GET", url, params=params)
async def run_datasource_node(
self,
node_id: str,
inputs: Dict[str, Any],
datasource_type: str,
is_published: bool = True,
credential_id: str = None,
):
"""Run a datasource node in RAG pipeline."""
data = {
"inputs": inputs,
"datasource_type": datasource_type,
"is_published": is_published,
}
if credential_id:
data["credential_id"] = credential_id
url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run"
return await self._send_request("POST", url, json=data, stream=True)
async def run_rag_pipeline(
self,
inputs: Dict[str, Any],
datasource_type: str,
datasource_info_list: List[Dict[str, Any]],
start_node_id: str,
is_published: bool = True,
response_mode: Literal["streaming", "blocking"] = "blocking",
):
"""Run RAG pipeline."""
data = {
"inputs": inputs,
"datasource_type": datasource_type,
"datasource_info_list": datasource_info_list,
"start_node_id": start_node_id,
"is_published": is_published,
"response_mode": response_mode,
}
url = f"/datasets/{self._get_dataset_id()}/pipeline/run"
return await self._send_request("POST", url, json=data, stream=response_mode == "streaming")
async def upload_pipeline_file(self, file_path: str):
"""Upload file for RAG pipeline."""
async with aiofiles.open(file_path, "rb") as f:
files = {"file": (os.path.basename(file_path), f)}
return await self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files)
# Dataset Management APIs
async def get_dataset(self, dataset_id: str | None = None):
"""Get detailed information about a specific dataset."""
ds_id = dataset_id or self._get_dataset_id()
url = f"/datasets/{ds_id}"
return await self._send_request("GET", url)
async def update_dataset(
self,
dataset_id: str | None = None,
name: str | None = None,
description: str | None = None,
indexing_technique: str | None = None,
embedding_model: str | None = None,
embedding_model_provider: str | None = None,
retrieval_model: Dict[str, Any] | None = None,
**kwargs,
):
"""Update dataset configuration.
Args:
dataset_id: Dataset ID (optional, uses current dataset_id if not provided)
name: New dataset name
description: New dataset description
indexing_technique: Indexing technique ('high_quality' or 'economy')
embedding_model: Embedding model name
embedding_model_provider: Embedding model provider
retrieval_model: Retrieval model configuration dict
**kwargs: Additional parameters to pass to the API
Returns:
Response from the API with updated dataset information
"""
ds_id = dataset_id or self._get_dataset_id()
url = f"/datasets/{ds_id}"
payload = {
"name": name,
"description": description,
"indexing_technique": indexing_technique,
"embedding_model": embedding_model,
"embedding_model_provider": embedding_model_provider,
"retrieval_model": retrieval_model,
}
data = {k: v for k, v in payload.items() if v is not None}
data.update(kwargs)
return await self._send_request("PATCH", url, json=data)
async def batch_update_document_status(
self,
action: Literal["enable", "disable", "archive", "un_archive"],
document_ids: List[str],
dataset_id: str | None = None,
):
"""Batch update document status."""
ds_id = dataset_id or self._get_dataset_id()
url = f"/datasets/{ds_id}/documents/status/{action}"
data = {"document_ids": document_ids}
return await self._send_request("PATCH", url, json=data)

View File

@ -1,32 +1,114 @@
import json
from typing import Literal, Union, Dict, List, Any, Optional, IO
import os
from typing import Literal, Dict, List, Any, IO
import requests
import httpx
class DifyClient:
def __init__(self, api_key, base_url: str = "https://api.dify.ai/v1"):
"""Synchronous Dify API client.
This client uses httpx.Client for efficient connection pooling and resource management.
It's recommended to use this client as a context manager:
Example:
with DifyClient(api_key="your-key") as client:
response = client.get_app_info()
"""
def __init__(
self,
api_key: str,
base_url: str = "https://api.dify.ai/v1",
timeout: float = 60.0,
):
"""Initialize the Dify client.
Args:
api_key: Your Dify API key
base_url: Base URL for the Dify API
timeout: Request timeout in seconds (default: 60.0)
"""
self.api_key = api_key
self.base_url = base_url
self._client = httpx.Client(
base_url=base_url,
timeout=httpx.Timeout(timeout, connect=5.0),
)
def __enter__(self):
"""Support context manager protocol."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Clean up resources when exiting context."""
self.close()
def close(self):
"""Close the HTTP client and release resources."""
if hasattr(self, "_client"):
self._client.close()
def _send_request(
self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False
self,
method: str,
endpoint: str,
json: dict | None = None,
params: dict | None = None,
stream: bool = False,
**kwargs,
):
"""Send an HTTP request to the Dify API.
Args:
method: HTTP method (GET, POST, PUT, PATCH, DELETE)
endpoint: API endpoint path
json: JSON request body
params: Query parameters
stream: Whether to stream the response
**kwargs: Additional arguments to pass to httpx.request
Returns:
httpx.Response object
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
url = f"{self.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers, stream=stream)
# httpx.Client automatically prepends base_url
response = self._client.request(
method,
endpoint,
json=json,
params=params,
headers=headers,
**kwargs,
)
return response
def _send_request_with_files(self, method, endpoint, data, files):
def _send_request_with_files(self, method: str, endpoint: str, data: dict, files: dict):
"""Send an HTTP request with file uploads.
Args:
method: HTTP method (POST, PUT, etc.)
endpoint: API endpoint path
data: Form data
files: Files to upload
Returns:
httpx.Response object
"""
headers = {"Authorization": f"Bearer {self.api_key}"}
url = f"{self.base_url}{endpoint}"
response = requests.request(method, url, data=data, headers=headers, files=files)
response = self._client.request(
method,
endpoint,
data=data,
headers=headers,
files=files,
)
return response
@ -65,7 +147,11 @@ class DifyClient:
class CompletionClient(DifyClient):
def create_completion_message(
self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None
self,
inputs: dict,
response_mode: Literal["blocking", "streaming"],
user: str,
files: dict | None = None,
):
data = {
"inputs": inputs,
@ -77,7 +163,7 @@ class CompletionClient(DifyClient):
"POST",
"/completion-messages",
data,
stream=True if response_mode == "streaming" else False,
stream=(response_mode == "streaming"),
)
@ -105,7 +191,7 @@ class ChatClient(DifyClient):
"POST",
"/chat-messages",
data,
stream=True if response_mode == "streaming" else False,
stream=(response_mode == "streaming"),
)
def get_suggested(self, message_id: str, user: str):
@ -166,10 +252,6 @@ class ChatClient(DifyClient):
embedding_model_name: str,
):
"""Enable or disable annotation reply feature."""
# Backend API requires these fields to be non-None values
if score_threshold is None or embedding_provider_name is None or embedding_model_name is None:
raise ValueError("score_threshold, embedding_provider_name, and embedding_model_name cannot be None")
data = {
"score_threshold": score_threshold,
"embedding_provider_name": embedding_provider_name,
@ -181,11 +263,9 @@ class ChatClient(DifyClient):
"""Get the status of an annotation reply action job."""
return self._send_request("GET", f"/apps/annotation-reply/{action}/status/{job_id}")
def list_annotations(self, page: int = 1, limit: int = 20, keyword: str = ""):
def list_annotations(self, page: int = 1, limit: int = 20, keyword: str | None = None):
"""List annotations for the application."""
params = {"page": page, "limit": limit}
if keyword:
params["keyword"] = keyword
params = {"page": page, "limit": limit, "keyword": keyword}
return self._send_request("GET", "/apps/annotations", params=params)
def create_annotation(self, question: str, answer: str):
@ -202,9 +282,47 @@ class ChatClient(DifyClient):
"""Delete an annotation."""
return self._send_request("DELETE", f"/apps/annotations/{annotation_id}")
# Conversation Variables APIs
def get_conversation_variables(self, conversation_id: str, user: str):
"""Get all variables for a specific conversation.
Args:
conversation_id: The conversation ID to query variables for
user: User identifier
Returns:
Response from the API containing:
- variables: List of conversation variables with their values
- conversation_id: The conversation ID
"""
params = {"user": user}
url = f"/conversations/{conversation_id}/variables"
return self._send_request("GET", url, params=params)
def update_conversation_variable(self, conversation_id: str, variable_id: str, value: Any, user: str):
"""Update a specific conversation variable.
Args:
conversation_id: The conversation ID
variable_id: The variable ID to update
value: New value for the variable
user: User identifier
Returns:
Response from the API with updated variable information
"""
data = {"value": value, "user": user}
url = f"/conversations/{conversation_id}/variables/{variable_id}"
return self._send_request("PATCH", url, json=data)
class WorkflowClient(DifyClient):
def run(self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"):
def run(
self,
inputs: dict,
response_mode: Literal["blocking", "streaming"] = "streaming",
user: str = "abc-123",
):
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
return self._send_request("POST", "/workflows/run", data)
@ -252,7 +370,10 @@ class WorkflowClient(DifyClient):
"""Run a specific workflow by workflow ID."""
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
return self._send_request(
"POST", f"/workflows/{workflow_id}/run", data, stream=True if response_mode == "streaming" else False
"POST",
f"/workflows/{workflow_id}/run",
data,
stream=(response_mode == "streaming"),
)
@ -293,7 +414,7 @@ class KnowledgeBaseClient(DifyClient):
return self._send_request("POST", "/datasets", {"name": name}, **kwargs)
def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
return self._send_request("GET", f"/datasets?page={page}&limit={page_size}", **kwargs)
return self._send_request("GET", "/datasets", params={"page": page, "limit": page_size}, **kwargs)
def create_document_by_text(self, name, text, extra_params: dict | None = None, **kwargs):
"""
@ -333,7 +454,12 @@ class KnowledgeBaseClient(DifyClient):
return self._send_request("POST", url, json=data, **kwargs)
def update_document_by_text(
self, document_id: str, name: str, text: str, extra_params: dict | None = None, **kwargs
self,
document_id: str,
name: str,
text: str,
extra_params: dict | None = None,
**kwargs,
):
"""
Update a document by text.
@ -368,7 +494,10 @@ class KnowledgeBaseClient(DifyClient):
return self._send_request("POST", url, json=data, **kwargs)
def create_document_by_file(
self, file_path: str, original_document_id: str | None = None, extra_params: dict | None = None
self,
file_path: str,
original_document_id: str | None = None,
extra_params: dict | None = None,
):
"""
Create a document by file.
@ -395,17 +524,18 @@ class KnowledgeBaseClient(DifyClient):
}
:return: Response from the API
"""
files = {"file": open(file_path, "rb")}
data = {
"process_rule": {"mode": "automatic"},
"indexing_technique": "high_quality",
}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
if original_document_id is not None:
data["original_document_id"] = original_document_id
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
with open(file_path, "rb") as f:
files = {"file": (os.path.basename(file_path), f)}
data = {
"process_rule": {"mode": "automatic"},
"indexing_technique": "high_quality",
}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
if original_document_id is not None:
data["original_document_id"] = original_document_id
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None):
"""
@ -433,12 +563,13 @@ class KnowledgeBaseClient(DifyClient):
}
:return:
"""
files = {"file": open(file_path, "rb")}
data = {}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
with open(file_path, "rb") as f:
files = {"file": (os.path.basename(file_path), f)}
data = {}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
def batch_indexing_status(self, batch_id: str, **kwargs):
"""
@ -516,6 +647,8 @@ class KnowledgeBaseClient(DifyClient):
:param document_id: ID of the document
:param keyword: query keyword, optional
:param status: status of the segment, optional, e.g. completed
:param kwargs: Additional parameters to pass to the API.
Can include a 'params' dict for extra query parameters.
"""
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
params = {}
@ -524,7 +657,7 @@ class KnowledgeBaseClient(DifyClient):
if status is not None:
params["status"] = status
if "params" in kwargs:
params.update(kwargs["params"])
params.update(kwargs.pop("params"))
return self._send_request("GET", url, params=params, **kwargs)
def delete_document_segment(self, document_id: str, segment_id: str):
@ -553,7 +686,10 @@ class KnowledgeBaseClient(DifyClient):
# Advanced Knowledge Base APIs
def hit_testing(
self, query: str, retrieval_model: Dict[str, Any] = None, external_retrieval_model: Dict[str, Any] = None
self,
query: str,
retrieval_model: Dict[str, Any] = None,
external_retrieval_model: Dict[str, Any] = None,
):
"""Perform hit testing on the dataset."""
data = {"query": query}
@ -632,7 +768,11 @@ class KnowledgeBaseClient(DifyClient):
credential_id: str = None,
):
"""Run a datasource node in RAG pipeline."""
data = {"inputs": inputs, "datasource_type": datasource_type, "is_published": is_published}
data = {
"inputs": inputs,
"datasource_type": datasource_type,
"is_published": is_published,
}
if credential_id:
data["credential_id"] = credential_id
url = f"/datasets/{self._get_dataset_id()}/pipeline/datasource/nodes/{node_id}/run"
@ -662,5 +802,94 @@ class KnowledgeBaseClient(DifyClient):
def upload_pipeline_file(self, file_path: str):
"""Upload file for RAG pipeline."""
with open(file_path, "rb") as f:
files = {"file": f}
files = {"file": (os.path.basename(file_path), f)}
return self._send_request_with_files("POST", "/datasets/pipeline/file-upload", {}, files)
# Dataset Management APIs
def get_dataset(self, dataset_id: str | None = None):
"""Get detailed information about a specific dataset.
Args:
dataset_id: Dataset ID (optional, uses current dataset_id if not provided)
Returns:
Response from the API containing dataset details including:
- name, description, permission
- indexing_technique, embedding_model, embedding_model_provider
- retrieval_model configuration
- document_count, word_count, app_count
- created_at, updated_at
"""
ds_id = dataset_id or self._get_dataset_id()
url = f"/datasets/{ds_id}"
return self._send_request("GET", url)
def update_dataset(
self,
dataset_id: str | None = None,
name: str | None = None,
description: str | None = None,
indexing_technique: str | None = None,
embedding_model: str | None = None,
embedding_model_provider: str | None = None,
retrieval_model: Dict[str, Any] | None = None,
**kwargs,
):
"""Update dataset configuration.
Args:
dataset_id: Dataset ID (optional, uses current dataset_id if not provided)
name: New dataset name
description: New dataset description
indexing_technique: Indexing technique ('high_quality' or 'economy')
embedding_model: Embedding model name
embedding_model_provider: Embedding model provider
retrieval_model: Retrieval model configuration dict
**kwargs: Additional parameters to pass to the API
Returns:
Response from the API with updated dataset information
"""
ds_id = dataset_id or self._get_dataset_id()
url = f"/datasets/{ds_id}"
# Build data dictionary with all possible parameters
payload = {
"name": name,
"description": description,
"indexing_technique": indexing_technique,
"embedding_model": embedding_model,
"embedding_model_provider": embedding_model_provider,
"retrieval_model": retrieval_model,
}
# Filter out None values and merge with additional kwargs
data = {k: v for k, v in payload.items() if v is not None}
data.update(kwargs)
return self._send_request("PATCH", url, json=data)
def batch_update_document_status(
self,
action: Literal["enable", "disable", "archive", "un_archive"],
document_ids: List[str],
dataset_id: str | None = None,
):
"""Batch update document status (enable/disable/archive/unarchive).
Args:
action: Action to perform on documents
- 'enable': Enable documents for retrieval
- 'disable': Disable documents from retrieval
- 'archive': Archive documents
- 'un_archive': Unarchive documents
document_ids: List of document IDs to update
dataset_id: Dataset ID (optional, uses current dataset_id if not provided)
Returns:
Response from the API with operation result
"""
ds_id = dataset_id or self._get_dataset_id()
url = f"/datasets/{ds_id}/documents/status/{action}"
data = {"document_ids": document_ids}
return self._send_request("PATCH", url, json=data)

View File

@ -0,0 +1,43 @@
[project]
name = "dify-client"
version = "0.1.12"
description = "A package for interacting with the Dify Service-API"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"httpx>=0.27.0",
"aiofiles>=23.0.0",
]
authors = [
{name = "Dify", email = "hello@dify.ai"}
]
license = {text = "MIT"}
keywords = ["dify", "nlp", "ai", "language-processing"]
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
[project.urls]
Homepage = "https://github.com/langgenius/dify"
[project.optional-dependencies]
dev = [
"pytest>=7.0.0",
"pytest-asyncio>=0.21.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["dify_client"]
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
asyncio_mode = "auto"

View File

@ -1,26 +0,0 @@
from setuptools import setup
with open("README.md", encoding="utf-8") as fh:
long_description = fh.read()
setup(
name="dify-client",
version="0.1.12",
author="Dify",
author_email="hello@dify.ai",
description="A package for interacting with the Dify Service-API",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/langgenius/dify",
license="MIT",
packages=["dify_client"],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires=">=3.6",
install_requires=["requests"],
keywords="dify nlp ai language-processing",
include_package_data=True,
)

View File

@ -0,0 +1,250 @@
#!/usr/bin/env python3
"""
Test suite for async client implementation in the Python SDK.
This test validates the async/await functionality using httpx.AsyncClient
and ensures API parity with sync clients.
"""
import unittest
from unittest.mock import Mock, patch, AsyncMock
from dify_client.async_client import (
AsyncDifyClient,
AsyncChatClient,
AsyncCompletionClient,
AsyncWorkflowClient,
AsyncWorkspaceClient,
AsyncKnowledgeBaseClient,
)
class TestAsyncAPIParity(unittest.TestCase):
"""Test that async clients have API parity with sync clients."""
def test_dify_client_api_parity(self):
"""Test AsyncDifyClient has same methods as DifyClient."""
from dify_client import DifyClient
sync_methods = {name for name in dir(DifyClient) if not name.startswith("_")}
async_methods = {name for name in dir(AsyncDifyClient) if not name.startswith("_")}
# aclose is async-specific, close is sync-specific
sync_methods.discard("close")
async_methods.discard("aclose")
# Verify parity
self.assertEqual(sync_methods, async_methods, "API parity mismatch for DifyClient")
def test_chat_client_api_parity(self):
"""Test AsyncChatClient has same methods as ChatClient."""
from dify_client import ChatClient
sync_methods = {name for name in dir(ChatClient) if not name.startswith("_")}
async_methods = {name for name in dir(AsyncChatClient) if not name.startswith("_")}
sync_methods.discard("close")
async_methods.discard("aclose")
self.assertEqual(sync_methods, async_methods, "API parity mismatch for ChatClient")
def test_completion_client_api_parity(self):
"""Test AsyncCompletionClient has same methods as CompletionClient."""
from dify_client import CompletionClient
sync_methods = {name for name in dir(CompletionClient) if not name.startswith("_")}
async_methods = {name for name in dir(AsyncCompletionClient) if not name.startswith("_")}
sync_methods.discard("close")
async_methods.discard("aclose")
self.assertEqual(sync_methods, async_methods, "API parity mismatch for CompletionClient")
def test_workflow_client_api_parity(self):
"""Test AsyncWorkflowClient has same methods as WorkflowClient."""
from dify_client import WorkflowClient
sync_methods = {name for name in dir(WorkflowClient) if not name.startswith("_")}
async_methods = {name for name in dir(AsyncWorkflowClient) if not name.startswith("_")}
sync_methods.discard("close")
async_methods.discard("aclose")
self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkflowClient")
def test_workspace_client_api_parity(self):
"""Test AsyncWorkspaceClient has same methods as WorkspaceClient."""
from dify_client import WorkspaceClient
sync_methods = {name for name in dir(WorkspaceClient) if not name.startswith("_")}
async_methods = {name for name in dir(AsyncWorkspaceClient) if not name.startswith("_")}
sync_methods.discard("close")
async_methods.discard("aclose")
self.assertEqual(sync_methods, async_methods, "API parity mismatch for WorkspaceClient")
def test_knowledge_base_client_api_parity(self):
"""Test AsyncKnowledgeBaseClient has same methods as KnowledgeBaseClient."""
from dify_client import KnowledgeBaseClient
sync_methods = {name for name in dir(KnowledgeBaseClient) if not name.startswith("_")}
async_methods = {name for name in dir(AsyncKnowledgeBaseClient) if not name.startswith("_")}
sync_methods.discard("close")
async_methods.discard("aclose")
self.assertEqual(sync_methods, async_methods, "API parity mismatch for KnowledgeBaseClient")
class TestAsyncClientMocked(unittest.IsolatedAsyncioTestCase):
"""Test async client with mocked httpx.AsyncClient."""
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_client_initialization(self, mock_httpx_async_client):
"""Test async client initializes with httpx.AsyncClient."""
mock_client_instance = AsyncMock()
mock_httpx_async_client.return_value = mock_client_instance
client = AsyncDifyClient("test-key", "https://api.dify.ai/v1")
# Verify httpx.AsyncClient was called
mock_httpx_async_client.assert_called_once()
self.assertEqual(client.api_key, "test-key")
await client.aclose()
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_context_manager(self, mock_httpx_async_client):
"""Test async context manager works."""
mock_client_instance = AsyncMock()
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncDifyClient("test-key") as client:
self.assertEqual(client.api_key, "test-key")
# Verify aclose was called
mock_client_instance.aclose.assert_called_once()
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_send_request(self, mock_httpx_async_client):
"""Test async _send_request method."""
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={"result": "success"})
mock_response.status_code = 200
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncDifyClient("test-key") as client:
response = await client._send_request("GET", "/test")
# Verify request was called
mock_client_instance.request.assert_called_once()
call_args = mock_client_instance.request.call_args
# Verify parameters
self.assertEqual(call_args[0][0], "GET")
self.assertEqual(call_args[0][1], "/test")
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_chat_client(self, mock_httpx_async_client):
"""Test AsyncChatClient functionality."""
mock_response = AsyncMock()
mock_response.text = '{"answer": "Hello!"}'
mock_response.json = AsyncMock(return_value={"answer": "Hello!"})
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncChatClient("test-key") as client:
response = await client.create_chat_message({}, "Hi", "user123")
self.assertIn("answer", response.text)
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_completion_client(self, mock_httpx_async_client):
"""Test AsyncCompletionClient functionality."""
mock_response = AsyncMock()
mock_response.text = '{"answer": "Response"}'
mock_response.json = AsyncMock(return_value={"answer": "Response"})
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncCompletionClient("test-key") as client:
response = await client.create_completion_message({"query": "test"}, "blocking", "user123")
self.assertIn("answer", response.text)
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_workflow_client(self, mock_httpx_async_client):
"""Test AsyncWorkflowClient functionality."""
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={"result": "success"})
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncWorkflowClient("test-key") as client:
response = await client.run({"input": "test"}, "blocking", "user123")
data = await response.json()
self.assertEqual(data["result"], "success")
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_workspace_client(self, mock_httpx_async_client):
"""Test AsyncWorkspaceClient functionality."""
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={"data": []})
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncWorkspaceClient("test-key") as client:
response = await client.get_available_models("llm")
data = await response.json()
self.assertIn("data", data)
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_async_knowledge_base_client(self, mock_httpx_async_client):
"""Test AsyncKnowledgeBaseClient functionality."""
mock_response = AsyncMock()
mock_response.json = AsyncMock(return_value={"data": [], "total": 0})
mock_client_instance = AsyncMock()
mock_client_instance.request = AsyncMock(return_value=mock_response)
mock_httpx_async_client.return_value = mock_client_instance
async with AsyncKnowledgeBaseClient("test-key") as client:
response = await client.list_datasets()
data = await response.json()
self.assertIn("data", data)
@patch("dify_client.async_client.httpx.AsyncClient")
async def test_all_async_client_classes(self, mock_httpx_async_client):
"""Test all async client classes work with httpx.AsyncClient."""
mock_client_instance = AsyncMock()
mock_httpx_async_client.return_value = mock_client_instance
clients = [
AsyncDifyClient("key"),
AsyncChatClient("key"),
AsyncCompletionClient("key"),
AsyncWorkflowClient("key"),
AsyncWorkspaceClient("key"),
AsyncKnowledgeBaseClient("key"),
]
# Verify httpx.AsyncClient was called for each
self.assertEqual(mock_httpx_async_client.call_count, 6)
# Clean up
for client in clients:
await client.aclose()
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,331 @@
#!/usr/bin/env python3
"""
Test suite for httpx migration in the Python SDK.
This test validates that the migration from requests to httpx maintains
backward compatibility and proper resource management.
"""
import unittest
from unittest.mock import Mock, patch
from dify_client import (
DifyClient,
ChatClient,
CompletionClient,
WorkflowClient,
WorkspaceClient,
KnowledgeBaseClient,
)
class TestHttpxMigrationMocked(unittest.TestCase):
"""Test cases for httpx migration with mocked requests."""
def setUp(self):
"""Set up test fixtures."""
self.api_key = "test-api-key"
self.base_url = "https://api.dify.ai/v1"
@patch("dify_client.client.httpx.Client")
def test_client_initialization(self, mock_httpx_client):
"""Test that client initializes with httpx.Client."""
mock_client_instance = Mock()
mock_httpx_client.return_value = mock_client_instance
client = DifyClient(self.api_key, self.base_url)
# Verify httpx.Client was called with correct parameters
mock_httpx_client.assert_called_once()
call_kwargs = mock_httpx_client.call_args[1]
self.assertEqual(call_kwargs["base_url"], self.base_url)
# Verify client properties
self.assertEqual(client.api_key, self.api_key)
self.assertEqual(client.base_url, self.base_url)
client.close()
@patch("dify_client.client.httpx.Client")
def test_context_manager_support(self, mock_httpx_client):
"""Test that client works as context manager."""
mock_client_instance = Mock()
mock_httpx_client.return_value = mock_client_instance
with DifyClient(self.api_key, self.base_url) as client:
self.assertEqual(client.api_key, self.api_key)
# Verify close was called
mock_client_instance.close.assert_called_once()
@patch("dify_client.client.httpx.Client")
def test_manual_close(self, mock_httpx_client):
"""Test manual close() method."""
mock_client_instance = Mock()
mock_httpx_client.return_value = mock_client_instance
client = DifyClient(self.api_key, self.base_url)
client.close()
# Verify close was called
mock_client_instance.close.assert_called_once()
@patch("dify_client.client.httpx.Client")
def test_send_request_httpx_compatibility(self, mock_httpx_client):
"""Test _send_request uses httpx.Client.request properly."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
client = DifyClient(self.api_key, self.base_url)
response = client._send_request("GET", "/test-endpoint")
# Verify httpx.Client.request was called correctly
mock_client_instance.request.assert_called_once()
call_args = mock_client_instance.request.call_args
# Verify method and endpoint
self.assertEqual(call_args[0][0], "GET")
self.assertEqual(call_args[0][1], "/test-endpoint")
# Verify headers contain authorization
headers = call_args[1]["headers"]
self.assertEqual(headers["Authorization"], f"Bearer {self.api_key}")
self.assertEqual(headers["Content-Type"], "application/json")
client.close()
@patch("dify_client.client.httpx.Client")
def test_response_compatibility(self, mock_httpx_client):
"""Test httpx.Response is compatible with requests.Response API."""
mock_response = Mock()
mock_response.json.return_value = {"key": "value"}
mock_response.text = '{"key": "value"}'
mock_response.content = b'{"key": "value"}'
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
client = DifyClient(self.api_key, self.base_url)
response = client._send_request("GET", "/test")
# Verify all common response methods work
self.assertEqual(response.json(), {"key": "value"})
self.assertEqual(response.text, '{"key": "value"}')
self.assertEqual(response.content, b'{"key": "value"}')
self.assertEqual(response.status_code, 200)
self.assertEqual(response.headers["Content-Type"], "application/json")
client.close()
@patch("dify_client.client.httpx.Client")
def test_all_client_classes_use_httpx(self, mock_httpx_client):
"""Test that all client classes properly use httpx."""
mock_client_instance = Mock()
mock_httpx_client.return_value = mock_client_instance
clients = [
DifyClient(self.api_key, self.base_url),
ChatClient(self.api_key, self.base_url),
CompletionClient(self.api_key, self.base_url),
WorkflowClient(self.api_key, self.base_url),
WorkspaceClient(self.api_key, self.base_url),
KnowledgeBaseClient(self.api_key, self.base_url),
]
# Verify httpx.Client was called for each client
self.assertEqual(mock_httpx_client.call_count, 6)
# Clean up
for client in clients:
client.close()
@patch("dify_client.client.httpx.Client")
def test_json_parameter_handling(self, mock_httpx_client):
"""Test that json parameter is passed correctly."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
client = DifyClient(self.api_key, self.base_url)
test_data = {"key": "value", "number": 123}
client._send_request("POST", "/test", json=test_data)
# Verify json parameter was passed
call_args = mock_client_instance.request.call_args
self.assertEqual(call_args[1]["json"], test_data)
client.close()
@patch("dify_client.client.httpx.Client")
def test_params_parameter_handling(self, mock_httpx_client):
"""Test that params parameter is passed correctly."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
client = DifyClient(self.api_key, self.base_url)
test_params = {"page": 1, "limit": 20}
client._send_request("GET", "/test", params=test_params)
# Verify params parameter was passed
call_args = mock_client_instance.request.call_args
self.assertEqual(call_args[1]["params"], test_params)
client.close()
@patch("dify_client.client.httpx.Client")
def test_inheritance_chain(self, mock_httpx_client):
"""Test that inheritance chain is maintained."""
mock_client_instance = Mock()
mock_httpx_client.return_value = mock_client_instance
# ChatClient inherits from DifyClient
chat_client = ChatClient(self.api_key, self.base_url)
self.assertIsInstance(chat_client, DifyClient)
# CompletionClient inherits from DifyClient
completion_client = CompletionClient(self.api_key, self.base_url)
self.assertIsInstance(completion_client, DifyClient)
# WorkflowClient inherits from DifyClient
workflow_client = WorkflowClient(self.api_key, self.base_url)
self.assertIsInstance(workflow_client, DifyClient)
# Clean up
chat_client.close()
completion_client.close()
workflow_client.close()
@patch("dify_client.client.httpx.Client")
def test_nested_context_managers(self, mock_httpx_client):
"""Test nested context managers work correctly."""
mock_client_instance = Mock()
mock_httpx_client.return_value = mock_client_instance
with DifyClient(self.api_key, self.base_url) as client1:
with ChatClient(self.api_key, self.base_url) as client2:
self.assertEqual(client1.api_key, self.api_key)
self.assertEqual(client2.api_key, self.api_key)
# Both close methods should have been called
self.assertEqual(mock_client_instance.close.call_count, 2)
class TestChatClientHttpx(unittest.TestCase):
"""Test ChatClient specific httpx integration."""
@patch("dify_client.client.httpx.Client")
def test_create_chat_message_httpx(self, mock_httpx_client):
"""Test create_chat_message works with httpx."""
mock_response = Mock()
mock_response.text = '{"answer": "Hello!"}'
mock_response.json.return_value = {"answer": "Hello!"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
with ChatClient("test-key") as client:
response = client.create_chat_message({}, "Hi", "user123")
self.assertIn("answer", response.text)
self.assertEqual(response.json()["answer"], "Hello!")
class TestCompletionClientHttpx(unittest.TestCase):
"""Test CompletionClient specific httpx integration."""
@patch("dify_client.client.httpx.Client")
def test_create_completion_message_httpx(self, mock_httpx_client):
"""Test create_completion_message works with httpx."""
mock_response = Mock()
mock_response.text = '{"answer": "Response"}'
mock_response.json.return_value = {"answer": "Response"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
with CompletionClient("test-key") as client:
response = client.create_completion_message({"query": "test"}, "blocking", "user123")
self.assertIn("answer", response.text)
class TestKnowledgeBaseClientHttpx(unittest.TestCase):
"""Test KnowledgeBaseClient specific httpx integration."""
@patch("dify_client.client.httpx.Client")
def test_list_datasets_httpx(self, mock_httpx_client):
"""Test list_datasets works with httpx."""
mock_response = Mock()
mock_response.json.return_value = {"data": [], "total": 0}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
with KnowledgeBaseClient("test-key") as client:
response = client.list_datasets()
data = response.json()
self.assertIn("data", data)
self.assertIn("total", data)
class TestWorkflowClientHttpx(unittest.TestCase):
"""Test WorkflowClient specific httpx integration."""
@patch("dify_client.client.httpx.Client")
def test_run_workflow_httpx(self, mock_httpx_client):
"""Test run workflow works with httpx."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
with WorkflowClient("test-key") as client:
response = client.run({"input": "test"}, "blocking", "user123")
self.assertEqual(response.json()["result"], "success")
class TestWorkspaceClientHttpx(unittest.TestCase):
"""Test WorkspaceClient specific httpx integration."""
@patch("dify_client.client.httpx.Client")
def test_get_available_models_httpx(self, mock_httpx_client):
"""Test get_available_models works with httpx."""
mock_response = Mock()
mock_response.json.return_value = {"data": []}
mock_response.status_code = 200
mock_client_instance = Mock()
mock_client_instance.request.return_value = mock_response
mock_httpx_client.return_value = mock_client_instance
with WorkspaceClient("test-key") as client:
response = client.get_available_models("llm")
self.assertIn("data", response.json())
if __name__ == "__main__":
unittest.main()

View File

@ -1,416 +0,0 @@
#!/usr/bin/env python3
"""
Test suite for the new Service API functionality in the Python SDK.
This test validates the implementation of the missing Service API endpoints
that were added to the Python SDK to achieve complete coverage.
"""
import unittest
from unittest.mock import Mock, patch, MagicMock
import json
from dify_client import (
DifyClient,
ChatClient,
WorkflowClient,
KnowledgeBaseClient,
WorkspaceClient,
)
class TestNewServiceAPIs(unittest.TestCase):
"""Test cases for new Service API implementations."""
def setUp(self):
"""Set up test fixtures."""
self.api_key = "test-api-key"
self.base_url = "https://api.dify.ai/v1"
@patch("dify_client.client.requests.request")
def test_app_info_apis(self, mock_request):
"""Test application info APIs."""
mock_response = Mock()
mock_response.json.return_value = {
"name": "Test App",
"description": "Test Description",
"tags": ["test", "api"],
"mode": "chat",
"author_name": "Test Author",
}
mock_request.return_value = mock_response
client = DifyClient(self.api_key, self.base_url)
# Test get_app_info
result = client.get_app_info()
mock_request.assert_called_with(
"GET",
f"{self.base_url}/info",
json=None,
params=None,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
stream=False,
)
# Test get_app_site_info
client.get_app_site_info()
mock_request.assert_called_with(
"GET",
f"{self.base_url}/site",
json=None,
params=None,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
stream=False,
)
# Test get_file_preview
file_id = "test-file-id"
client.get_file_preview(file_id)
mock_request.assert_called_with(
"GET",
f"{self.base_url}/files/{file_id}/preview",
json=None,
params=None,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
stream=False,
)
@patch("dify_client.client.requests.request")
def test_annotation_apis(self, mock_request):
"""Test annotation APIs."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_request.return_value = mock_response
client = ChatClient(self.api_key, self.base_url)
# Test annotation_reply_action - enable
client.annotation_reply_action(
action="enable",
score_threshold=0.8,
embedding_provider_name="openai",
embedding_model_name="text-embedding-ada-002",
)
mock_request.assert_called_with(
"POST",
f"{self.base_url}/apps/annotation-reply/enable",
json={
"score_threshold": 0.8,
"embedding_provider_name": "openai",
"embedding_model_name": "text-embedding-ada-002",
},
params=None,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
stream=False,
)
# Test annotation_reply_action - disable (now requires same fields as enable)
client.annotation_reply_action(
action="disable",
score_threshold=0.5,
embedding_provider_name="openai",
embedding_model_name="text-embedding-ada-002",
)
# Test annotation_reply_action with score_threshold=0 (edge case)
client.annotation_reply_action(
action="enable",
score_threshold=0.0, # This should work and not raise ValueError
embedding_provider_name="openai",
embedding_model_name="text-embedding-ada-002",
)
# Test get_annotation_reply_status
client.get_annotation_reply_status("enable", "job-123")
# Test list_annotations
client.list_annotations(page=1, limit=20, keyword="test")
# Test create_annotation
client.create_annotation("Test question?", "Test answer.")
# Test update_annotation
client.update_annotation("annotation-123", "Updated question?", "Updated answer.")
# Test delete_annotation
client.delete_annotation("annotation-123")
# Verify all calls were made (8 calls: enable + disable + enable with 0.0 + 5 other operations)
self.assertEqual(mock_request.call_count, 8)
@patch("dify_client.client.requests.request")
def test_knowledge_base_advanced_apis(self, mock_request):
"""Test advanced knowledge base APIs."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_request.return_value = mock_response
dataset_id = "test-dataset-id"
client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id)
# Test hit_testing
client.hit_testing("test query", {"type": "vector"})
mock_request.assert_called_with(
"POST",
f"{self.base_url}/datasets/{dataset_id}/hit-testing",
json={"query": "test query", "retrieval_model": {"type": "vector"}},
params=None,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
stream=False,
)
# Test metadata operations
client.get_dataset_metadata()
client.create_dataset_metadata({"key": "value"})
client.update_dataset_metadata("meta-123", {"key": "new_value"})
client.get_built_in_metadata()
client.manage_built_in_metadata("enable", {"type": "built_in"})
client.update_documents_metadata([{"document_id": "doc1", "metadata": {"key": "value"}}])
# Test tag operations
client.list_dataset_tags()
client.bind_dataset_tags(["tag1", "tag2"])
client.unbind_dataset_tag("tag1")
client.get_dataset_tags()
# Verify multiple calls were made
self.assertGreater(mock_request.call_count, 5)
@patch("dify_client.client.requests.request")
def test_rag_pipeline_apis(self, mock_request):
"""Test RAG pipeline APIs."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_request.return_value = mock_response
dataset_id = "test-dataset-id"
client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id)
# Test get_datasource_plugins
client.get_datasource_plugins(is_published=True)
mock_request.assert_called_with(
"GET",
f"{self.base_url}/datasets/{dataset_id}/pipeline/datasource-plugins",
json=None,
params={"is_published": True},
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
stream=False,
)
# Test run_datasource_node
client.run_datasource_node(
node_id="node-123",
inputs={"param": "value"},
datasource_type="online_document",
is_published=True,
credential_id="cred-123",
)
# Test run_rag_pipeline with blocking mode
client.run_rag_pipeline(
inputs={"query": "test"},
datasource_type="online_document",
datasource_info_list=[{"id": "ds1"}],
start_node_id="start-node",
is_published=True,
response_mode="blocking",
)
# Test run_rag_pipeline with streaming mode
client.run_rag_pipeline(
inputs={"query": "test"},
datasource_type="online_document",
datasource_info_list=[{"id": "ds1"}],
start_node_id="start-node",
is_published=True,
response_mode="streaming",
)
self.assertEqual(mock_request.call_count, 4)
@patch("dify_client.client.requests.request")
def test_workspace_apis(self, mock_request):
"""Test workspace APIs."""
mock_response = Mock()
mock_response.json.return_value = {
"data": [{"name": "gpt-3.5-turbo", "type": "llm"}, {"name": "gpt-4", "type": "llm"}]
}
mock_request.return_value = mock_response
client = WorkspaceClient(self.api_key, self.base_url)
# Test get_available_models
result = client.get_available_models("llm")
mock_request.assert_called_with(
"GET",
f"{self.base_url}/workspaces/current/models/model-types/llm",
json=None,
params=None,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
stream=False,
)
@patch("dify_client.client.requests.request")
def test_workflow_advanced_apis(self, mock_request):
"""Test advanced workflow APIs."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_request.return_value = mock_response
client = WorkflowClient(self.api_key, self.base_url)
# Test get_workflow_logs
client.get_workflow_logs(keyword="test", status="succeeded", page=1, limit=20)
mock_request.assert_called_with(
"GET",
f"{self.base_url}/workflows/logs",
json=None,
params={"page": 1, "limit": 20, "keyword": "test", "status": "succeeded"},
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
stream=False,
)
# Test get_workflow_logs with additional filters
client.get_workflow_logs(
keyword="test",
status="succeeded",
page=1,
limit=20,
created_at__before="2024-01-01",
created_at__after="2023-01-01",
created_by_account="user123",
)
# Test run_specific_workflow
client.run_specific_workflow(
workflow_id="workflow-123", inputs={"param": "value"}, response_mode="streaming", user="user-123"
)
self.assertEqual(mock_request.call_count, 3)
def test_error_handling(self):
"""Test error handling for required parameters."""
client = ChatClient(self.api_key, self.base_url)
# Test annotation_reply_action with missing required parameters would be a TypeError now
# since parameters are required in method signature
with self.assertRaises(TypeError):
client.annotation_reply_action("enable")
# Test annotation_reply_action with explicit None values should raise ValueError
with self.assertRaises(ValueError) as context:
client.annotation_reply_action("enable", None, "provider", "model")
self.assertIn("cannot be None", str(context.exception))
# Test KnowledgeBaseClient without dataset_id
kb_client = KnowledgeBaseClient(self.api_key, self.base_url)
with self.assertRaises(ValueError) as context:
kb_client.hit_testing("test query")
self.assertIn("dataset_id is not set", str(context.exception))
@patch("dify_client.client.open")
@patch("dify_client.client.requests.request")
def test_file_upload_apis(self, mock_request, mock_open):
"""Test file upload APIs."""
mock_response = Mock()
mock_response.json.return_value = {"result": "success"}
mock_request.return_value = mock_response
mock_file = MagicMock()
mock_open.return_value.__enter__.return_value = mock_file
dataset_id = "test-dataset-id"
client = KnowledgeBaseClient(self.api_key, self.base_url, dataset_id)
# Test upload_pipeline_file
client.upload_pipeline_file("/path/to/test.pdf")
mock_open.assert_called_with("/path/to/test.pdf", "rb")
mock_request.assert_called_once()
def test_comprehensive_coverage(self):
"""Test that all previously missing APIs are now implemented."""
# Test DifyClient methods
dify_methods = ["get_app_info", "get_app_site_info", "get_file_preview"]
client = DifyClient(self.api_key)
for method in dify_methods:
self.assertTrue(hasattr(client, method), f"DifyClient missing method: {method}")
# Test ChatClient annotation methods
chat_methods = [
"annotation_reply_action",
"get_annotation_reply_status",
"list_annotations",
"create_annotation",
"update_annotation",
"delete_annotation",
]
chat_client = ChatClient(self.api_key)
for method in chat_methods:
self.assertTrue(hasattr(chat_client, method), f"ChatClient missing method: {method}")
# Test WorkflowClient advanced methods
workflow_methods = ["get_workflow_logs", "run_specific_workflow"]
workflow_client = WorkflowClient(self.api_key)
for method in workflow_methods:
self.assertTrue(hasattr(workflow_client, method), f"WorkflowClient missing method: {method}")
# Test KnowledgeBaseClient advanced methods
kb_methods = [
"hit_testing",
"get_dataset_metadata",
"create_dataset_metadata",
"update_dataset_metadata",
"get_built_in_metadata",
"manage_built_in_metadata",
"update_documents_metadata",
"list_dataset_tags",
"bind_dataset_tags",
"unbind_dataset_tag",
"get_dataset_tags",
"get_datasource_plugins",
"run_datasource_node",
"run_rag_pipeline",
"upload_pipeline_file",
]
kb_client = KnowledgeBaseClient(self.api_key)
for method in kb_methods:
self.assertTrue(hasattr(kb_client, method), f"KnowledgeBaseClient missing method: {method}")
# Test WorkspaceClient methods
workspace_methods = ["get_available_models"]
workspace_client = WorkspaceClient(self.api_key)
for method in workspace_methods:
self.assertTrue(hasattr(workspace_client, method), f"WorkspaceClient missing method: {method}")
if __name__ == "__main__":
unittest.main()

271
sdks/python-client/uv.lock Normal file
View File

@ -0,0 +1,271 @@
version = 1
revision = 3
requires-python = ">=3.10"
[[package]]
name = "aiofiles"
version = "25.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/41/c3/534eac40372d8ee36ef40df62ec129bee4fdb5ad9706e58a29be53b2c970/aiofiles-25.1.0.tar.gz", hash = "sha256:a8d728f0a29de45dc521f18f07297428d56992a742f0cd2701ba86e44d23d5b2", size = 46354, upload-time = "2025-10-09T20:51:04.358Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/8a/340a1555ae33d7354dbca4faa54948d76d89a27ceef032c8c3bc661d003e/aiofiles-25.1.0-py3-none-any.whl", hash = "sha256:abe311e527c862958650f9438e859c1fa7568a141b22abcd015e120e86a85695", size = 14668, upload-time = "2025-10-09T20:51:03.174Z" },
]
[[package]]
name = "anyio"
version = "4.11.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "idna" },
{ name = "sniffio" },
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c6/78/7d432127c41b50bccba979505f272c16cbcadcc33645d5fa3a738110ae75/anyio-4.11.0.tar.gz", hash = "sha256:82a8d0b81e318cc5ce71a5f1f8b5c4e63619620b63141ef8c995fa0db95a57c4", size = 219094, upload-time = "2025-09-23T09:19:12.58Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" },
]
[[package]]
name = "backports-asyncio-runner"
version = "1.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" },
]
[[package]]
name = "certifi"
version = "2025.10.5"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/4c/5b/b6ce21586237c77ce67d01dc5507039d444b630dd76611bbca2d8e5dcd91/certifi-2025.10.5.tar.gz", hash = "sha256:47c09d31ccf2acf0be3f701ea53595ee7e0b8fa08801c6624be771df09ae7b43", size = 164519, upload-time = "2025-10-05T04:12:15.808Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e4/37/af0d2ef3967ac0d6113837b44a4f0bfe1328c2b9763bd5b1744520e5cfed/certifi-2025.10.5-py3-none-any.whl", hash = "sha256:0f212c2744a9bb6de0c56639a6f68afe01ecd92d91f14ae897c4fe7bbeeef0de", size = 163286, upload-time = "2025-10-05T04:12:14.03Z" },
]
[[package]]
name = "colorama"
version = "0.4.6"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
]
[[package]]
name = "dify-client"
version = "0.1.12"
source = { editable = "." }
dependencies = [
{ name = "aiofiles" },
{ name = "httpx" },
]
[package.optional-dependencies]
dev = [
{ name = "pytest" },
{ name = "pytest-asyncio" },
]
[package.metadata]
requires-dist = [
{ name = "aiofiles", specifier = ">=23.0.0" },
{ name = "httpx", specifier = ">=0.27.0" },
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" },
]
provides-extras = ["dev"]
[[package]]
name = "exceptiongroup"
version = "1.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" },
]
[[package]]
name = "h11"
version = "0.16.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" },
]
[[package]]
name = "httpcore"
version = "1.0.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "certifi" },
{ name = "h11" },
]
sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" },
]
[[package]]
name = "httpx"
version = "0.28.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "certifi" },
{ name = "httpcore" },
{ name = "idna" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
]
[[package]]
name = "idna"
version = "3.10"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" },
]
[[package]]
name = "iniconfig"
version = "2.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
]
[[package]]
name = "packaging"
version = "25.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
]
[[package]]
name = "pluggy"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]]
name = "pygments"
version = "2.19.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
]
[[package]]
name = "pytest"
version = "8.4.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "pygments" },
{ name = "tomli", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" },
]
[[package]]
name = "pytest-asyncio"
version = "1.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" },
{ name = "pytest" },
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" },
]
[[package]]
name = "sniffio"
version = "1.3.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" },
]
[[package]]
name = "tomli"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" },
{ url = "https://files.pythonhosted.org/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" },
{ url = "https://files.pythonhosted.org/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" },
{ url = "https://files.pythonhosted.org/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" },
{ url = "https://files.pythonhosted.org/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" },
{ url = "https://files.pythonhosted.org/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" },
{ url = "https://files.pythonhosted.org/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" },
{ url = "https://files.pythonhosted.org/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" },
{ url = "https://files.pythonhosted.org/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" },
{ url = "https://files.pythonhosted.org/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" },
{ url = "https://files.pythonhosted.org/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" },
{ url = "https://files.pythonhosted.org/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" },
{ url = "https://files.pythonhosted.org/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" },
{ url = "https://files.pythonhosted.org/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" },
{ url = "https://files.pythonhosted.org/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" },
{ url = "https://files.pythonhosted.org/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" },
{ url = "https://files.pythonhosted.org/packages/89/48/06ee6eabe4fdd9ecd48bf488f4ac783844fd777f547b8d1b61c11939974e/tomli-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5192f562738228945d7b13d4930baffda67b69425a7f0da96d360b0a3888136b", size = 154819, upload-time = "2025-10-08T22:01:17.964Z" },
{ url = "https://files.pythonhosted.org/packages/f1/01/88793757d54d8937015c75dcdfb673c65471945f6be98e6a0410fba167ed/tomli-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:be71c93a63d738597996be9528f4abe628d1adf5e6eb11607bc8fe1a510b5dae", size = 148766, upload-time = "2025-10-08T22:01:18.959Z" },
{ url = "https://files.pythonhosted.org/packages/42/17/5e2c956f0144b812e7e107f94f1cc54af734eb17b5191c0bbfb72de5e93e/tomli-2.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4665508bcbac83a31ff8ab08f424b665200c0e1e645d2bd9ab3d3e557b6185b", size = 240771, upload-time = "2025-10-08T22:01:20.106Z" },
{ url = "https://files.pythonhosted.org/packages/d5/f4/0fbd014909748706c01d16824eadb0307115f9562a15cbb012cd9b3512c5/tomli-2.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4021923f97266babc6ccab9f5068642a0095faa0a51a246a6a02fccbb3514eaf", size = 248586, upload-time = "2025-10-08T22:01:21.164Z" },
{ url = "https://files.pythonhosted.org/packages/30/77/fed85e114bde5e81ecf9bc5da0cc69f2914b38f4708c80ae67d0c10180c5/tomli-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4ea38c40145a357d513bffad0ed869f13c1773716cf71ccaa83b0fa0cc4e42f", size = 244792, upload-time = "2025-10-08T22:01:22.417Z" },
{ url = "https://files.pythonhosted.org/packages/55/92/afed3d497f7c186dc71e6ee6d4fcb0acfa5f7d0a1a2878f8beae379ae0cc/tomli-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad805ea85eda330dbad64c7ea7a4556259665bdf9d2672f5dccc740eb9d3ca05", size = 248909, upload-time = "2025-10-08T22:01:23.859Z" },
{ url = "https://files.pythonhosted.org/packages/f8/84/ef50c51b5a9472e7265ce1ffc7f24cd4023d289e109f669bdb1553f6a7c2/tomli-2.3.0-cp313-cp313-win32.whl", hash = "sha256:97d5eec30149fd3294270e889b4234023f2c69747e555a27bd708828353ab606", size = 96946, upload-time = "2025-10-08T22:01:24.893Z" },
{ url = "https://files.pythonhosted.org/packages/b2/b7/718cd1da0884f281f95ccfa3a6cc572d30053cba64603f79d431d3c9b61b/tomli-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0c95ca56fbe89e065c6ead5b593ee64b84a26fca063b5d71a1122bf26e533999", size = 107705, upload-time = "2025-10-08T22:01:26.153Z" },
{ url = "https://files.pythonhosted.org/packages/19/94/aeafa14a52e16163008060506fcb6aa1949d13548d13752171a755c65611/tomli-2.3.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:cebc6fe843e0733ee827a282aca4999b596241195f43b4cc371d64fc6639da9e", size = 154244, upload-time = "2025-10-08T22:01:27.06Z" },
{ url = "https://files.pythonhosted.org/packages/db/e4/1e58409aa78eefa47ccd19779fc6f36787edbe7d4cd330eeeedb33a4515b/tomli-2.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4c2ef0244c75aba9355561272009d934953817c49f47d768070c3c94355c2aa3", size = 148637, upload-time = "2025-10-08T22:01:28.059Z" },
{ url = "https://files.pythonhosted.org/packages/26/b6/d1eccb62f665e44359226811064596dd6a366ea1f985839c566cd61525ae/tomli-2.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c22a8bf253bacc0cf11f35ad9808b6cb75ada2631c2d97c971122583b129afbc", size = 241925, upload-time = "2025-10-08T22:01:29.066Z" },
{ url = "https://files.pythonhosted.org/packages/70/91/7cdab9a03e6d3d2bb11beae108da5bdc1c34bdeb06e21163482544ddcc90/tomli-2.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0eea8cc5c5e9f89c9b90c4896a8deefc74f518db5927d0e0e8d4a80953d774d0", size = 249045, upload-time = "2025-10-08T22:01:31.98Z" },
{ url = "https://files.pythonhosted.org/packages/15/1b/8c26874ed1f6e4f1fcfeb868db8a794cbe9f227299402db58cfcc858766c/tomli-2.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b74a0e59ec5d15127acdabd75ea17726ac4c5178ae51b85bfe39c4f8a278e879", size = 245835, upload-time = "2025-10-08T22:01:32.989Z" },
{ url = "https://files.pythonhosted.org/packages/fd/42/8e3c6a9a4b1a1360c1a2a39f0b972cef2cc9ebd56025168c4137192a9321/tomli-2.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b5870b50c9db823c595983571d1296a6ff3e1b88f734a4c8f6fc6188397de005", size = 253109, upload-time = "2025-10-08T22:01:34.052Z" },
{ url = "https://files.pythonhosted.org/packages/22/0c/b4da635000a71b5f80130937eeac12e686eefb376b8dee113b4a582bba42/tomli-2.3.0-cp314-cp314-win32.whl", hash = "sha256:feb0dacc61170ed7ab602d3d972a58f14ee3ee60494292d384649a3dc38ef463", size = 97930, upload-time = "2025-10-08T22:01:35.082Z" },
{ url = "https://files.pythonhosted.org/packages/b9/74/cb1abc870a418ae99cd5c9547d6bce30701a954e0e721821df483ef7223c/tomli-2.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:b273fcbd7fc64dc3600c098e39136522650c49bca95df2d11cf3b626422392c8", size = 107964, upload-time = "2025-10-08T22:01:36.057Z" },
{ url = "https://files.pythonhosted.org/packages/54/78/5c46fff6432a712af9f792944f4fcd7067d8823157949f4e40c56b8b3c83/tomli-2.3.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:940d56ee0410fa17ee1f12b817b37a4d4e4dc4d27340863cc67236c74f582e77", size = 163065, upload-time = "2025-10-08T22:01:37.27Z" },
{ url = "https://files.pythonhosted.org/packages/39/67/f85d9bd23182f45eca8939cd2bc7050e1f90c41f4a2ecbbd5963a1d1c486/tomli-2.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f85209946d1fe94416debbb88d00eb92ce9cd5266775424ff81bc959e001acaf", size = 159088, upload-time = "2025-10-08T22:01:38.235Z" },
{ url = "https://files.pythonhosted.org/packages/26/5a/4b546a0405b9cc0659b399f12b6adb750757baf04250b148d3c5059fc4eb/tomli-2.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a56212bdcce682e56b0aaf79e869ba5d15a6163f88d5451cbde388d48b13f530", size = 268193, upload-time = "2025-10-08T22:01:39.712Z" },
{ url = "https://files.pythonhosted.org/packages/42/4f/2c12a72ae22cf7b59a7fe75b3465b7aba40ea9145d026ba41cb382075b0e/tomli-2.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5f3ffd1e098dfc032d4d3af5c0ac64f6d286d98bc148698356847b80fa4de1b", size = 275488, upload-time = "2025-10-08T22:01:40.773Z" },
{ url = "https://files.pythonhosted.org/packages/92/04/a038d65dbe160c3aa5a624e93ad98111090f6804027d474ba9c37c8ae186/tomli-2.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5e01decd096b1530d97d5d85cb4dff4af2d8347bd35686654a004f8dea20fc67", size = 272669, upload-time = "2025-10-08T22:01:41.824Z" },
{ url = "https://files.pythonhosted.org/packages/be/2f/8b7c60a9d1612a7cbc39ffcca4f21a73bf368a80fc25bccf8253e2563267/tomli-2.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8a35dd0e643bb2610f156cca8db95d213a90015c11fee76c946aa62b7ae7e02f", size = 279709, upload-time = "2025-10-08T22:01:43.177Z" },
{ url = "https://files.pythonhosted.org/packages/7e/46/cc36c679f09f27ded940281c38607716c86cf8ba4a518d524e349c8b4874/tomli-2.3.0-cp314-cp314t-win32.whl", hash = "sha256:a1f7f282fe248311650081faafa5f4732bdbfef5d45fe3f2e702fbc6f2d496e0", size = 107563, upload-time = "2025-10-08T22:01:44.233Z" },
{ url = "https://files.pythonhosted.org/packages/84/ff/426ca8683cf7b753614480484f6437f568fd2fda2edbdf57a2d3d8b27a0b/tomli-2.3.0-cp314-cp314t-win_amd64.whl", hash = "sha256:70a251f8d4ba2d9ac2542eecf008b3c8a9fc5c3f9f02c56a9d7952612be2fdba", size = 119756, upload-time = "2025-10-08T22:01:45.234Z" },
{ url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" },
]
[[package]]
name = "typing-extensions"
version = "4.15.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" },
]

View File

@ -13,39 +13,60 @@ import { ThemeProvider } from 'next-themes'
import useTheme from '@/hooks/use-theme'
import { useEffect, useState } from 'react'
const DARK_MODE_MEDIA_QUERY = /prefers-color-scheme:\s*dark/i
// Setup browser environment for testing
const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = false) => {
// Mock localStorage
const mockStorage = {
getItem: jest.fn((key: string) => {
if (key === 'theme') return storedTheme
return null
}),
setItem: jest.fn(),
removeItem: jest.fn(),
if (typeof window === 'undefined')
return
try {
window.localStorage.clear()
}
catch {
// ignore if localStorage has been replaced by a throwing stub
}
// Mock system theme preference
const mockMatchMedia = jest.fn((query: string) => ({
matches: query.includes('dark') && systemPrefersDark,
media: query,
addListener: jest.fn(),
removeListener: jest.fn(),
}))
if (storedTheme === null)
window.localStorage.removeItem('theme')
else
window.localStorage.setItem('theme', storedTheme)
if (typeof window !== 'undefined') {
Object.defineProperty(window, 'localStorage', {
value: mockStorage,
configurable: true,
})
document.documentElement.removeAttribute('data-theme')
Object.defineProperty(window, 'matchMedia', {
value: mockMatchMedia,
configurable: true,
})
const mockMatchMedia: typeof window.matchMedia = (query: string) => {
const listeners = new Set<(event: MediaQueryListEvent) => void>()
const isDarkQuery = DARK_MODE_MEDIA_QUERY.test(query)
const matches = isDarkQuery ? systemPrefersDark : false
const mediaQueryList: MediaQueryList = {
matches,
media: query,
onchange: null,
addListener: (listener: MediaQueryListListener) => {
listeners.add(listener)
},
removeListener: (listener: MediaQueryListListener) => {
listeners.delete(listener)
},
addEventListener: (_event, listener: EventListener) => {
if (typeof listener === 'function')
listeners.add(listener as MediaQueryListListener)
},
removeEventListener: (_event, listener: EventListener) => {
if (typeof listener === 'function')
listeners.delete(listener as MediaQueryListListener)
},
dispatchEvent: (event: Event) => {
listeners.forEach(listener => listener(event as MediaQueryListEvent))
return true
},
}
return mediaQueryList
}
return { mockStorage, mockMatchMedia }
jest.spyOn(window, 'matchMedia').mockImplementation(mockMatchMedia)
}
// Simulate real page component based on Dify's actual theme usage
@ -94,7 +115,17 @@ const TestThemeProvider = ({ children }: { children: React.ReactNode }) => (
describe('Real Browser Environment Dark Mode Flicker Test', () => {
beforeEach(() => {
jest.restoreAllMocks()
jest.clearAllMocks()
if (typeof window !== 'undefined') {
try {
window.localStorage.clear()
}
catch {
// ignore when localStorage is replaced with an error-throwing stub
}
document.documentElement.removeAttribute('data-theme')
}
})
describe('Page Refresh Scenario Simulation', () => {
@ -323,35 +354,40 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => {
describe('Edge Cases and Error Handling', () => {
test('handles localStorage access errors gracefully', async () => {
// Mock localStorage to throw an error
setupMockEnvironment(null)
const mockStorage = {
getItem: jest.fn(() => {
throw new Error('LocalStorage access denied')
}),
setItem: jest.fn(),
removeItem: jest.fn(),
clear: jest.fn(),
}
if (typeof window !== 'undefined') {
Object.defineProperty(window, 'localStorage', {
value: mockStorage,
configurable: true,
})
}
render(
<TestThemeProvider>
<PageComponent />
</TestThemeProvider>,
)
// Should fallback gracefully without crashing
await waitFor(() => {
expect(screen.getByTestId('theme-indicator')).toBeInTheDocument()
Object.defineProperty(window, 'localStorage', {
value: mockStorage,
configurable: true,
})
// Should default to light theme when localStorage fails
expect(screen.getByTestId('visual-appearance')).toHaveTextContent('Appearance: light')
try {
render(
<TestThemeProvider>
<PageComponent />
</TestThemeProvider>,
)
// Should fallback gracefully without crashing
await waitFor(() => {
expect(screen.getByTestId('theme-indicator')).toBeInTheDocument()
})
// Should default to light theme when localStorage fails
expect(screen.getByTestId('visual-appearance')).toHaveTextContent('Appearance: light')
}
finally {
Reflect.deleteProperty(window, 'localStorage')
}
})
test('handles invalid theme values in localStorage', async () => {
@ -403,6 +439,8 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => {
setupMockEnvironment('dark')
expect(window.localStorage.getItem('theme')).toBe('dark')
render(
<TestThemeProvider>
<PerformanceTestComponent />

View File

@ -17,12 +17,9 @@ import type {
import { noop } from 'lodash-es'
export type EmbeddedChatbotContextValue = {
userCanAccess?: boolean
appInfoError?: any
appInfoLoading?: boolean
appMeta?: AppMeta
appData?: AppData
appParams?: ChatConfig
appMeta: AppMeta | null
appData: AppData | null
appParams: ChatConfig | null
appChatListDataLoading?: boolean
currentConversationId: string
currentConversationItem?: ConversationItem
@ -59,7 +56,10 @@ export type EmbeddedChatbotContextValue = {
}
export const EmbeddedChatbotContext = createContext<EmbeddedChatbotContextValue>({
userCanAccess: false,
appData: null,
appMeta: null,
appParams: null,
appChatListDataLoading: false,
currentConversationId: '',
appPrevChatList: [],
pinnedConversationList: [],

View File

@ -18,9 +18,6 @@ import { CONVERSATION_ID_INFO } from '../constants'
import { buildChatItemTree, getProcessedInputsFromUrlParams, getProcessedSystemVariablesFromUrlParams, getProcessedUserVariablesFromUrlParams } from '../utils'
import { getProcessedFilesFromResponse } from '../../file-uploader/utils'
import {
fetchAppInfo,
fetchAppMeta,
fetchAppParams,
fetchChatList,
fetchConversations,
generationConversationName,
@ -36,8 +33,7 @@ import { InputVarType } from '@/app/components/workflow/types'
import { TransferMethod } from '@/types/app'
import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils'
import { noop } from 'lodash-es'
import { useGetUserCanAccessApp } from '@/service/access-control'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useWebAppStore } from '@/context/web-app-context'
function getFormattedChatList(messages: any[]) {
const newChatList: ChatItem[] = []
@ -67,18 +63,10 @@ function getFormattedChatList(messages: any[]) {
export const useEmbeddedChatbot = () => {
const isInstalledApp = false
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
const { data: appInfo, isLoading: appInfoLoading, error: appInfoError } = useSWR('appInfo', fetchAppInfo)
const { isPending: isCheckingPermission, data: userCanAccessResult } = useGetUserCanAccessApp({
appId: appInfo?.app_id,
isInstalledApp,
enabled: systemFeatures.webapp_auth.enabled,
})
const appData = useMemo(() => {
return appInfo
}, [appInfo])
const appId = useMemo(() => appData?.app_id, [appData])
const appInfo = useWebAppStore(s => s.appInfo)
const appMeta = useWebAppStore(s => s.appMeta)
const appParams = useWebAppStore(s => s.appParams)
const appId = useMemo(() => appInfo?.app_id, [appInfo])
const [userId, setUserId] = useState<string>()
const [conversationId, setConversationId] = useState<string>()
@ -145,8 +133,6 @@ export const useEmbeddedChatbot = () => {
return currentConversationId
}, [currentConversationId, newConversationId])
const { data: appParams } = useSWR(['appParams', isInstalledApp, appId], () => fetchAppParams(isInstalledApp, appId))
const { data: appMeta } = useSWR(['appMeta', isInstalledApp, appId], () => fetchAppMeta(isInstalledApp, appId))
const { data: appPinnedConversationData } = useSWR(['appConversationData', isInstalledApp, appId, true], () => fetchConversations(isInstalledApp, appId, undefined, true, 100))
const { data: appConversationData, isLoading: appConversationDataLoading, mutate: mutateAppConversationData } = useSWR(['appConversationData', isInstalledApp, appId, false], () => fetchConversations(isInstalledApp, appId, undefined, false, 100))
const { data: appChatListData, isLoading: appChatListDataLoading } = useSWR(chatShouldReloadKey ? ['appChatList', chatShouldReloadKey, isInstalledApp, appId] : null, () => fetchChatList(chatShouldReloadKey, isInstalledApp, appId))
@ -398,16 +384,13 @@ export const useEmbeddedChatbot = () => {
}, [isInstalledApp, appId, t, notify])
return {
appInfoError,
appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && isCheckingPermission),
userCanAccess: systemFeatures.webapp_auth.enabled ? userCanAccessResult?.result : true,
isInstalledApp,
allowResetChat,
appId,
currentConversationId,
currentConversationItem,
handleConversationIdInfoChange,
appData,
appData: appInfo,
appParams: appParams || {} as ChatConfig,
appMeta,
appPinnedConversationData,

View File

@ -101,7 +101,6 @@ const EmbeddedChatbotWrapper = () => {
const {
appData,
userCanAccess,
appParams,
appMeta,
appChatListDataLoading,
@ -135,7 +134,6 @@ const EmbeddedChatbotWrapper = () => {
} = useEmbeddedChatbot()
return <EmbeddedChatbotContext.Provider value={{
userCanAccess,
appData,
appParams,
appMeta,

View File

@ -0,0 +1,152 @@
import React from 'react'
import { cleanup, fireEvent, render } from '@testing-library/react'
import InlineDeleteConfirm from './index'
// Mock react-i18next
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => {
const translations: Record<string, string> = {
'common.operation.deleteConfirmTitle': 'Delete?',
'common.operation.yes': 'Yes',
'common.operation.no': 'No',
'common.operation.confirmAction': 'Please confirm your action.',
}
return translations[key] || key
},
}),
}))
afterEach(cleanup)
describe('InlineDeleteConfirm', () => {
describe('Rendering', () => {
test('should render with default text', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
const { getByText } = render(
<InlineDeleteConfirm onConfirm={onConfirm} onCancel={onCancel} />,
)
expect(getByText('Delete?')).toBeInTheDocument()
expect(getByText('No')).toBeInTheDocument()
expect(getByText('Yes')).toBeInTheDocument()
})
test('should render with custom text', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
const { getByText } = render(
<InlineDeleteConfirm
title="Remove?"
confirmText="Confirm"
cancelText="Cancel"
onConfirm={onConfirm}
onCancel={onCancel}
/>,
)
expect(getByText('Remove?')).toBeInTheDocument()
expect(getByText('Cancel')).toBeInTheDocument()
expect(getByText('Confirm')).toBeInTheDocument()
})
test('should have proper ARIA attributes', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
const { container } = render(
<InlineDeleteConfirm onConfirm={onConfirm} onCancel={onCancel} />,
)
const wrapper = container.firstChild as HTMLElement
expect(wrapper).toHaveAttribute('aria-labelledby', 'inline-delete-confirm-title')
expect(wrapper).toHaveAttribute('aria-describedby', 'inline-delete-confirm-description')
})
})
describe('Button interactions', () => {
test('should call onCancel when cancel button is clicked', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
const { getByText } = render(
<InlineDeleteConfirm onConfirm={onConfirm} onCancel={onCancel} />,
)
fireEvent.click(getByText('No'))
expect(onCancel).toHaveBeenCalledTimes(1)
expect(onConfirm).not.toHaveBeenCalled()
})
test('should call onConfirm when confirm button is clicked', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
const { getByText } = render(
<InlineDeleteConfirm onConfirm={onConfirm} onCancel={onCancel} />,
)
fireEvent.click(getByText('Yes'))
expect(onConfirm).toHaveBeenCalledTimes(1)
expect(onCancel).not.toHaveBeenCalled()
})
})
describe('Variant prop', () => {
test('should render with delete variant by default', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
const { getByText } = render(
<InlineDeleteConfirm onConfirm={onConfirm} onCancel={onCancel} />,
)
const confirmButton = getByText('Yes').closest('button')
expect(confirmButton?.className).toContain('btn-destructive')
})
test('should render without destructive class for warning variant', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
const { getByText } = render(
<InlineDeleteConfirm
variant="warning"
onConfirm={onConfirm}
onCancel={onCancel}
/>,
)
const confirmButton = getByText('Yes').closest('button')
expect(confirmButton?.className).not.toContain('btn-destructive')
})
test('should render without destructive class for info variant', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
const { getByText } = render(
<InlineDeleteConfirm
variant="info"
onConfirm={onConfirm}
onCancel={onCancel}
/>,
)
const confirmButton = getByText('Yes').closest('button')
expect(confirmButton?.className).not.toContain('btn-destructive')
})
})
describe('Custom className', () => {
test('should apply custom className to wrapper', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
const { container } = render(
<InlineDeleteConfirm
className="custom-class"
onConfirm={onConfirm}
onCancel={onCancel}
/>,
)
const wrapper = container.firstChild as HTMLElement
expect(wrapper.className).toContain('custom-class')
})
})
})

View File

@ -0,0 +1,83 @@
'use client'
import type { FC } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import cn from '@/utils/classnames'
export type InlineDeleteConfirmProps = {
title?: string
confirmText?: string
cancelText?: string
onConfirm: () => void
onCancel: () => void
className?: string
variant?: 'delete' | 'warning' | 'info'
}
const InlineDeleteConfirm: FC<InlineDeleteConfirmProps> = ({
title,
confirmText,
cancelText,
onConfirm,
onCancel,
className,
variant = 'delete',
}) => {
const { t } = useTranslation()
const titleText = title || t('common.operation.deleteConfirmTitle', 'Delete?')
const confirmTxt = confirmText || t('common.operation.yes', 'Yes')
const cancelTxt = cancelText || t('common.operation.no', 'No')
return (
<div
aria-labelledby="inline-delete-confirm-title"
aria-describedby="inline-delete-confirm-description"
className={cn(
'flex w-[120px] flex-col justify-center gap-1.5',
'rounded-[10px] border-[0.5px] border-components-panel-border-subtle',
'bg-components-panel-bg-blur px-2 pb-2 pt-1.5',
'backdrop-blur-[10px]',
'shadow-lg',
className,
)}
>
<div
id="inline-delete-confirm-title"
className="system-xs-semibold text-text-primary"
>
{titleText}
</div>
<div className="flex w-full items-center justify-center gap-1">
<Button
size="small"
variant="secondary"
onClick={onCancel}
aria-label={cancelTxt}
className="flex-1"
>
{cancelTxt}
</Button>
<Button
size="small"
variant="primary"
destructive={variant === 'delete'}
onClick={onConfirm}
aria-label={confirmTxt}
className="flex-1"
>
{confirmTxt}
</Button>
</div>
<span id="inline-delete-confirm-description" className="sr-only">
{t('common.operation.confirmAction', 'Please confirm your action.')}
</span>
</div>
)
}
InlineDeleteConfirm.displayName = 'InlineDeleteConfirm'
export default InlineDeleteConfirm

View File

@ -7,6 +7,7 @@ import { useInvalidateStrategyProviders } from '@/service/use-strategy'
import type { Plugin, PluginDeclaration, PluginManifestInMarket } from '../../types'
import { PluginType } from '../../types'
import { useInvalidDataSourceList } from '@/service/use-pipeline'
import { useInvalidDataSourceListAuth } from '@/service/use-datasource'
const useRefreshPluginList = () => {
const invalidateInstalledPluginList = useInvalidateInstalledPluginList()
@ -19,6 +20,8 @@ const useRefreshPluginList = () => {
const invalidateAllBuiltInTools = useInvalidateAllBuiltInTools()
const invalidateAllDataSources = useInvalidDataSourceList()
const invalidateDataSourceListAuth = useInvalidDataSourceListAuth()
const invalidateStrategyProviders = useInvalidateStrategyProviders()
return {
refreshPluginList: (manifest?: PluginManifestInMarket | Plugin | PluginDeclaration | null, refreshAllType?: boolean) => {
@ -32,8 +35,10 @@ const useRefreshPluginList = () => {
// TODO: update suggested tools. It's a function in hook useMarketplacePlugins,handleUpdatePlugins
}
if ((manifest && PluginType.datasource.includes(manifest.category)) || refreshAllType)
if ((manifest && PluginType.datasource.includes(manifest.category)) || refreshAllType) {
invalidateAllDataSources()
invalidateDataSourceListAuth()
}
// model select
if ((manifest && PluginType.model.includes(manifest.category)) || refreshAllType) {

View File

@ -1,10 +1,9 @@
import {
useCallback,
useEffect,
useMemo,
useRef,
} from 'react'
import Link from 'next/link'
import { useTranslation } from 'react-i18next'
import { RiArrowRightUpLine } from '@remixicon/react'
import { BlockEnum } from '../types'
import type {
OnSelectBlock,
@ -14,10 +13,12 @@ import type { DataSourceDefaultValue, ToolDefaultValue } from './types'
import Tools from './tools'
import { ViewType } from './view-type-select'
import cn from '@/utils/classnames'
import type { ListRef } from '@/app/components/workflow/block-selector/market-place-plugin/list'
import { getMarketplaceUrl } from '@/utils/var'
import PluginList, { type ListRef } from '@/app/components/workflow/block-selector/market-place-plugin/list'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { DEFAULT_FILE_EXTENSIONS_IN_LOCAL_FILE_DATA_SOURCE } from './constants'
import { useMarketplacePlugins } from '../../plugins/marketplace/hooks'
import { PluginType } from '../../plugins/types'
import { useGetLanguage } from '@/context/i18n'
type AllToolsProps = {
className?: string
@ -34,9 +35,26 @@ const DataSources = ({
onSelect,
dataSources,
}: AllToolsProps) => {
const { t } = useTranslation()
const language = useGetLanguage()
const pluginRef = useRef<ListRef>(null)
const wrapElemRef = useRef<HTMLDivElement>(null)
const isMatchingKeywords = (text: string, keywords: string) => {
return text.toLowerCase().includes(keywords.toLowerCase())
}
const filteredDatasources = useMemo(() => {
const hasFilter = searchText
if (!hasFilter)
return dataSources.filter(toolWithProvider => toolWithProvider.tools.length > 0)
return dataSources.filter((toolWithProvider) => {
return isMatchingKeywords(toolWithProvider.name, searchText) || toolWithProvider.tools.some((tool) => {
return tool.label[language].toLowerCase().includes(searchText.toLowerCase()) || tool.name.toLowerCase().includes(searchText.toLowerCase())
})
})
}, [searchText, dataSources, language])
const handleSelect = useCallback((_: any, toolDefaultValue: ToolDefaultValue) => {
let defaultValue: DataSourceDefaultValue = {
plugin_id: toolDefaultValue?.provider_id,
@ -55,8 +73,24 @@ const DataSources = ({
}
onSelect(BlockEnum.DataSource, toolDefaultValue && defaultValue)
}, [onSelect])
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
const {
queryPluginsWithDebounced: fetchPlugins,
plugins: notInstalledPlugins = [],
} = useMarketplacePlugins()
useEffect(() => {
if (!enable_marketplace) return
if (searchText) {
fetchPlugins({
query: searchText,
category: PluginType.datasource,
})
}
}, [searchText, enable_marketplace])
return (
<div className={cn(className)}>
<div
@ -66,24 +100,23 @@ const DataSources = ({
>
<Tools
className={toolContentClassName}
tools={dataSources}
tools={filteredDatasources}
onSelect={handleSelect as OnSelectBlock}
viewType={ViewType.flat}
hasSearchText={!!searchText}
canNotSelectMultiple
/>
{
enable_marketplace && (
<Link
className='system-sm-medium sticky bottom-0 z-10 flex h-8 cursor-pointer items-center rounded-b-lg border-[0.5px] border-t border-components-panel-border bg-components-panel-bg-blur px-4 py-1 text-text-accent-light-mode-only shadow-lg'
href={getMarketplaceUrl('')}
target='_blank'
>
<span>{t('plugin.findMoreInMarketplace')}</span>
<RiArrowRightUpLine className='ml-0.5 h-3 w-3' />
</Link>
)
}
{/* Plugins from marketplace */}
{enable_marketplace && (
<PluginList
ref={pluginRef}
wrapElemRef={wrapElemRef}
list={notInstalledPlugins}
tags={[]}
searchText={searchText}
toolContentClassName={toolContentClassName}
/>
)}
</div>
</div>
)

View File

@ -0,0 +1,94 @@
import type { ComponentType } from 'react'
import { BlockEnum } from '../types'
import StartNode from './start/node'
import StartPanel from './start/panel'
import EndNode from './end/node'
import EndPanel from './end/panel'
import AnswerNode from './answer/node'
import AnswerPanel from './answer/panel'
import LLMNode from './llm/node'
import LLMPanel from './llm/panel'
import KnowledgeRetrievalNode from './knowledge-retrieval/node'
import KnowledgeRetrievalPanel from './knowledge-retrieval/panel'
import QuestionClassifierNode from './question-classifier/node'
import QuestionClassifierPanel from './question-classifier/panel'
import IfElseNode from './if-else/node'
import IfElsePanel from './if-else/panel'
import CodeNode from './code/node'
import CodePanel from './code/panel'
import TemplateTransformNode from './template-transform/node'
import TemplateTransformPanel from './template-transform/panel'
import HttpNode from './http/node'
import HttpPanel from './http/panel'
import ToolNode from './tool/node'
import ToolPanel from './tool/panel'
import VariableAssignerNode from './variable-assigner/node'
import VariableAssignerPanel from './variable-assigner/panel'
import AssignerNode from './assigner/node'
import AssignerPanel from './assigner/panel'
import ParameterExtractorNode from './parameter-extractor/node'
import ParameterExtractorPanel from './parameter-extractor/panel'
import IterationNode from './iteration/node'
import IterationPanel from './iteration/panel'
import LoopNode from './loop/node'
import LoopPanel from './loop/panel'
import DocExtractorNode from './document-extractor/node'
import DocExtractorPanel from './document-extractor/panel'
import ListFilterNode from './list-operator/node'
import ListFilterPanel from './list-operator/panel'
import AgentNode from './agent/node'
import AgentPanel from './agent/panel'
import DataSourceNode from './data-source/node'
import DataSourcePanel from './data-source/panel'
import KnowledgeBaseNode from './knowledge-base/node'
import KnowledgeBasePanel from './knowledge-base/panel'
export const NodeComponentMap: Record<string, ComponentType<any>> = {
[BlockEnum.Start]: StartNode,
[BlockEnum.End]: EndNode,
[BlockEnum.Answer]: AnswerNode,
[BlockEnum.LLM]: LLMNode,
[BlockEnum.KnowledgeRetrieval]: KnowledgeRetrievalNode,
[BlockEnum.QuestionClassifier]: QuestionClassifierNode,
[BlockEnum.IfElse]: IfElseNode,
[BlockEnum.Code]: CodeNode,
[BlockEnum.TemplateTransform]: TemplateTransformNode,
[BlockEnum.HttpRequest]: HttpNode,
[BlockEnum.Tool]: ToolNode,
[BlockEnum.VariableAssigner]: VariableAssignerNode,
[BlockEnum.Assigner]: AssignerNode,
[BlockEnum.VariableAggregator]: VariableAssignerNode,
[BlockEnum.ParameterExtractor]: ParameterExtractorNode,
[BlockEnum.Iteration]: IterationNode,
[BlockEnum.Loop]: LoopNode,
[BlockEnum.DocExtractor]: DocExtractorNode,
[BlockEnum.ListFilter]: ListFilterNode,
[BlockEnum.Agent]: AgentNode,
[BlockEnum.DataSource]: DataSourceNode,
[BlockEnum.KnowledgeBase]: KnowledgeBaseNode,
}
export const PanelComponentMap: Record<string, ComponentType<any>> = {
[BlockEnum.Start]: StartPanel,
[BlockEnum.End]: EndPanel,
[BlockEnum.Answer]: AnswerPanel,
[BlockEnum.LLM]: LLMPanel,
[BlockEnum.KnowledgeRetrieval]: KnowledgeRetrievalPanel,
[BlockEnum.QuestionClassifier]: QuestionClassifierPanel,
[BlockEnum.IfElse]: IfElsePanel,
[BlockEnum.Code]: CodePanel,
[BlockEnum.TemplateTransform]: TemplateTransformPanel,
[BlockEnum.HttpRequest]: HttpPanel,
[BlockEnum.Tool]: ToolPanel,
[BlockEnum.VariableAssigner]: VariableAssignerPanel,
[BlockEnum.VariableAggregator]: VariableAssignerPanel,
[BlockEnum.Assigner]: AssignerPanel,
[BlockEnum.ParameterExtractor]: ParameterExtractorPanel,
[BlockEnum.Iteration]: IterationPanel,
[BlockEnum.Loop]: LoopPanel,
[BlockEnum.DocExtractor]: DocExtractorPanel,
[BlockEnum.ListFilter]: ListFilterPanel,
[BlockEnum.Agent]: AgentPanel,
[BlockEnum.DataSource]: DataSourcePanel,
[BlockEnum.KnowledgeBase]: KnowledgeBasePanel,
}

View File

@ -1,101 +1,5 @@
import type { ComponentType } from 'react'
import { BlockEnum } from '../types'
import StartNode from './start/node'
import StartPanel from './start/panel'
import EndNode from './end/node'
import EndPanel from './end/panel'
import AnswerNode from './answer/node'
import AnswerPanel from './answer/panel'
import LLMNode from './llm/node'
import LLMPanel from './llm/panel'
import KnowledgeRetrievalNode from './knowledge-retrieval/node'
import KnowledgeRetrievalPanel from './knowledge-retrieval/panel'
import QuestionClassifierNode from './question-classifier/node'
import QuestionClassifierPanel from './question-classifier/panel'
import IfElseNode from './if-else/node'
import IfElsePanel from './if-else/panel'
import CodeNode from './code/node'
import CodePanel from './code/panel'
import TemplateTransformNode from './template-transform/node'
import TemplateTransformPanel from './template-transform/panel'
import HttpNode from './http/node'
import HttpPanel from './http/panel'
import ToolNode from './tool/node'
import ToolPanel from './tool/panel'
import VariableAssignerNode from './variable-assigner/node'
import VariableAssignerPanel from './variable-assigner/panel'
import AssignerNode from './assigner/node'
import AssignerPanel from './assigner/panel'
import ParameterExtractorNode from './parameter-extractor/node'
import ParameterExtractorPanel from './parameter-extractor/panel'
import IterationNode from './iteration/node'
import IterationPanel from './iteration/panel'
import LoopNode from './loop/node'
import LoopPanel from './loop/panel'
import DocExtractorNode from './document-extractor/node'
import DocExtractorPanel from './document-extractor/panel'
import ListFilterNode from './list-operator/node'
import ListFilterPanel from './list-operator/panel'
import AgentNode from './agent/node'
import AgentPanel from './agent/panel'
import DataSourceNode from './data-source/node'
import DataSourcePanel from './data-source/panel'
import KnowledgeBaseNode from './knowledge-base/node'
import KnowledgeBasePanel from './knowledge-base/panel'
import { TransferMethod } from '@/types/app'
export const NodeComponentMap: Record<string, ComponentType<any>> = {
[BlockEnum.Start]: StartNode,
[BlockEnum.End]: EndNode,
[BlockEnum.Answer]: AnswerNode,
[BlockEnum.LLM]: LLMNode,
[BlockEnum.KnowledgeRetrieval]: KnowledgeRetrievalNode,
[BlockEnum.QuestionClassifier]: QuestionClassifierNode,
[BlockEnum.IfElse]: IfElseNode,
[BlockEnum.Code]: CodeNode,
[BlockEnum.TemplateTransform]: TemplateTransformNode,
[BlockEnum.HttpRequest]: HttpNode,
[BlockEnum.Tool]: ToolNode,
[BlockEnum.VariableAssigner]: VariableAssignerNode,
[BlockEnum.Assigner]: AssignerNode,
[BlockEnum.VariableAggregator]: VariableAssignerNode,
[BlockEnum.ParameterExtractor]: ParameterExtractorNode,
[BlockEnum.Iteration]: IterationNode,
[BlockEnum.Loop]: LoopNode,
[BlockEnum.DocExtractor]: DocExtractorNode,
[BlockEnum.ListFilter]: ListFilterNode,
[BlockEnum.Agent]: AgentNode,
[BlockEnum.DataSource]: DataSourceNode,
[BlockEnum.KnowledgeBase]: KnowledgeBaseNode,
}
export const PanelComponentMap: Record<string, ComponentType<any>> = {
[BlockEnum.Start]: StartPanel,
[BlockEnum.End]: EndPanel,
[BlockEnum.Answer]: AnswerPanel,
[BlockEnum.LLM]: LLMPanel,
[BlockEnum.KnowledgeRetrieval]: KnowledgeRetrievalPanel,
[BlockEnum.QuestionClassifier]: QuestionClassifierPanel,
[BlockEnum.IfElse]: IfElsePanel,
[BlockEnum.Code]: CodePanel,
[BlockEnum.TemplateTransform]: TemplateTransformPanel,
[BlockEnum.HttpRequest]: HttpPanel,
[BlockEnum.Tool]: ToolPanel,
[BlockEnum.VariableAssigner]: VariableAssignerPanel,
[BlockEnum.VariableAggregator]: VariableAssignerPanel,
[BlockEnum.Assigner]: AssignerPanel,
[BlockEnum.ParameterExtractor]: ParameterExtractorPanel,
[BlockEnum.Iteration]: IterationPanel,
[BlockEnum.Loop]: LoopPanel,
[BlockEnum.DocExtractor]: DocExtractorPanel,
[BlockEnum.ListFilter]: ListFilterPanel,
[BlockEnum.Agent]: AgentPanel,
[BlockEnum.DataSource]: DataSourcePanel,
[BlockEnum.KnowledgeBase]: KnowledgeBasePanel,
}
export const CUSTOM_NODE_TYPE = 'custom'
export const FILE_TYPE_OPTIONS = [
{ value: 'image', i18nKey: 'image' },
{ value: 'document', i18nKey: 'doc' },

View File

@ -8,7 +8,7 @@ import { CUSTOM_NODE } from '../constants'
import {
NodeComponentMap,
PanelComponentMap,
} from './constants'
} from './components'
import BaseNode from './_base/node'
import BasePanel from './_base/components/workflow-panel'

View File

@ -29,7 +29,7 @@ const ChunkStructure = ({
<Field
fieldTitleProps={{
title: t('workflow.nodes.knowledgeBase.chunkStructure'),
tooltip: t('workflow.nodes.knowledgeBase.chunkStructure'),
tooltip: t('workflow.nodes.knowledgeBase.chunkStructureTip.message'),
operation: chunkStructure && (
<Selector
options={options}

View File

@ -61,6 +61,10 @@ const translation = {
selectAll: 'Alles auswählen',
deSelectAll: 'Alle abwählen',
config: 'Konfiguration',
yes: 'Ja',
deleteConfirmTitle: 'Löschen?',
no: 'Nein',
confirmAction: 'Bitte bestätigen Sie Ihre Aktion.',
},
placeholder: {
input: 'Bitte eingeben',

View File

@ -18,6 +18,10 @@ const translation = {
cancel: 'Cancel',
clear: 'Clear',
save: 'Save',
yes: 'Yes',
no: 'No',
deleteConfirmTitle: 'Delete?',
confirmAction: 'Please confirm your action.',
saveAndEnable: 'Save & Enable',
edit: 'Edit',
add: 'Add',

View File

@ -61,6 +61,10 @@ const translation = {
deSelectAll: 'Deseleccionar todo',
selectAll: 'Seleccionar todo',
config: 'Config',
confirmAction: 'Por favor, confirme su acción.',
deleteConfirmTitle: '¿Eliminar?',
yes: 'Sí',
no: 'No',
},
errorMsg: {
fieldRequired: '{{field}} es requerido',

View File

@ -61,6 +61,10 @@ const translation = {
selectAll: 'انتخاب همه',
deSelectAll: 'همه را انتخاب نکنید',
config: 'تنظیمات',
no: 'نه',
deleteConfirmTitle: 'حذف شود؟',
yes: 'بله',
confirmAction: 'لطفاً اقدام خود را تأیید کنید.',
},
errorMsg: {
fieldRequired: '{{field}} الزامی است',

View File

@ -61,6 +61,10 @@ const translation = {
deSelectAll: 'Désélectionner tout',
selectAll: 'Sélectionner tout',
config: 'Config',
no: 'Non',
confirmAction: 'Veuillez confirmer votre action.',
deleteConfirmTitle: 'Supprimer ?',
yes: 'Oui',
},
placeholder: {
input: 'Veuillez entrer',

View File

@ -61,6 +61,10 @@ const translation = {
selectAll: 'सभी चुनें',
deSelectAll: 'सभी चयन हटाएँ',
config: 'कॉन्फ़िगरेशन',
no: 'नहीं',
yes: 'हाँ',
deleteConfirmTitle: 'हटाएं?',
confirmAction: 'कृपया अपनी क्रिया की पुष्टि करें।',
},
errorMsg: {
fieldRequired: '{{field}} आवश्यक है',

View File

@ -67,6 +67,10 @@ const translation = {
sure: 'Saya yakin',
imageCopied: 'Gambar yang disalin',
config: 'Konfigurasi',
deleteConfirmTitle: 'Hapus?',
confirmAction: 'Silakan konfirmasi tindakan Anda.',
yes: 'Ya',
no: 'Tidak',
},
errorMsg: {
urlError: 'URL harus dimulai dengan http:// atau https://',

View File

@ -61,6 +61,10 @@ const translation = {
selectAll: 'Seleziona tutto',
deSelectAll: 'Deseleziona tutto',
config: 'Config',
no: 'No',
yes: 'Sì',
confirmAction: 'Per favore conferma la tua azione.',
deleteConfirmTitle: 'Eliminare?',
},
errorMsg: {
fieldRequired: '{{field}} è obbligatorio',

View File

@ -67,6 +67,10 @@ const translation = {
selectAll: 'すべて選択',
deSelectAll: 'すべて選択解除',
config: 'コンフィグ',
yes: 'はい',
no: 'いいえ',
deleteConfirmTitle: '削除しますか?',
confirmAction: '操作を確認してください。',
},
errorMsg: {
fieldRequired: '{{field}}は必要です',

View File

@ -61,6 +61,10 @@ const translation = {
selectAll: '모두 선택',
deSelectAll: '모두 선택 해제',
config: '구성',
no: '아니요',
yes: '네',
deleteConfirmTitle: '삭제하시겠습니까?',
confirmAction: '귀하의 행동을 확인해 주세요.',
},
placeholder: {
input: '입력해주세요',

Some files were not shown because too many files have changed in this diff Show More