Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

This commit is contained in:
-LAN- 2025-09-08 13:56:45 +08:00
commit 299141ae01
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
89 changed files with 2095 additions and 759 deletions

8
.gitignore vendored
View File

@ -198,6 +198,7 @@ sdks/python-client/dify_client.egg-info
!.vscode/launch.json.template !.vscode/launch.json.template
!.vscode/README.md !.vscode/README.md
api/.vscode api/.vscode
web/.vscode
# vscode Code History Extension # vscode Code History Extension
.history .history
@ -215,6 +216,13 @@ mise.toml
# Next.js build output # Next.js build output
.next/ .next/
# PWA generated files
web/public/sw.js
web/public/sw.js.map
web/public/workbox-*.js
web/public/workbox-*.js.map
web/public/fallback-*.js
# AI Assistant # AI Assistant
.roo/ .roo/
api/.env.backup api/.env.backup

View File

@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
@ -6,6 +8,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import api
@ -14,9 +18,9 @@ from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
def admin_required(view): def admin_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ADMIN_API_KEY: if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.") raise Unauthorized("API key is invalid.")

View File

@ -87,7 +87,7 @@ class BaseApiKeyListResource(Resource):
custom="max_keys_exceeded", custom="max_keys_exceeded",
) )
key = ApiToken.generate_api_key(self.token_prefix, 24) key = ApiToken.generate_api_key(self.token_prefix or "", 24)
api_token = ApiToken() api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id) setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id api_token.tenant_id = current_user.current_tenant_id

View File

@ -1,5 +1,6 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import cast from typing import Concatenate, ParamSpec, TypeVar, cast
import flask_login import flask_login
from flask import jsonify, request from flask import jsonify, request
@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
from .. import api from .. import api
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def oauth_server_client_id_required(view):
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("client_id", type=str, required=True, location="json") parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args() parsed_args = parser.parse_args()
@ -30,18 +35,15 @@ def oauth_server_client_id_required(view):
if not oauth_provider_app: if not oauth_provider_app:
raise NotFound("client_id is invalid") raise NotFound("client_id is invalid")
kwargs["oauth_provider_app"] = oauth_provider_app return view(self, oauth_provider_app, *args, **kwargs)
return view(*args, **kwargs)
return decorated return decorated
def oauth_server_access_token_required(view): def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
oauth_provider_app = kwargs.get("oauth_provider_app") if not isinstance(oauth_provider_app, OAuthProviderApp):
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app") raise BadRequest("Invalid oauth_provider_app")
authorization_header = request.headers.get("Authorization") authorization_header = request.headers.get("Authorization")
@ -79,9 +81,7 @@ def oauth_server_access_token_required(view):
response.headers["WWW-Authenticate"] = "Bearer" response.headers["WWW-Authenticate"] = "Bearer"
return response return response
kwargs["account"] = account return view(self, oauth_provider_app, account, *args, **kwargs)
return view(*args, **kwargs)
return decorated return decorated

View File

@ -1,9 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import login_required from libs.login import current_user, login_required
from models.model import Account
from services.billing_service import BillingService from services.billing_service import BillingService
@ -17,9 +17,10 @@ class Subscription(Resource):
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
args = parser.parse_args() args = parser.parse_args()
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_subscription( return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
) )
@ -31,7 +32,9 @@ class Invoices(Resource):
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)

View File

@ -475,6 +475,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
if document.data_source_type == "upload_file": if document.data_source_type == "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
@ -491,6 +493,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import": elif document.data_source_type == "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info={
@ -503,6 +507,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl": elif document.data_source_type == "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value, datasource_type=DatasourceType.WEBSITE.value,
website_info={ website_info={

View File

@ -43,6 +43,8 @@ class ExploreAppMetaApi(InstalledAppResource):
def get(self, installed_app: InstalledApp): def get(self, installed_app: InstalledApp):
"""Get app meta""" """Get app meta"""
app_model = installed_app.app app_model = installed_app.app
if not app_model:
raise ValueError("App not found")
return AppService().get_app_meta(app_model) return AppService().get_app_meta(app_model)

View File

@ -36,6 +36,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
Run workflow Run workflow
""" """
app_model = installed_app.app app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
@ -74,6 +76,8 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
Stop workflow task Stop workflow task
""" """
app_model = installed_app.app app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()

View File

@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Concatenate, Optional, ParamSpec, TypeVar
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource from flask_restx import Resource
@ -13,19 +15,15 @@ from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def installed_app_required(view=None):
def decorator(view): def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
if not kwargs.get("installed_app_id"):
raise ValueError("missing installed_app_id in path parameters")
installed_app_id = kwargs.get("installed_app_id")
installed_app_id = str(installed_app_id)
del kwargs["installed_app_id"]
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.where( .where(
@ -52,10 +50,10 @@ def installed_app_required(view=None):
return decorator return decorator
def user_allowed_to_access_app(view=None): def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view): def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view) @wraps(view)
def decorated(installed_app: InstalledApp, *args, **kwargs): def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
feature = FeatureService.get_system_features() feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled: if feature.webapp_auth.enabled:
app_id = installed_app.app_id app_id = installed_app.app_id

View File

@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask_login import current_user from flask_login import current_user
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db from extensions.ext_database import db
from models.account import TenantPluginPermission from models.account import TenantPluginPermission
P = ParamSpec("P")
R = TypeVar("R")
def plugin_permission_required( def plugin_permission_required(
install_required: bool = False, install_required: bool = False,
debug_required: bool = False, debug_required: bool = False,
): ):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
user = current_user user = current_user
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id

View File

@ -2,7 +2,9 @@ import contextlib
import json import json
import os import os
import time import time
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request from flask import abort, request
from flask_login import current_user from flask_login import current_user
@ -19,10 +21,13 @@ from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
P = ParamSpec("P")
R = TypeVar("R")
def account_initialization_required(view):
def account_initialization_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization # check account initialization
account = current_user account = current_user
@ -34,9 +39,9 @@ def account_initialization_required(view):
return decorated return decorated
def only_edition_cloud(view): def only_edition_cloud(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "CLOUD": if dify_config.EDITION != "CLOUD":
abort(404) abort(404)
@ -45,9 +50,9 @@ def only_edition_cloud(view):
return decorated return decorated
def only_edition_enterprise(view): def only_edition_enterprise(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ENTERPRISE_ENABLED: if not dify_config.ENTERPRISE_ENABLED:
abort(404) abort(404)
@ -56,9 +61,9 @@ def only_edition_enterprise(view):
return decorated return decorated
def only_edition_self_hosted(view): def only_edition_self_hosted(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "SELF_HOSTED": if dify_config.EDITION != "SELF_HOSTED":
abort(404) abort(404)
@ -67,9 +72,9 @@ def only_edition_self_hosted(view):
return decorated return decorated
def cloud_edition_billing_enabled(view): def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled: if not features.billing.enabled:
abort(403, "Billing feature is not enabled.") abort(403, "Billing feature is not enabled.")
@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view):
def cloud_edition_billing_resource_check(resource: str): def cloud_edition_billing_resource_check(resource: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
members = features.members members = features.members
@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str):
def cloud_edition_billing_knowledge_limit_check(resource: str): def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
if resource == "add_segment": if resource == "add_segment":
@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def cloud_edition_billing_rate_limit_check(resource: str): def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge": if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
if knowledge_rate_limit.enabled: if knowledge_rate_limit.enabled:
@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
return interceptor return interceptor
def cloud_utm_record(view): def cloud_utm_record(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
@ -194,9 +199,9 @@ def cloud_utm_record(view):
return decorated return decorated
def setup_required(view): def setup_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
# check setup # check setup
if ( if (
dify_config.EDITION == "SELF_HOSTED" dify_config.EDITION == "SELF_HOSTED"
@ -212,9 +217,9 @@ def setup_required(view):
return decorated return decorated
def enterprise_license_required(view): def enterprise_license_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features() settings = FeatureService.get_system_features()
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.") raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
@ -224,9 +229,9 @@ def enterprise_license_required(view):
return decorated return decorated
def email_password_login_enabled(view): def email_password_login_enabled(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
if features.enable_email_password_login: if features.enable_email_password_login:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -237,9 +242,9 @@ def email_password_login_enabled(view):
return decorated return decorated
def enable_change_email(view): def enable_change_email(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features() features = FeatureService.get_system_features()
if features.enable_change_email: if features.enable_change_email:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -250,9 +255,9 @@ def enable_change_email(view):
return decorated return decorated
def is_allow_transfer_owner(view): def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.is_allow_transfer_workspace: if features.is_allow_transfer_workspace:
return view(*args, **kwargs) return view(*args, **kwargs)

View File

@ -3,7 +3,7 @@ from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
from enum import StrEnum, auto from enum import StrEnum, auto
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional, ParamSpec, TypeVar
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
@ -22,6 +22,9 @@ from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, EndUser from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
class WhereisUserArg(StrEnum): class WhereisUserArg(StrEnum):
""" """
@ -60,27 +63,6 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if tenant.status == TenantStatus.ARCHIVE: if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.") raise Forbidden("The workspace's status is archived.")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).where(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
kwargs["app_model"] = app_model kwargs["app_model"] = app_model
if fetch_user_arg: if fetch_user_arg:
@ -118,8 +100,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
def cloud_edition_billing_resource_check(resource: str, api_token_type: str): def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type) api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id) features = FeatureService.get_features(api_token.tenant_id)
@ -148,9 +130,9 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type) api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id) features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled: if features.billing.enabled:
@ -170,9 +152,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view): def interceptor(view: Callable[P, R]):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type) api_token = validate_and_get_api_token(api_token_type)
if resource == "knowledge": if resource == "knowledge":

View File

@ -1,5 +1,6 @@
from datetime import UTC, datetime from datetime import UTC, datetime
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource
@ -15,6 +16,9 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService from services.webapp_auth_service import WebAppAuthService
P = ParamSpec("P")
R = TypeVar("R")
def validate_jwt_token(view=None): def validate_jwt_token(view=None):
def decorator(view): def decorator(view):

View File

@ -262,6 +262,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
raise MessageNotExistsError() raise MessageNotExistsError()
current_app_model_config = app_model.app_model_config current_app_model_config = app_model.app_model_config
if not current_app_model_config:
raise MoreLikeThisDisabledError()
more_like_this = current_app_model_config.more_like_this_dict more_like_this = current_app_model_config.more_like_this_dict
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:

View File

@ -124,6 +124,7 @@ class TokenBufferMemory:
messages = list(reversed(thread_messages)) messages = list(reversed(thread_messages))
curr_message_tokens = 0
prompt_messages: list[PromptMessage] = [] prompt_messages: list[PromptMessage] = []
for message in messages: for message in messages:
# Process user message with files # Process user message with files

View File

@ -17,6 +17,10 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
class MatrixoneConfig(BaseModel): class MatrixoneConfig(BaseModel):

View File

@ -334,7 +334,8 @@ class NotionExtractor(BaseExtractor):
last_edited_time = self.get_notion_last_edited_time() last_edited_time = self.get_notion_last_edited_time()
data_source_info = document_model.data_source_info_dict data_source_info = document_model.data_source_info_dict
data_source_info["last_edited_time"] = last_edited_time if data_source_info:
data_source_info["last_edited_time"] = last_edited_time
db.session.query(DocumentModel).filter_by(id=document_model.id).update( db.session.query(DocumentModel).filter_by(id=document_model.id).update(
{DocumentModel.data_source_info: json.dumps(data_source_info)} {DocumentModel.data_source_info: json.dumps(data_source_info)}

View File

@ -1,5 +1,5 @@
import json import json
from typing import Any, Optional from typing import Any, Optional, Self
from core.mcp.types import Tool as RemoteMCPTool from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
@ -48,7 +48,7 @@ class MCPToolProviderController(ToolProviderController):
return ToolProviderType.MCP return ToolProviderType.MCP
@classmethod @classmethod
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController": def from_db(cls, db_provider: MCPToolProvider) -> Self:
""" """
from db provider from db provider
""" """

View File

@ -777,7 +777,7 @@ class ToolManager:
if provider is None: if provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
controller = MCPToolProviderController._from_db(provider) controller = MCPToolProviderController.from_db(provider)
return controller return controller
@ -932,7 +932,7 @@ class ToolManager:
tenant_id: str, tenant_id: str,
provider_type: ToolProviderType, provider_type: ToolProviderType,
provider_id: str, provider_id: str,
) -> Union[str, dict]: ) -> Union[str, dict[str, Any]]:
""" """
get the tool icon get the tool icon

View File

@ -3,7 +3,7 @@ from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
from core.variables import ArrayVariable, IntegerVariable, NoneVariable from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities import VariablePool from core.workflow.entities import VariablePool
from core.workflow.enums import ( from core.workflow.enums import (
@ -97,10 +97,10 @@ class IterationNode(Node):
if not variable: if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable): if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0: if isinstance(variable, NoneSegment) or len(variable.value) == 0:
# Try our best to preserve the type informat. # Try our best to preserve the type informat.
if isinstance(variable, ArraySegment): if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []}) output = variable.model_copy(update={"value": []})

View File

@ -50,6 +50,7 @@ from .exc import (
) )
from .prompts import ( from .prompts import (
CHAT_EXAMPLE, CHAT_EXAMPLE,
CHAT_GENERATE_JSON_PROMPT,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
COMPLETION_GENERATE_JSON_PROMPT, COMPLETION_GENERATE_JSON_PROMPT,
FUNCTION_CALLING_EXTRACTOR_EXAMPLE, FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
@ -746,7 +747,7 @@ class ParameterExtractorNode(Node):
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage( system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM, role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
) )
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message] return [system_prompt_messages, user_prompt_message]

View File

@ -1,3 +1,4 @@
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Union, cast from typing import Union, cast
@ -12,9 +13,13 @@ from models.model import EndUser
#: A proxy for the current user. If no user is logged in, this will be an #: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user #: anonymous user
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
def login_required(func): def login_required(func: Callable[P, R]):
""" """
If you decorate a view with this, it will ensure that the current user is If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are logged in and authenticated before calling the actual view. (If they are
@ -49,17 +54,12 @@ def login_required(func):
""" """
@wraps(func) @wraps(func)
def decorated_view(*args, **kwargs): def decorated_view(*args: P.args, **kwargs: P.kwargs):
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass pass
elif current_user is not None and not current_user.is_authenticated: elif current_user is not None and not current_user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore return current_app.login_manager.unauthorized() # type: ignore
return current_app.ensure_sync(func)(*args, **kwargs)
# flask 1.x compatibility
# current_app.ensure_sync is only available in Flask >= 2.0
if callable(getattr(current_app, "ensure_sync", None)):
return current_app.ensure_sync(func)(*args, **kwargs)
return func(*args, **kwargs)
return decorated_view return decorated_view

View File

@ -1,10 +1,10 @@
import enum import enum
import json import json
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Any, Optional
import sqlalchemy as sa import sqlalchemy as sa
from flask_login import UserMixin from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import DateTime, String, func, select from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
@ -225,11 +225,11 @@ class Tenant(Base):
) )
@property @property
def custom_config_dict(self): def custom_config_dict(self) -> dict[str, Any]:
return json.loads(self.custom_config) if self.custom_config else {} return json.loads(self.custom_config) if self.custom_config else {}
@custom_config_dict.setter @custom_config_dict.setter
def custom_config_dict(self, value: dict): def custom_config_dict(self, value: dict[str, Any]) -> None:
self.custom_config = json.dumps(value) self.custom_config = json.dumps(value)

View File

@ -286,7 +286,7 @@ class DatasetProcessRule(Base):
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
} }
def to_dict(self): def to_dict(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"dataset_id": self.dataset_id, "dataset_id": self.dataset_id,
@ -295,7 +295,7 @@ class DatasetProcessRule(Base):
} }
@property @property
def rules_dict(self): def rules_dict(self) -> dict[str, Any] | None:
try: try:
return json.loads(self.rules) if self.rules else None return json.loads(self.rules) if self.rules else None
except JSONDecodeError: except JSONDecodeError:
@ -392,10 +392,10 @@ class Document(Base):
return status return status
@property @property
def data_source_info_dict(self): def data_source_info_dict(self) -> dict[str, Any] | None:
if self.data_source_info: if self.data_source_info:
try: try:
data_source_info_dict = json.loads(self.data_source_info) data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
except JSONDecodeError: except JSONDecodeError:
data_source_info_dict = {} data_source_info_dict = {}
@ -403,10 +403,10 @@ class Document(Base):
return None return None
@property @property
def data_source_detail_dict(self): def data_source_detail_dict(self) -> dict[str, Any]:
if self.data_source_info: if self.data_source_info:
if self.data_source_type == "upload_file": if self.data_source_type == "upload_file":
data_source_info_dict = json.loads(self.data_source_info) data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
.where(UploadFile.id == data_source_info_dict["upload_file_id"]) .where(UploadFile.id == data_source_info_dict["upload_file_id"])
@ -425,7 +425,8 @@ class Document(Base):
} }
} }
elif self.data_source_type in {"notion_import", "website_crawl"}: elif self.data_source_type in {"notion_import", "website_crawl"}:
return json.loads(self.data_source_info) result: dict[str, Any] = json.loads(self.data_source_info)
return result
return {} return {}
@property @property
@ -471,7 +472,7 @@ class Document(Base):
return self.updated_at return self.updated_at
@property @property
def doc_metadata_details(self): def doc_metadata_details(self) -> list[dict[str, Any]] | None:
if self.doc_metadata: if self.doc_metadata:
document_metadatas = ( document_metadatas = (
db.session.query(DatasetMetadata) db.session.query(DatasetMetadata)
@ -481,9 +482,9 @@ class Document(Base):
) )
.all() .all()
) )
metadata_list = [] metadata_list: list[dict[str, Any]] = []
for metadata in document_metadatas: for metadata in document_metadatas:
metadata_dict = { metadata_dict: dict[str, Any] = {
"id": metadata.id, "id": metadata.id,
"name": metadata.name, "name": metadata.name,
"type": metadata.type, "type": metadata.type,
@ -497,13 +498,13 @@ class Document(Base):
return None return None
@property @property
def process_rule_dict(self): def process_rule_dict(self) -> dict[str, Any] | None:
if self.dataset_process_rule_id: if self.dataset_process_rule_id and self.dataset_process_rule:
return self.dataset_process_rule.to_dict() return self.dataset_process_rule.to_dict()
return None return None
def get_built_in_fields(self): def get_built_in_fields(self) -> list[dict[str, Any]]:
built_in_fields = [] built_in_fields: list[dict[str, Any]] = []
built_in_fields.append( built_in_fields.append(
{ {
"id": "built-in", "id": "built-in",
@ -546,7 +547,7 @@ class Document(Base):
) )
return built_in_fields return built_in_fields
def to_dict(self): def to_dict(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"tenant_id": self.tenant_id, "tenant_id": self.tenant_id,
@ -592,13 +593,13 @@ class Document(Base):
"data_source_info_dict": self.data_source_info_dict, "data_source_info_dict": self.data_source_info_dict,
"average_segment_length": self.average_segment_length, "average_segment_length": self.average_segment_length,
"dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
"dataset": self.dataset.to_dict() if self.dataset else None, "dataset": None, # Dataset class doesn't have a to_dict method
"segment_count": self.segment_count, "segment_count": self.segment_count,
"hit_count": self.hit_count, "hit_count": self.hit_count,
} }
@classmethod @classmethod
def from_dict(cls, data: dict): def from_dict(cls, data: dict[str, Any]):
return cls( return cls(
id=data.get("id"), id=data.get("id"),
tenant_id=data.get("tenant_id"), tenant_id=data.get("tenant_id"),
@ -711,46 +712,48 @@ class DocumentSegment(Base):
) )
@property @property
def child_chunks(self): def child_chunks(self) -> list[Any]:
process_rule = self.document.dataset_process_rule if not self.document:
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
return [] return []
process_rule = self.document.dataset_process_rule
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
return []
def get_child_chunks(self): def get_child_chunks(self) -> list[Any]:
process_rule = self.document.dataset_process_rule if not self.document:
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
return [] return []
process_rule = self.document.dataset_process_rule
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
return []
@property @property
def sign_content(self): def sign_content(self) -> str:
return self.get_sign_content() return self.get_sign_content()
def get_sign_content(self): def get_sign_content(self) -> str:
signed_urls = [] signed_urls: list[tuple[int, int, str]] = []
text = self.content text = self.content
# For data before v0.10.0 # For data before v0.10.0
@ -890,17 +893,22 @@ class DatasetKeywordTable(Base):
) )
@property @property
def keyword_table_dict(self): def keyword_table_dict(self) -> dict[str, set[Any]] | None:
class SetDecoder(json.JSONDecoder): class SetDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(object_hook=self.object_hook, *args, **kwargs) def object_hook(dct: Any) -> Any:
if isinstance(dct, dict):
result: dict[str, Any] = {}
items = cast(dict[str, Any], dct).items()
for keyword, node_idxs in items:
if isinstance(node_idxs, list):
result[keyword] = set(cast(list[Any], node_idxs))
else:
result[keyword] = node_idxs
return result
return dct
def object_hook(self, dct): super().__init__(object_hook=object_hook, *args, **kwargs)
if isinstance(dct, dict):
for keyword, node_idxs in dct.items():
if isinstance(node_idxs, list):
dct[keyword] = set(node_idxs)
return dct
# get dataset # get dataset
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
@ -1026,7 +1034,7 @@ class ExternalKnowledgeApis(Base):
updated_by = mapped_column(StringUUID, nullable=True) updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
def to_dict(self): def to_dict(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"tenant_id": self.tenant_id, "tenant_id": self.tenant_id,
@ -1039,14 +1047,14 @@ class ExternalKnowledgeApis(Base):
} }
@property @property
def settings_dict(self): def settings_dict(self) -> dict[str, Any] | None:
try: try:
return json.loads(self.settings) if self.settings else None return json.loads(self.settings) if self.settings else None
except JSONDecodeError: except JSONDecodeError:
return None return None
@property @property
def dataset_bindings(self): def dataset_bindings(self) -> list[dict[str, Any]]:
external_knowledge_bindings = ( external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings) db.session.query(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
@ -1054,7 +1062,7 @@ class ExternalKnowledgeApis(Base):
) )
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
dataset_bindings = [] dataset_bindings: list[dict[str, Any]] = []
for dataset in datasets: for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name}) dataset_bindings.append({"id": dataset.id, "name": dataset.name})

View File

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_login import UserMixin from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column
@ -18,7 +18,7 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.tools.signature import sign_tool_file from core.tools.signature import sign_tool_file
from core.workflow.enums import WorkflowExecutionStatus from core.workflow.enums import WorkflowExecutionStatus
from libs.helper import generate_string from libs.helper import generate_string # type: ignore[import-not-found]
from .account import Account, Tenant from .account import Account, Tenant
from .base import Base from .base import Base
@ -96,7 +96,7 @@ class App(Base):
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property @property
def desc_or_prompt(self): def desc_or_prompt(self) -> str:
if self.description: if self.description:
return self.description return self.description
else: else:
@ -107,12 +107,12 @@ class App(Base):
return "" return ""
@property @property
def site(self): def site(self) -> Optional["Site"]:
site = db.session.query(Site).where(Site.app_id == self.id).first() site = db.session.query(Site).where(Site.app_id == self.id).first()
return site return site
@property @property
def app_model_config(self): def app_model_config(self) -> Optional["AppModelConfig"]:
if self.app_model_config_id: if self.app_model_config_id:
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
@ -128,11 +128,11 @@ class App(Base):
return None return None
@property @property
def api_base_url(self): def api_base_url(self) -> str:
return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"
@property @property
def tenant(self): def tenant(self) -> Optional[Tenant]:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant return tenant
@ -160,9 +160,8 @@ class App(Base):
return str(self.mode) return str(self.mode)
@property @property
def deleted_tools(self) -> list: def deleted_tools(self) -> list[dict[str, str]]:
from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager, ToolProviderType
from core.tools.tool_manager import ToolManager
from services.plugin.plugin_service import PluginService from services.plugin.plugin_service import PluginService
# get agent mode tools # get agent mode tools
@ -242,7 +241,7 @@ class App(Base):
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
} }
deleted_tools = [] deleted_tools: list[dict[str, str]] = []
for tool in tools: for tool in tools:
keys = list(tool.keys()) keys = list(tool.keys())
@ -275,7 +274,7 @@ class App(Base):
return deleted_tools return deleted_tools
@property @property
def tags(self): def tags(self) -> list["Tag"]:
tags = ( tags = (
db.session.query(Tag) db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id) .join(TagBinding, Tag.id == TagBinding.tag_id)
@ -291,7 +290,7 @@ class App(Base):
return tags or [] return tags or []
@property @property
def author_name(self): def author_name(self) -> Optional[str]:
if self.created_by: if self.created_by:
account = db.session.query(Account).where(Account.id == self.created_by).first() account = db.session.query(Account).where(Account.id == self.created_by).first()
if account: if account:
@ -334,20 +333,20 @@ class AppModelConfig(Base):
file_upload = mapped_column(sa.Text) file_upload = mapped_column(sa.Text)
@property @property
def app(self): def app(self) -> Optional[App]:
app = db.session.query(App).where(App.id == self.app_id).first() app = db.session.query(App).where(App.id == self.app_id).first()
return app return app
@property @property
def model_dict(self): def model_dict(self) -> dict[str, Any]:
return json.loads(self.model) if self.model else {} return json.loads(self.model) if self.model else {}
@property @property
def suggested_questions_list(self): def suggested_questions_list(self) -> list[str]:
return json.loads(self.suggested_questions) if self.suggested_questions else [] return json.loads(self.suggested_questions) if self.suggested_questions else []
@property @property
def suggested_questions_after_answer_dict(self): def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
return ( return (
json.loads(self.suggested_questions_after_answer) json.loads(self.suggested_questions_after_answer)
if self.suggested_questions_after_answer if self.suggested_questions_after_answer
@ -355,19 +354,19 @@ class AppModelConfig(Base):
) )
@property @property
def speech_to_text_dict(self): def speech_to_text_dict(self) -> dict[str, Any]:
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}
@property @property
def text_to_speech_dict(self): def text_to_speech_dict(self) -> dict[str, Any]:
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}
@property @property
def retriever_resource_dict(self): def retriever_resource_dict(self) -> dict[str, Any]:
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
@property @property
def annotation_reply_dict(self): def annotation_reply_dict(self) -> dict[str, Any]:
annotation_setting = ( annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
) )
@ -390,11 +389,11 @@ class AppModelConfig(Base):
return {"enabled": False} return {"enabled": False}
@property @property
def more_like_this_dict(self): def more_like_this_dict(self) -> dict[str, Any]:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
@property @property
def sensitive_word_avoidance_dict(self): def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
return ( return (
json.loads(self.sensitive_word_avoidance) json.loads(self.sensitive_word_avoidance)
if self.sensitive_word_avoidance if self.sensitive_word_avoidance
@ -402,15 +401,15 @@ class AppModelConfig(Base):
) )
@property @property
def external_data_tools_list(self) -> list[dict]: def external_data_tools_list(self) -> list[dict[str, Any]]:
return json.loads(self.external_data_tools) if self.external_data_tools else [] return json.loads(self.external_data_tools) if self.external_data_tools else []
@property @property
def user_input_form_list(self): def user_input_form_list(self) -> list[dict[str, Any]]:
return json.loads(self.user_input_form) if self.user_input_form else [] return json.loads(self.user_input_form) if self.user_input_form else []
@property @property
def agent_mode_dict(self): def agent_mode_dict(self) -> dict[str, Any]:
return ( return (
json.loads(self.agent_mode) json.loads(self.agent_mode)
if self.agent_mode if self.agent_mode
@ -418,17 +417,17 @@ class AppModelConfig(Base):
) )
@property @property
def chat_prompt_config_dict(self): def chat_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
@property @property
def completion_prompt_config_dict(self): def completion_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
@property @property
def dataset_configs_dict(self): def dataset_configs_dict(self) -> dict[str, Any]:
if self.dataset_configs: if self.dataset_configs:
dataset_configs: dict = json.loads(self.dataset_configs) dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
if "retrieval_model" not in dataset_configs: if "retrieval_model" not in dataset_configs:
return {"retrieval_model": "single"} return {"retrieval_model": "single"}
else: else:
@ -438,7 +437,7 @@ class AppModelConfig(Base):
} }
@property @property
def file_upload_dict(self): def file_upload_dict(self) -> dict[str, Any]:
return ( return (
json.loads(self.file_upload) json.loads(self.file_upload)
if self.file_upload if self.file_upload
@ -452,7 +451,7 @@ class AppModelConfig(Base):
} }
) )
def to_dict(self): def to_dict(self) -> dict[str, Any]:
return { return {
"opening_statement": self.opening_statement, "opening_statement": self.opening_statement,
"suggested_questions": self.suggested_questions_list, "suggested_questions": self.suggested_questions_list,
@ -546,7 +545,7 @@ class RecommendedApp(Base):
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property @property
def app(self): def app(self) -> Optional[App]:
app = db.session.query(App).where(App.id == self.app_id).first() app = db.session.query(App).where(App.id == self.app_id).first()
return app return app
@ -570,12 +569,12 @@ class InstalledApp(Base):
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property @property
def app(self): def app(self) -> Optional[App]:
app = db.session.query(App).where(App.id == self.app_id).first() app = db.session.query(App).where(App.id == self.app_id).first()
return app return app
@property @property
def tenant(self): def tenant(self) -> Optional[Tenant]:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant return tenant
@ -622,7 +621,7 @@ class Conversation(Base):
mode: Mapped[str] = mapped_column(String(255)) mode: Mapped[str] = mapped_column(String(255))
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
summary = mapped_column(sa.Text) summary = mapped_column(sa.Text)
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
introduction = mapped_column(sa.Text) introduction = mapped_column(sa.Text)
system_instruction = mapped_column(sa.Text) system_instruction = mapped_column(sa.Text)
system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
@ -652,7 +651,7 @@ class Conversation(Base):
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property @property
def inputs(self): def inputs(self) -> dict[str, Any]:
inputs = self._inputs.copy() inputs = self._inputs.copy()
# Convert file mapping to File object # Convert file mapping to File object
@ -660,22 +659,39 @@ class Conversation(Base):
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
from factories import file_factory from factories import file_factory
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: if (
if value["transfer_method"] == FileTransferMethod.TOOL_FILE: isinstance(value, dict)
value["tool_file_id"] = value["related_id"] and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value["upload_file_id"] = value["related_id"]
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
elif isinstance(value, list) and all(
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
): ):
inputs[key] = [] value_dict = cast(dict[str, Any], value)
for item in value: if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
if item["transfer_method"] == FileTransferMethod.TOOL_FILE: value_dict["tool_file_id"] = value_dict["related_id"]
item["tool_file_id"] = item["related_id"] elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: value_dict["upload_file_id"] = value_dict["related_id"]
item["upload_file_id"] = item["related_id"] tenant_id = cast(str, value_dict.get("tenant_id", ""))
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id)
elif isinstance(value, list):
value_list = cast(list[Any], value)
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
item_dict["tool_file_id"] = item_dict["related_id"]
elif item_dict["transfer_method"] in [
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
]:
item_dict["upload_file_id"] = item_dict["related_id"]
tenant_id = cast(str, item_dict.get("tenant_id", ""))
file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id))
inputs[key] = file_list
return inputs return inputs
@ -685,8 +701,10 @@ class Conversation(Base):
for k, v in inputs.items(): for k, v in inputs.items():
if isinstance(v, File): if isinstance(v, File):
inputs[k] = v.model_dump() inputs[k] = v.model_dump()
elif isinstance(v, list) and all(isinstance(item, File) for item in v): elif isinstance(v, list):
inputs[k] = [item.model_dump() for item in v] v_list = cast(list[Any], v)
if all(isinstance(item, File) for item in v_list):
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
self._inputs = inputs self._inputs = inputs
@property @property
@ -826,7 +844,7 @@ class Conversation(Base):
) )
@property @property
def app(self): def app(self) -> Optional[App]:
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
return session.query(App).where(App.id == self.app_id).first() return session.query(App).where(App.id == self.app_id).first()
@ -840,7 +858,7 @@ class Conversation(Base):
return None return None
@property @property
def from_account_name(self): def from_account_name(self) -> Optional[str]:
if self.from_account_id: if self.from_account_id:
account = db.session.query(Account).where(Account.id == self.from_account_id).first() account = db.session.query(Account).where(Account.id == self.from_account_id).first()
if account: if account:
@ -849,10 +867,10 @@ class Conversation(Base):
return None return None
@property @property
def in_debug_mode(self): def in_debug_mode(self) -> bool:
return self.override_model_configs is not None return self.override_model_configs is not None
def to_dict(self): def to_dict(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"app_id": self.app_id, "app_id": self.app_id,
@ -898,7 +916,7 @@ class Message(Base):
model_id = mapped_column(String(255), nullable=True) model_id = mapped_column(String(255), nullable=True)
override_model_configs = mapped_column(sa.Text) override_model_configs = mapped_column(sa.Text)
conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False) conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
query: Mapped[str] = mapped_column(sa.Text, nullable=False) query: Mapped[str] = mapped_column(sa.Text, nullable=False)
message = mapped_column(sa.JSON, nullable=False) message = mapped_column(sa.JSON, nullable=False)
message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
@ -925,28 +943,45 @@ class Message(Base):
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
@property @property
def inputs(self): def inputs(self) -> dict[str, Any]:
inputs = self._inputs.copy() inputs = self._inputs.copy()
for key, value in inputs.items(): for key, value in inputs.items():
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
from factories import file_factory from factories import file_factory
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: if (
if value["transfer_method"] == FileTransferMethod.TOOL_FILE: isinstance(value, dict)
value["tool_file_id"] = value["related_id"] and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value["upload_file_id"] = value["related_id"]
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
elif isinstance(value, list) and all(
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
): ):
inputs[key] = [] value_dict = cast(dict[str, Any], value)
for item in value: if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
if item["transfer_method"] == FileTransferMethod.TOOL_FILE: value_dict["tool_file_id"] = value_dict["related_id"]
item["tool_file_id"] = item["related_id"] elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: value_dict["upload_file_id"] = value_dict["related_id"]
item["upload_file_id"] = item["related_id"] tenant_id = cast(str, value_dict.get("tenant_id", ""))
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id)
elif isinstance(value, list):
value_list = cast(list[Any], value)
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
item_dict["tool_file_id"] = item_dict["related_id"]
elif item_dict["transfer_method"] in [
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
]:
item_dict["upload_file_id"] = item_dict["related_id"]
tenant_id = cast(str, item_dict.get("tenant_id", ""))
file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id))
inputs[key] = file_list
return inputs return inputs
@inputs.setter @inputs.setter
@ -955,8 +990,10 @@ class Message(Base):
for k, v in inputs.items(): for k, v in inputs.items():
if isinstance(v, File): if isinstance(v, File):
inputs[k] = v.model_dump() inputs[k] = v.model_dump()
elif isinstance(v, list) and all(isinstance(item, File) for item in v): elif isinstance(v, list):
inputs[k] = [item.model_dump() for item in v] v_list = cast(list[Any], v)
if all(isinstance(item, File) for item in v_list):
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
self._inputs = inputs self._inputs = inputs
@property @property
@ -1084,15 +1121,15 @@ class Message(Base):
return None return None
@property @property
def in_debug_mode(self): def in_debug_mode(self) -> bool:
return self.override_model_configs is not None return self.override_model_configs is not None
@property @property
def message_metadata_dict(self): def message_metadata_dict(self) -> dict[str, Any]:
return json.loads(self.message_metadata) if self.message_metadata else {} return json.loads(self.message_metadata) if self.message_metadata else {}
@property @property
def agent_thoughts(self): def agent_thoughts(self) -> list["MessageAgentThought"]:
return ( return (
db.session.query(MessageAgentThought) db.session.query(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id) .where(MessageAgentThought.message_id == self.id)
@ -1101,11 +1138,11 @@ class Message(Base):
) )
@property @property
def retriever_resources(self): def retriever_resources(self) -> Any | list[Any]:
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
@property @property
def message_files(self): def message_files(self) -> list[dict[str, Any]]:
from factories import file_factory from factories import file_factory
message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all()
@ -1113,7 +1150,7 @@ class Message(Base):
if not current_app: if not current_app:
raise ValueError(f"App {self.app_id} not found") raise ValueError(f"App {self.app_id} not found")
files = [] files: list[File] = []
for message_file in message_files: for message_file in message_files:
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value:
if message_file.upload_file_id is None: if message_file.upload_file_id is None:
@ -1160,7 +1197,7 @@ class Message(Base):
) )
files.append(file) files.append(file)
result = [ result: list[dict[str, Any]] = [
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
for (file, message_file) in zip(files, message_files) for (file, message_file) in zip(files, message_files)
] ]
@ -1177,7 +1214,7 @@ class Message(Base):
return None return None
def to_dict(self): def to_dict(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"app_id": self.app_id, "app_id": self.app_id,
@ -1201,7 +1238,7 @@ class Message(Base):
} }
@classmethod @classmethod
def from_dict(cls, data: dict): def from_dict(cls, data: dict[str, Any]) -> "Message":
return cls( return cls(
id=data["id"], id=data["id"],
app_id=data["app_id"], app_id=data["app_id"],
@ -1251,7 +1288,7 @@ class MessageFeedback(Base):
account = db.session.query(Account).where(Account.id == self.from_account_id).first() account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account return account
def to_dict(self): def to_dict(self) -> dict[str, Any]:
return { return {
"id": str(self.id), "id": str(self.id),
"app_id": str(self.app_id), "app_id": str(self.app_id),
@ -1436,7 +1473,18 @@ class EndUser(Base, UserMixin):
type: Mapped[str] = mapped_column(String(255), nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False)
external_user_id = mapped_column(String(255), nullable=True) external_user_id = mapped_column(String(255), nullable=True)
name = mapped_column(String(255)) name = mapped_column(String(255))
is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) _is_anonymous: Mapped[bool] = mapped_column(
"is_anonymous", sa.Boolean, nullable=False, server_default=sa.text("true")
)
@property
def is_anonymous(self) -> Literal[False]:
return False
@is_anonymous.setter
def is_anonymous(self, value: bool) -> None:
self._is_anonymous = value
session_id: Mapped[str] = mapped_column() session_id: Mapped[str] = mapped_column()
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@ -1462,7 +1510,7 @@ class AppMCPServer(Base):
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod @staticmethod
def generate_server_code(n): def generate_server_code(n: int) -> str:
while True: while True:
result = generate_string(n) result = generate_string(n)
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
@ -1519,7 +1567,7 @@ class Site(Base):
self._custom_disclaimer = value self._custom_disclaimer = value
@staticmethod @staticmethod
def generate_code(n): def generate_code(n: int) -> str:
while True: while True:
result = generate_string(n) result = generate_string(n)
while db.session.query(Site).where(Site.code == result).count() > 0: while db.session.query(Site).where(Site.code == result).count() > 0:
@ -1550,7 +1598,7 @@ class ApiToken(Base):
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod @staticmethod
def generate_api_key(prefix, n): def generate_api_key(prefix: str, n: int) -> str:
while True: while True:
result = prefix + generate_string(n) result = prefix + generate_string(n)
if db.session.scalar(select(exists().where(ApiToken.token == result))): if db.session.scalar(select(exists().where(ApiToken.token == result))):
@ -1690,7 +1738,7 @@ class MessageAgentThought(Base):
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property @property
def files(self): def files(self) -> list[Any]:
if self.message_files: if self.message_files:
return cast(list[Any], json.loads(self.message_files)) return cast(list[Any], json.loads(self.message_files))
else: else:
@ -1701,32 +1749,32 @@ class MessageAgentThought(Base):
return self.tool.split(";") if self.tool else [] return self.tool.split(";") if self.tool else []
@property @property
def tool_labels(self): def tool_labels(self) -> dict[str, Any]:
try: try:
if self.tool_labels_str: if self.tool_labels_str:
return cast(dict, json.loads(self.tool_labels_str)) return cast(dict[str, Any], json.loads(self.tool_labels_str))
else: else:
return {} return {}
except Exception: except Exception:
return {} return {}
@property @property
def tool_meta(self): def tool_meta(self) -> dict[str, Any]:
try: try:
if self.tool_meta_str: if self.tool_meta_str:
return cast(dict, json.loads(self.tool_meta_str)) return cast(dict[str, Any], json.loads(self.tool_meta_str))
else: else:
return {} return {}
except Exception: except Exception:
return {} return {}
@property @property
def tool_inputs_dict(self): def tool_inputs_dict(self) -> dict[str, Any]:
tools = self.tools tools = self.tools
try: try:
if self.tool_input: if self.tool_input:
data = json.loads(self.tool_input) data = json.loads(self.tool_input)
result = {} result: dict[str, Any] = {}
for tool in tools: for tool in tools:
if tool in data: if tool in data:
result[tool] = data[tool] result[tool] = data[tool]
@ -1742,12 +1790,12 @@ class MessageAgentThought(Base):
return {} return {}
@property @property
def tool_outputs_dict(self): def tool_outputs_dict(self) -> dict[str, Any]:
tools = self.tools tools = self.tools
try: try:
if self.observation: if self.observation:
data = json.loads(self.observation) data = json.loads(self.observation)
result = {} result: dict[str, Any] = {}
for tool in tools: for tool in tools:
if tool in data: if tool in data:
result[tool] = data[tool] result[tool] = data[tool]
@ -1845,14 +1893,14 @@ class TraceAppConfig(Base):
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
@property @property
def tracing_config_dict(self): def tracing_config_dict(self) -> dict[str, Any]:
return self.tracing_config or {} return self.tracing_config or {}
@property @property
def tracing_config_str(self): def tracing_config_str(self) -> str:
return json.dumps(self.tracing_config_dict) return json.dumps(self.tracing_config_dict)
def to_dict(self): def to_dict(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"app_id": self.app_id, "app_id": self.app_id,

View File

@ -17,7 +17,7 @@ class ProviderType(Enum):
SYSTEM = "system" SYSTEM = "system"
@staticmethod @staticmethod
def value_of(value): def value_of(value: str) -> "ProviderType":
for member in ProviderType: for member in ProviderType:
if member.value == value: if member.value == value:
return member return member
@ -35,7 +35,7 @@ class ProviderQuotaType(Enum):
"""hosted trial quota""" """hosted trial quota"""
@staticmethod @staticmethod
def value_of(value): def value_of(value: str) -> "ProviderQuotaType":
for member in ProviderQuotaType: for member in ProviderQuotaType:
if member.value == value: if member.value == value:
return member return member

View File

@ -1,6 +1,6 @@
import json import json
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
from urllib.parse import urlparse from urllib.parse import urlparse
import sqlalchemy as sa import sqlalchemy as sa
@ -58,8 +58,8 @@ class ToolOAuthTenantClient(Base):
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
@property @property
def oauth_params(self): def oauth_params(self) -> dict[str, Any]:
return cast(dict, json.loads(self.encrypted_oauth_params or "{}")) return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
class BuiltinToolProvider(Base): class BuiltinToolProvider(Base):
@ -100,8 +100,8 @@ class BuiltinToolProvider(Base):
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))
@property @property
def credentials(self): def credentials(self) -> dict[str, Any]:
return cast(dict, json.loads(self.encrypted_credentials)) return cast(dict[str, Any], json.loads(self.encrypted_credentials))
class ApiToolProvider(Base): class ApiToolProvider(Base):
@ -154,8 +154,8 @@ class ApiToolProvider(Base):
return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
@property @property
def credentials(self): def credentials(self) -> dict[str, Any]:
return dict(json.loads(self.credentials_str)) return dict[str, Any](json.loads(self.credentials_str))
@property @property
def user(self) -> Account | None: def user(self) -> Account | None:
@ -299,9 +299,9 @@ class MCPToolProvider(Base):
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property @property
def credentials(self): def credentials(self) -> dict[str, Any]:
try: try:
return cast(dict, json.loads(self.encrypted_credentials)) or {} return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
except Exception: except Exception:
return {} return {}
@ -341,12 +341,12 @@ class MCPToolProvider(Base):
return mask_url(self.decrypted_server_url) return mask_url(self.decrypted_server_url)
@property @property
def decrypted_credentials(self): def decrypted_credentials(self) -> dict[str, Any]:
from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.mcp_tool.provider import MCPToolProviderController from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.encryption import create_provider_encrypter from core.tools.utils.encryption import create_provider_encrypter
provider_controller = MCPToolProviderController._from_db(self) provider_controller = MCPToolProviderController.from_db(self)
encrypter, _ = create_provider_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
@ -354,7 +354,7 @@ class MCPToolProvider(Base):
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),
) )
return encrypter.decrypt(self.credentials) # type: ignore return encrypter.decrypt(self.credentials)
class ToolModelInvoke(Base): class ToolModelInvoke(Base):

View File

@ -1,29 +1,34 @@
import enum import enum
from typing import Generic, TypeVar import uuid
from typing import Any, Generic, TypeVar
from sqlalchemy import CHAR, VARCHAR, TypeDecorator from sqlalchemy import CHAR, VARCHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.type_api import TypeEngine
class StringUUID(TypeDecorator): class StringUUID(TypeDecorator[uuid.UUID | str | None]):
impl = CHAR impl = CHAR
cache_ok = True cache_ok = True
def process_bind_param(self, value, dialect): def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None: if value is None:
return value return value
elif dialect.name == "postgresql": elif dialect.name == "postgresql":
return str(value) return str(value)
else: else:
return value.hex if isinstance(value, uuid.UUID):
return value.hex
return value
def load_dialect_impl(self, dialect): def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql": if dialect.name == "postgresql":
return dialect.type_descriptor(UUID()) return dialect.type_descriptor(UUID())
else: else:
return dialect.type_descriptor(CHAR(36)) return dialect.type_descriptor(CHAR(36))
def process_result_value(self, value, dialect): def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None: if value is None:
return value return value
return str(value) return str(value)
@ -32,7 +37,7 @@ class StringUUID(TypeDecorator):
_E = TypeVar("_E", bound=enum.StrEnum) _E = TypeVar("_E", bound=enum.StrEnum)
class EnumText(TypeDecorator, Generic[_E]): class EnumText(TypeDecorator[_E | None], Generic[_E]):
impl = VARCHAR impl = VARCHAR
cache_ok = True cache_ok = True
@ -50,28 +55,25 @@ class EnumText(TypeDecorator, Generic[_E]):
# leave some rooms for future longer enum values. # leave some rooms for future longer enum values.
self._length = max(max_enum_value_len, 20) self._length = max(max_enum_value_len, 20)
def process_bind_param(self, value: _E | str | None, dialect): def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None:
if value is None: if value is None:
return value return value
if isinstance(value, self._enum_class): if isinstance(value, self._enum_class):
return value.value return value.value
elif isinstance(value, str): # Since _E is bound to StrEnum which inherits from str, at this point value must be str
self._enum_class(value) self._enum_class(value)
return value return value
else:
raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
def load_dialect_impl(self, dialect): def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
return dialect.type_descriptor(VARCHAR(self._length)) return dialect.type_descriptor(VARCHAR(self._length))
def process_result_value(self, value, dialect) -> _E | None: def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
if value is None: if value is None:
return value return value
if not isinstance(value, str): # Type annotation guarantees value is str at this point
raise TypeError(f"expected str, got {type(value)}")
return self._enum_class(value) return self._enum_class(value)
def compare_values(self, x, y): def compare_values(self, x: _E | None, y: _E | None) -> bool:
if x is None or y is None: if x is None or y is None:
return x is y return x is y
return x == y return x == y

View File

@ -3,7 +3,7 @@ import logging
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import datetime from datetime import datetime
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4 from uuid import uuid4
import sqlalchemy as sa import sqlalchemy as sa
@ -224,7 +224,7 @@ class Workflow(Base):
raise WorkflowDataError("nodes not found in workflow graph") raise WorkflowDataError("nodes not found in workflow graph")
try: try:
node_config = next(filter(lambda node: node["id"] == node_id, nodes)) node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration: except StopIteration:
raise NodeNotFoundError(node_id) raise NodeNotFoundError(node_id)
assert isinstance(node_config, dict) assert isinstance(node_config, dict)
@ -289,7 +289,7 @@ class Workflow(Base):
def features_dict(self) -> dict[str, Any]: def features_dict(self) -> dict[str, Any]:
return json.loads(self.features) if self.features else {} return json.loads(self.features) if self.features else {}
def user_input_form(self, to_old_structure: bool = False): def user_input_form(self, to_old_structure: bool = False) -> list[Any]:
# get start node from graph # get start node from graph
if not self.graph: if not self.graph:
return [] return []
@ -306,7 +306,7 @@ class Workflow(Base):
variables: list[Any] = start_node.get("data", {}).get("variables", []) variables: list[Any] = start_node.get("data", {}).get("variables", [])
if to_old_structure: if to_old_structure:
old_structure_variables = [] old_structure_variables: list[dict[str, Any]] = []
for variable in variables: for variable in variables:
old_structure_variables.append({variable["type"]: variable}) old_structure_variables.append({variable["type"]: variable})
@ -346,9 +346,7 @@ class Workflow(Base):
@property @property
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created. # _environment_variables is guaranteed to be non-None due to server_default="{}"
if self._environment_variables is None:
self._environment_variables = "{}"
# Use workflow.tenant_id to avoid relying on request user in background threads # Use workflow.tenant_id to avoid relying on request user in background threads
tenant_id = self.tenant_id tenant_id = self.tenant_id
@ -362,17 +360,18 @@ class Workflow(Base):
] ]
# decrypt secret variables value # decrypt secret variables value
def decrypt_func(var): def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable): if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
return var return var
else: else:
raise AssertionError("this statement should be unreachable.") # Other variable types are not supported for environment variables
raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}")
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list( decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [
map(decrypt_func, results) decrypt_func(var) for var in results
) ]
return decrypted_results return decrypted_results
@environment_variables.setter @environment_variables.setter
@ -400,7 +399,7 @@ class Workflow(Base):
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value # encrypt secret variables value
def encrypt_func(var): def encrypt_func(var: Variable) -> Variable:
if isinstance(var, SecretVariable): if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else: else:
@ -430,9 +429,7 @@ class Workflow(Base):
@property @property
def conversation_variables(self) -> Sequence[Variable]: def conversation_variables(self) -> Sequence[Variable]:
# TODO: find some way to init `self._conversation_variables` when instance created. # _conversation_variables is guaranteed to be non-None due to server_default="{}"
if self._conversation_variables is None:
self._conversation_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables) variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
@ -577,7 +574,7 @@ class WorkflowRun(Base):
} }
@classmethod @classmethod
def from_dict(cls, data: dict) -> "WorkflowRun": def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls( return cls(
id=data.get("id"), id=data.get("id"),
tenant_id=data.get("tenant_id"), tenant_id=data.get("tenant_id"),
@ -662,7 +659,8 @@ class WorkflowNodeExecutionModel(Base):
__tablename__ = "workflow_node_executions" __tablename__ = "workflow_node_executions"
@declared_attr @declared_attr
def __table_args__(cls): # noqa @classmethod
def __table_args__(cls) -> Any:
return ( return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index( Index(
@ -699,7 +697,7 @@ class WorkflowNodeExecutionModel(Base):
# MyPy may flag the following line because it doesn't recognize that # MyPy may flag the following line because it doesn't recognize that
# the `declared_attr` decorator passes the receiving class as the first # the `declared_attr` decorator passes the receiving class as the first
# argument to this method, allowing us to reference class attributes. # argument to this method, allowing us to reference class attributes.
cls.created_at.desc(), # type: ignore cls.created_at.desc(),
), ),
) )
@ -761,15 +759,15 @@ class WorkflowNodeExecutionModel(Base):
return json.loads(self.execution_metadata) if self.execution_metadata else {} return json.loads(self.execution_metadata) if self.execution_metadata else {}
@property @property
def extras(self): def extras(self) -> dict[str, Any]:
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
extras = {} extras: dict[str, Any] = {}
if self.execution_metadata_dict: if self.execution_metadata_dict:
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict:
tool_info = self.execution_metadata_dict["tool_info"] tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"]
extras["icon"] = ToolManager.get_tool_icon( extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
provider_type=tool_info["provider_type"], provider_type=tool_info["provider_type"],
@ -1037,7 +1035,7 @@ class WorkflowDraftVariable(Base):
# making this attribute harder to access from outside the class. # making this attribute harder to access from outside the class.
__value: Segment | None __value: Segment | None
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
""" """
The constructor of `WorkflowDraftVariable` is not intended for The constructor of `WorkflowDraftVariable` is not intended for
direct use outside this file. Its solo purpose is setup private state direct use outside this file. Its solo purpose is setup private state
@ -1055,15 +1053,15 @@ class WorkflowDraftVariable(Base):
self.__value = None self.__value = None
def get_selector(self) -> list[str]: def get_selector(self) -> list[str]:
selector = json.loads(self.selector) selector: Any = json.loads(self.selector)
if not isinstance(selector, list): if not isinstance(selector, list):
logger.error( logger.error(
"invalid selector loaded from database, type=%s, value=%s", "invalid selector loaded from database, type=%s, value=%s",
type(selector), type(selector).__name__,
self.selector, self.selector,
) )
raise ValueError("invalid selector.") raise ValueError("invalid selector.")
return selector return cast(list[str], selector)
def _set_selector(self, value: list[str]): def _set_selector(self, value: list[str]):
self.selector = json.dumps(value) self.selector = json.dumps(value)
@ -1086,15 +1084,17 @@ class WorkflowDraftVariable(Base):
# `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging.
if isinstance(value, dict): if isinstance(value, dict):
if not maybe_file_object(value): if not maybe_file_object(value):
return value return cast(Any, value)
return File.model_validate(value) return File.model_validate(value)
elif isinstance(value, list) and value: elif isinstance(value, list) and value:
first = value[0] value_list = cast(list[Any], value)
first: Any = value_list[0]
if not maybe_file_object(first): if not maybe_file_object(first):
return value return cast(Any, value)
return [File.model_validate(i) for i in value] file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list]
return cast(Any, file_list)
else: else:
return value return cast(Any, value)
@classmethod @classmethod
def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment:

View File

@ -6,7 +6,6 @@
"tests/", "tests/",
"migrations/", "migrations/",
".venv/", ".venv/",
"models/",
"core/", "core/",
"controllers/", "controllers/",
"tasks/", "tasks/",

View File

@ -1,8 +1,7 @@
import threading import threading
from typing import Optional from typing import Any, Optional
import pytz import pytz
from flask_login import current_user
import contexts import contexts
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
@ -10,6 +9,7 @@ from core.plugin.impl.agent import PluginAgentClient
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_user
from models.account import Account from models.account import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought from models.model import App, Conversation, EndUser, Message, MessageAgentThought
@ -61,14 +61,15 @@ class AgentService:
executor = executor.name executor = executor.name
else: else:
executor = "Unknown" executor = "Unknown"
assert isinstance(current_user, Account)
assert current_user.timezone is not None
timezone = pytz.timezone(current_user.timezone) timezone = pytz.timezone(current_user.timezone)
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
if not app_model_config: if not app_model_config:
raise ValueError("App model config not found") raise ValueError("App model config not found")
result = { result: dict[str, Any] = {
"meta": { "meta": {
"status": "success", "status": "success",
"executor": executor, "executor": executor,

View File

@ -2,7 +2,6 @@ import uuid
from typing import Optional from typing import Optional
import pandas as pd import pandas as pd
from flask_login import current_user
from sqlalchemy import or_, select from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -10,6 +9,8 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
from services.feature_service import FeatureService from services.feature_service import FeatureService
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
@ -24,6 +25,7 @@ class AppAnnotationService:
@classmethod @classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info # get app info
assert isinstance(current_user, Account)
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -62,6 +64,7 @@ class AppAnnotationService:
db.session.commit() db.session.commit()
# if annotation reply is enabled , add annotation to index # if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
assert current_user.current_tenant_id is not None
if annotation_setting: if annotation_setting:
add_annotation_to_index_task.delay( add_annotation_to_index_task.delay(
annotation.id, annotation.id,
@ -84,6 +87,8 @@ class AppAnnotationService:
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
# send batch add segments task # send batch add segments task
redis_client.setnx(enable_app_annotation_job_key, "waiting") redis_client.setnx(enable_app_annotation_job_key, "waiting")
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
enable_annotation_reply_task.delay( enable_annotation_reply_task.delay(
str(job_id), str(job_id),
app_id, app_id,
@ -97,6 +102,8 @@ class AppAnnotationService:
@classmethod @classmethod
def disable_app_annotation(cls, app_id: str): def disable_app_annotation(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key) cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None: if cache_result is not None:
@ -113,6 +120,8 @@ class AppAnnotationService:
@classmethod @classmethod
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info # get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -145,6 +154,8 @@ class AppAnnotationService:
@classmethod @classmethod
def export_annotation_list_by_app_id(cls, app_id: str): def export_annotation_list_by_app_id(cls, app_id: str):
# get app info # get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -164,6 +175,8 @@ class AppAnnotationService:
@classmethod @classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info # get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -193,6 +206,8 @@ class AppAnnotationService:
@classmethod @classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info # get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -230,6 +245,8 @@ class AppAnnotationService:
@classmethod @classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str): def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info # get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -269,6 +286,8 @@ class AppAnnotationService:
@classmethod @classmethod
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
# get app info # get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -317,6 +336,8 @@ class AppAnnotationService:
@classmethod @classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage): def batch_import_app_annotations(cls, app_id, file: FileStorage):
# get app info # get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -355,6 +376,8 @@ class AppAnnotationService:
@classmethod @classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
@ -425,6 +448,8 @@ class AppAnnotationService:
@classmethod @classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str): def get_app_annotation_setting_by_app_id(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
@ -451,6 +476,8 @@ class AppAnnotationService:
@classmethod @classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info # get app info
app = ( app = (
db.session.query(App) db.session.query(App)
@ -491,6 +518,8 @@ class AppAnnotationService:
@classmethod @classmethod
def clear_all_annotations(cls, app_id: str): def clear_all_annotations(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = ( app = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")

View File

@ -2,7 +2,6 @@ import json
import logging import logging
from typing import Optional, TypedDict, cast from typing import Optional, TypedDict, cast
from flask_login import current_user
from flask_sqlalchemy.pagination import Pagination from flask_sqlalchemy.pagination import Pagination
from configs import dify_config from configs import dify_config
@ -17,6 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created from events.app_event import app_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account from models.account import Account
from models.model import App, AppMode, AppModelConfig, Site from models.model import App, AppMode, AppModelConfig, Site
from models.tools import ApiToolProvider from models.tools import ApiToolProvider
@ -168,9 +168,13 @@ class AppService:
""" """
Get App Get App
""" """
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get original app model config # get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
model_config = app.app_model_config model_config = app.app_model_config
if not model_config:
return app
agent_mode = model_config.agent_mode_dict agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input # decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get("tools") or []: for tool in agent_mode.get("tools") or []:
@ -205,7 +209,8 @@ class AppService:
pass pass
# override agent mode # override agent mode
model_config.agent_mode = json.dumps(agent_mode) if model_config:
model_config.agent_mode = json.dumps(agent_mode)
class ModifiedApp(App): class ModifiedApp(App):
""" """
@ -239,6 +244,7 @@ class AppService:
:param args: request args :param args: request args
:return: App instance :return: App instance
""" """
assert current_user is not None
app.name = args["name"] app.name = args["name"]
app.description = args["description"] app.description = args["description"]
app.icon_type = args["icon_type"] app.icon_type = args["icon_type"]
@ -259,6 +265,7 @@ class AppService:
:param name: new name :param name: new name
:return: App instance :return: App instance
""" """
assert current_user is not None
app.name = name app.name = name
app.updated_by = current_user.id app.updated_by = current_user.id
app.updated_at = naive_utc_now() app.updated_at = naive_utc_now()
@ -274,6 +281,7 @@ class AppService:
:param icon_background: new icon_background :param icon_background: new icon_background
:return: App instance :return: App instance
""" """
assert current_user is not None
app.icon = icon app.icon = icon
app.icon_background = icon_background app.icon_background = icon_background
app.updated_by = current_user.id app.updated_by = current_user.id
@ -291,7 +299,7 @@ class AppService:
""" """
if enable_site == app.enable_site: if enable_site == app.enable_site:
return app return app
assert current_user is not None
app.enable_site = enable_site app.enable_site = enable_site
app.updated_by = current_user.id app.updated_by = current_user.id
app.updated_at = naive_utc_now() app.updated_at = naive_utc_now()
@ -308,6 +316,7 @@ class AppService:
""" """
if enable_api == app.enable_api: if enable_api == app.enable_api:
return app return app
assert current_user is not None
app.enable_api = enable_api app.enable_api = enable_api
app.updated_by = current_user.id app.updated_by = current_user.id

View File

@ -12,7 +12,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import MessageStatus from models.enums import MessageStatus
from models.model import App, AppMode, AppModelConfig, Message from models.model import App, AppMode, Message
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
NoAudioUploadedServiceError, NoAudioUploadedServiceError,
@ -40,7 +40,9 @@ class AudioService:
if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
raise ValueError("Speech to text is not enabled") raise ValueError("Speech to text is not enabled")
else: else:
app_model_config: AppModelConfig = app_model.app_model_config app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("Speech to text is not enabled")
if not app_model_config.speech_to_text_dict["enabled"]: if not app_model_config.speech_to_text_dict["enabled"]:
raise ValueError("Speech to text is not enabled") raise ValueError("Speech to text is not enabled")

View File

@ -70,7 +70,7 @@ class BillingService:
return response.json() return response.json()
@staticmethod @staticmethod
def is_tenant_owner_or_admin(current_user): def is_tenant_owner_or_admin(current_user: Account):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
join: Optional[TenantAccountJoin] = ( join: Optional[TenantAccountJoin] = (

View File

@ -8,7 +8,7 @@ import uuid
from collections import Counter from collections import Counter
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
from flask_login import current_user import sqlalchemy as sa
from sqlalchemy import exists, func, select from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -26,6 +26,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs import helper from libs import helper
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account, TenantAccountRole from models.account import Account, TenantAccountRole
from models.dataset import ( from models.dataset import (
AppDatasetJoin, AppDatasetJoin,
@ -498,8 +499,11 @@ class DatasetService:
data: Update data dictionary data: Update data dictionary
filtered_data: Filtered update data to modify filtered_data: Filtered update data to modify
""" """
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
try: try:
model_manager = ModelManager() model_manager = ModelManager()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
embedding_model = model_manager.get_model_instance( embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"], provider=data["embedding_model_provider"],
@ -611,8 +615,12 @@ class DatasetService:
data: Update data dictionary data: Update data dictionary
filtered_data: Filtered update data to modify filtered_data: Filtered update data to modify
""" """
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
model_manager = ModelManager() model_manager = ModelManager()
try: try:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
embedding_model = model_manager.get_model_instance( embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"], provider=data["embedding_model_provider"],
@ -720,6 +728,8 @@ class DatasetService:
@staticmethod @staticmethod
def get_dataset_auto_disable_logs(dataset_id: str): def get_dataset_auto_disable_logs(dataset_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox": if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
return { return {
@ -924,6 +934,8 @@ class DocumentService:
@staticmethod @staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
assert isinstance(current_user, Account)
documents = ( documents = (
db.session.query(Document) db.session.query(Document)
.where( .where(
@ -973,7 +985,7 @@ class DocumentService:
file_ids = [ file_ids = [
document.data_source_info_dict["upload_file_id"] document.data_source_info_dict["upload_file_id"]
for document in documents for document in documents
if document.data_source_type == "upload_file" if document.data_source_type == "upload_file" and document.data_source_info_dict
] ]
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
@ -983,6 +995,8 @@ class DocumentService:
@staticmethod @staticmethod
def rename_document(dataset_id: str, document_id: str, name: str) -> Document: def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise ValueError("Dataset not found.") raise ValueError("Dataset not found.")
@ -1012,6 +1026,7 @@ class DocumentService:
if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
raise DocumentIndexingError() raise DocumentIndexingError()
# update document to be paused # update document to be paused
assert current_user is not None
document.is_paused = True document.is_paused = True
document.paused_by = current_user.id document.paused_by = current_user.id
document.paused_at = naive_utc_now() document.paused_at = naive_utc_now()
@ -1067,8 +1082,9 @@ class DocumentService:
# sync document indexing # sync document indexing
document.indexing_status = "waiting" document.indexing_status = "waiting"
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
data_source_info["mode"] = "scrape" if data_source_info:
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) data_source_info["mode"] = "scrape"
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
@ -1097,6 +1113,9 @@ class DocumentService:
# check doc_form # check doc_form
DatasetService.check_doc_form(dataset, knowledge_config.doc_form) DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
# check document limit # check document limit
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
@ -1433,6 +1452,8 @@ class DocumentService:
@staticmethod @staticmethod
def get_tenant_documents_count(): def get_tenant_documents_count():
assert isinstance(current_user, Account)
documents_count = ( documents_count = (
db.session.query(Document) db.session.query(Document)
.where( .where(
@ -1453,6 +1474,8 @@ class DocumentService:
dataset_process_rule: Optional[DatasetProcessRule] = None, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = "web", created_from: str = "web",
): ):
assert isinstance(current_user, Account)
DatasetService.check_dataset_model_setting(dataset) DatasetService.check_dataset_model_setting(dataset)
document = DocumentService.get_document(dataset.id, document_data.original_document_id) document = DocumentService.get_document(dataset.id, document_data.original_document_id)
if document is None: if document is None:
@ -1512,7 +1535,7 @@ class DocumentService:
data_source_binding = ( data_source_binding = (
db.session.query(DataSourceOauthBinding) db.session.query(DataSourceOauthBinding)
.where( .where(
db.and_( sa.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
@ -1573,6 +1596,9 @@ class DocumentService:
@staticmethod @staticmethod
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
@ -2012,6 +2038,9 @@ class SegmentService:
@classmethod @classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset): def create_segment(cls, args: dict, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
content = args["content"] content = args["content"]
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content) segment_hash = helper.generate_text_hash(content)
@ -2074,6 +2103,9 @@ class SegmentService:
@classmethod @classmethod
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
lock_name = f"multi_add_segment_lock_document_id_{document.id}" lock_name = f"multi_add_segment_lock_document_id_{document.id}"
increment_word_count = 0 increment_word_count = 0
with redis_client.lock(lock_name, timeout=600): with redis_client.lock(lock_name, timeout=600):
@ -2157,6 +2189,9 @@ class SegmentService:
@classmethod @classmethod
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
indexing_cache_key = f"segment_{segment.id}_indexing" indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None: if cache_result is not None:
@ -2348,6 +2383,7 @@ class SegmentService:
@classmethod @classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
segments = ( segments = (
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
.where( .where(
@ -2378,6 +2414,8 @@ class SegmentService:
def update_segments_status( def update_segments_status(
cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document
): ):
assert current_user is not None
# Check if segment_ids is not empty to avoid WHERE false condition # Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0: if not segment_ids or len(segment_ids) == 0:
return return
@ -2440,6 +2478,8 @@ class SegmentService:
def create_child_chunk( def create_child_chunk(
cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset
) -> ChildChunk: ) -> ChildChunk:
assert isinstance(current_user, Account)
lock_name = f"add_child_lock_{segment.id}" lock_name = f"add_child_lock_{segment.id}"
with redis_client.lock(lock_name, timeout=20): with redis_client.lock(lock_name, timeout=20):
index_node_id = str(uuid.uuid4()) index_node_id = str(uuid.uuid4())
@ -2487,6 +2527,8 @@ class SegmentService:
document: Document, document: Document,
dataset: Dataset, dataset: Dataset,
) -> list[ChildChunk]: ) -> list[ChildChunk]:
assert isinstance(current_user, Account)
child_chunks = ( child_chunks = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.where( .where(
@ -2561,6 +2603,8 @@ class SegmentService:
document: Document, document: Document,
dataset: Dataset, dataset: Dataset,
) -> ChildChunk: ) -> ChildChunk:
assert current_user is not None
try: try:
child_chunk.content = content child_chunk.content = content
child_chunk.word_count = len(content) child_chunk.word_count = len(content)
@ -2591,6 +2635,8 @@ class SegmentService:
def get_child_chunks( def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
): ):
assert isinstance(current_user, Account)
query = ( query = (
select(ChildChunk) select(ChildChunk)
.filter_by( .filter_by(

View File

@ -114,8 +114,9 @@ class ExternalDatasetService:
) )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: settings = args.get("settings")
args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key") if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict:
settings["api_key"] = external_knowledge_api.settings_dict.get("api_key")
external_knowledge_api.name = args.get("name") external_knowledge_api.name = args.get("name")
external_knowledge_api.description = args.get("description", "") external_knowledge_api.description = args.get("description", "")

View File

@ -3,7 +3,6 @@ import os
import uuid import uuid
from typing import Any, Literal, Union from typing import Any, Literal, Union
from flask_login import current_user
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from configs import dify_config from configs import dify_config
@ -19,6 +18,7 @@ from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id from libs.helper import extract_tenant_id
from libs.login import current_user
from models.account import Account from models.account import Account
from models.enums import CreatorUserRole from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile from models.model import EndUser, UploadFile
@ -111,6 +111,9 @@ class FileService:
@staticmethod @staticmethod
def upload_text(text: str, text_name: str) -> UploadFile: def upload_text(text: str, text_name: str) -> UploadFile:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if len(text_name) > 200: if len(text_name) > 200:
text_name = text_name[:200] text_name = text_name[:200]
# user uuid as file name # user uuid as file name

View File

@ -226,7 +226,7 @@ class MCPToolManageService:
def update_mcp_provider_credentials( def update_mcp_provider_credentials(
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
): ):
provider_controller = MCPToolProviderController._from_db(mcp_provider) provider_controller = MCPToolProviderController.from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id, tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]

View File

@ -1,10 +1,11 @@
import json import json
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from models.account import Account
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
from services.agent_service import AgentService from services.agent_service import AgentService
@ -21,7 +22,7 @@ class TestAgentService:
patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client, patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client,
patch("services.agent_service.ToolManager") as mock_tool_manager, patch("services.agent_service.ToolManager") as mock_tool_manager,
patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager, patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager,
patch("services.agent_service.current_user") as mock_current_user, patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user,
patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager") as mock_model_manager, patch("services.app_service.ModelManager") as mock_model_manager,

View File

@ -1,9 +1,10 @@
from unittest.mock import patch from unittest.mock import create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from models.account import Account
from models.model import MessageAnnotation from models.model import MessageAnnotation
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
from services.app_service import AppService from services.app_service import AppService
@ -24,7 +25,9 @@ class TestAnnotationService:
patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task,
patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task,
patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task,
patch("services.annotation_service.current_user") as mock_current_user, patch(
"services.annotation_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
): ):
# Setup default mock returns # Setup default mock returns
mock_account_feature_service.get_features.return_value.billing.enabled = False mock_account_feature_service.get_features.return_value.billing.enabled = False

View File

@ -1,9 +1,10 @@
from unittest.mock import patch from unittest.mock import create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
from constants.model_template import default_app_templates from constants.model_template import default_app_templates
from models.account import Account
from models.model import App, Site from models.model import App, Site
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
from services.app_service import AppService from services.app_service import AppService
@ -161,8 +162,13 @@ class TestAppService:
app_service = AppService() app_service = AppService()
created_app = app_service.create_app(tenant.id, app_args, account) created_app = app_service.create_app(tenant.id, app_args, account)
# Get app using the service # Get app using the service - needs current_user mock
retrieved_app = app_service.get_app(created_app) mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
retrieved_app = app_service.get_app(created_app)
# Verify retrieved app matches created app # Verify retrieved app matches created app
assert retrieved_app.id == created_app.id assert retrieved_app.id == created_app.id
@ -406,7 +412,11 @@ class TestAppService:
"use_icon_as_answer_icon": True, "use_icon_as_answer_icon": True,
} }
with patch("flask_login.utils._get_user", return_value=account): mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app(app, update_args) updated_app = app_service.update_app(app, update_args)
# Verify updated fields # Verify updated fields
@ -456,7 +466,11 @@ class TestAppService:
# Update app name # Update app name
new_name = "New App Name" new_name = "New App Name"
with patch("flask_login.utils._get_user", return_value=account): mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_name(app, new_name) updated_app = app_service.update_app_name(app, new_name)
assert updated_app.name == new_name assert updated_app.name == new_name
@ -504,7 +518,11 @@ class TestAppService:
# Update app icon # Update app icon
new_icon = "🌟" new_icon = "🌟"
new_icon_background = "#FFD93D" new_icon_background = "#FFD93D"
with patch("flask_login.utils._get_user", return_value=account): mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) updated_app = app_service.update_app_icon(app, new_icon, new_icon_background)
assert updated_app.icon == new_icon assert updated_app.icon == new_icon
@ -551,13 +569,17 @@ class TestAppService:
original_site_status = app.enable_site original_site_status = app.enable_site
# Update site status to disabled # Update site status to disabled
with patch("flask_login.utils._get_user", return_value=account): mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_site_status(app, False) updated_app = app_service.update_app_site_status(app, False)
assert updated_app.enable_site is False assert updated_app.enable_site is False
assert updated_app.updated_by == account.id assert updated_app.updated_by == account.id
# Update site status back to enabled # Update site status back to enabled
with patch("flask_login.utils._get_user", return_value=account): with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_site_status(updated_app, True) updated_app = app_service.update_app_site_status(updated_app, True)
assert updated_app.enable_site is True assert updated_app.enable_site is True
assert updated_app.updated_by == account.id assert updated_app.updated_by == account.id
@ -602,13 +624,17 @@ class TestAppService:
original_api_status = app.enable_api original_api_status = app.enable_api
# Update API status to disabled # Update API status to disabled
with patch("flask_login.utils._get_user", return_value=account): mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_api_status(app, False) updated_app = app_service.update_app_api_status(app, False)
assert updated_app.enable_api is False assert updated_app.enable_api is False
assert updated_app.updated_by == account.id assert updated_app.updated_by == account.id
# Update API status back to enabled # Update API status back to enabled
with patch("flask_login.utils._get_user", return_value=account): with patch("services.app_service.current_user", mock_current_user):
updated_app = app_service.update_app_api_status(updated_app, True) updated_app = app_service.update_app_api_status(updated_app, True)
assert updated_app.enable_api is True assert updated_app.enable_api is True
assert updated_app.updated_by == account.id assert updated_app.updated_by == account.id

View File

@ -1,6 +1,6 @@
import hashlib import hashlib
from io import BytesIO from io import BytesIO
from unittest.mock import patch from unittest.mock import create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
@ -417,11 +417,12 @@ class TestFileService:
text = "This is a test text content" text = "This is a test text content"
text_name = "test_text.txt" text_name = "test_text.txt"
# Mock current_user # Mock current_user using create_autospec
with patch("services.file_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = str(fake.uuid4()) mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4()) mock_current_user.id = str(fake.uuid4())
with patch("services.file_service.current_user", mock_current_user):
upload_file = FileService.upload_text(text=text, text_name=text_name) upload_file = FileService.upload_text(text=text, text_name=text_name)
assert upload_file is not None assert upload_file is not None
@ -443,11 +444,12 @@ class TestFileService:
text = "test content" text = "test content"
long_name = "a" * 250 # Longer than 200 characters long_name = "a" * 250 # Longer than 200 characters
# Mock current_user # Mock current_user using create_autospec
with patch("services.file_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = str(fake.uuid4()) mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4()) mock_current_user.id = str(fake.uuid4())
with patch("services.file_service.current_user", mock_current_user):
upload_file = FileService.upload_text(text=text, text_name=long_name) upload_file = FileService.upload_text(text=text, text_name=long_name)
# Verify name was truncated # Verify name was truncated
@ -846,11 +848,12 @@ class TestFileService:
text = "" text = ""
text_name = "empty.txt" text_name = "empty.txt"
# Mock current_user # Mock current_user using create_autospec
with patch("services.file_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = str(fake.uuid4()) mock_current_user.current_tenant_id = str(fake.uuid4())
mock_current_user.id = str(fake.uuid4()) mock_current_user.id = str(fake.uuid4())
with patch("services.file_service.current_user", mock_current_user):
upload_file = FileService.upload_text(text=text, text_name=text_name) upload_file = FileService.upload_text(text=text, text_name=text_name)
assert upload_file is not None assert upload_file is not None

View File

@ -1,4 +1,4 @@
from unittest.mock import patch from unittest.mock import create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
@ -17,7 +17,9 @@ class TestMetadataService:
def mock_external_service_dependencies(self): def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies.""" """Mock setup for external service dependencies."""
with ( with (
patch("services.metadata_service.current_user") as mock_current_user, patch(
"services.metadata_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
patch("services.metadata_service.redis_client") as mock_redis_client, patch("services.metadata_service.redis_client") as mock_redis_client,
patch("services.dataset_service.DocumentService") as mock_document_service, patch("services.dataset_service.DocumentService") as mock_document_service,
): ):

View File

@ -1,4 +1,4 @@
from unittest.mock import patch from unittest.mock import create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
@ -17,7 +17,7 @@ class TestTagService:
def mock_external_service_dependencies(self): def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies.""" """Mock setup for external service dependencies."""
with ( with (
patch("services.tag_service.current_user") as mock_current_user, patch("services.tag_service.current_user", create_autospec(Account, instance=True)) as mock_current_user,
): ):
# Setup default mock returns # Setup default mock returns
mock_current_user.current_tenant_id = "test-tenant-id" mock_current_user.current_tenant_id = "test-tenant-id"

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, create_autospec, patch
import pytest import pytest
from faker import Faker from faker import Faker
@ -231,9 +231,10 @@ class TestWebsiteService:
fake = Faker() fake = Faker()
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request # Create API request
api_request = WebsiteCrawlApiRequest( api_request = WebsiteCrawlApiRequest(
provider="firecrawl", provider="firecrawl",
@ -285,9 +286,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request # Create API request
api_request = WebsiteCrawlApiRequest( api_request = WebsiteCrawlApiRequest(
provider="watercrawl", provider="watercrawl",
@ -336,9 +338,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request for single page crawling # Create API request for single page crawling
api_request = WebsiteCrawlApiRequest( api_request = WebsiteCrawlApiRequest(
provider="jinareader", provider="jinareader",
@ -389,9 +392,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request with invalid provider # Create API request with invalid provider
api_request = WebsiteCrawlApiRequest( api_request = WebsiteCrawlApiRequest(
provider="invalid_provider", provider="invalid_provider",
@ -419,9 +423,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request # Create API request
api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123") api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123")
@ -463,9 +468,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request # Create API request
api_request = WebsiteCrawlStatusApiRequest(provider="watercrawl", job_id="watercrawl_job_123") api_request = WebsiteCrawlStatusApiRequest(provider="watercrawl", job_id="watercrawl_job_123")
@ -502,9 +508,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request # Create API request
api_request = WebsiteCrawlStatusApiRequest(provider="jinareader", job_id="jina_job_123") api_request = WebsiteCrawlStatusApiRequest(provider="jinareader", job_id="jina_job_123")
@ -544,9 +551,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request with invalid provider # Create API request with invalid provider
api_request = WebsiteCrawlStatusApiRequest(provider="invalid_provider", job_id="test_job_id_123") api_request = WebsiteCrawlStatusApiRequest(provider="invalid_provider", job_id="test_job_id_123")
@ -569,9 +577,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Mock missing credentials # Mock missing credentials
mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = None mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = None
@ -597,9 +606,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Mock missing API key in config # Mock missing API key in config
mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = { mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = {
"config": {"base_url": "https://api.example.com"} "config": {"base_url": "https://api.example.com"}
@ -995,9 +1005,10 @@ class TestWebsiteService:
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request for sub-page crawling # Create API request for sub-page crawling
api_request = WebsiteCrawlApiRequest( api_request = WebsiteCrawlApiRequest(
provider="jinareader", provider="jinareader",
@ -1054,9 +1065,10 @@ class TestWebsiteService:
mock_external_service_dependencies["requests"].get.return_value = mock_failed_response mock_external_service_dependencies["requests"].get.return_value = mock_failed_response
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request # Create API request
api_request = WebsiteCrawlApiRequest( api_request = WebsiteCrawlApiRequest(
provider="jinareader", provider="jinareader",
@ -1096,9 +1108,10 @@ class TestWebsiteService:
mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance
# Mock current_user for the test # Mock current_user for the test
with patch("services.website_service.current_user") as mock_current_user: mock_current_user = create_autospec(Account, instance=True)
mock_current_user.current_tenant_id = account.current_tenant.id mock_current_user.current_tenant_id = account.current_tenant.id
with patch("services.website_service.current_user", mock_current_user):
# Create API request # Create API request
api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="active_job_123") api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="active_job_123")

File diff suppressed because it is too large Load Diff

View File

@ -154,7 +154,7 @@ class TestEnumText:
TestCase( TestCase(
name="session insert with invalid type", name="session insert with invalid type",
action=lambda s: _session_insert_with_value(s, 1), action=lambda s: _session_insert_with_value(s, 1),
exc_type=TypeError, exc_type=ValueError,
), ),
TestCase( TestCase(
name="insert with invalid value", name="insert with invalid value",
@ -164,7 +164,7 @@ class TestEnumText:
TestCase( TestCase(
name="insert with invalid type", name="insert with invalid type",
action=lambda s: _insert_with_user(s, 1), action=lambda s: _insert_with_user(s, 1),
exc_type=TypeError, exc_type=ValueError,
), ),
] ]
for idx, c in enumerate(cases, 1): for idx, c in enumerate(cases, 1):

View File

@ -2,11 +2,12 @@ import datetime
from typing import Any, Optional from typing import Any, Optional
# Mock redis_client before importing dataset_service # Mock redis_client before importing dataset_service
from unittest.mock import Mock, patch from unittest.mock import Mock, create_autospec, patch
import pytest import pytest
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from models.account import Account
from models.dataset import Dataset, ExternalKnowledgeBindings from models.dataset import Dataset, ExternalKnowledgeBindings
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError from services.errors.account import NoPermissionError
@ -78,7 +79,7 @@ class DatasetUpdateTestDataFactory:
@staticmethod @staticmethod
def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock: def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock:
"""Create a mock current user.""" """Create a mock current user."""
current_user = Mock() current_user = create_autospec(Account, instance=True)
current_user.current_tenant_id = tenant_id current_user.current_tenant_id = tenant_id
return current_user return current_user
@ -135,7 +136,9 @@ class TestDatasetServiceUpdateDataset:
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
) as mock_get_binding, ) as mock_get_binding,
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
patch("services.dataset_service.current_user") as mock_current_user, patch(
"services.dataset_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
): ):
mock_current_user.current_tenant_id = "tenant-123" mock_current_user.current_tenant_id = "tenant-123"
yield { yield {

View File

@ -1,9 +1,10 @@
from unittest.mock import Mock, patch from unittest.mock import Mock, create_autospec, patch
import pytest import pytest
from flask_restx import reqparse from flask_restx import reqparse
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from models.account import Account
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
@ -35,19 +36,21 @@ class TestMetadataBugCompleteValidation:
mock_metadata_args.name = None mock_metadata_args.name = None
mock_metadata_args.type = "string" mock_metadata_args.type = "string"
with patch("services.metadata_service.current_user") as mock_user: mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123" mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456" mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
# Should crash with TypeError # Should crash with TypeError
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args) MetadataService.create_metadata("dataset-123", mock_metadata_args)
# Test update method as well # Test update method as well
with patch("services.metadata_service.current_user") as mock_user: mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123" mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456" mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None) MetadataService.update_metadata_name("dataset-123", "metadata-456", None)

View File

@ -1,8 +1,9 @@
from unittest.mock import Mock, patch from unittest.mock import Mock, create_autospec, patch
import pytest import pytest
from flask_restx import reqparse from flask_restx import reqparse
from models.account import Account
from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.entities.knowledge_entities.knowledge_entities import MetadataArgs
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
@ -24,20 +25,22 @@ class TestMetadataNullableBug:
mock_metadata_args.name = None # This will cause len() to crash mock_metadata_args.name = None # This will cause len() to crash
mock_metadata_args.type = "string" mock_metadata_args.type = "string"
with patch("services.metadata_service.current_user") as mock_user: mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123" mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456" mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
# This should crash with TypeError when calling len(None) # This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args) MetadataService.create_metadata("dataset-123", mock_metadata_args)
def test_metadata_service_update_with_none_name_crashes(self): def test_metadata_service_update_with_none_name_crashes(self):
"""Test that MetadataService.update_metadata_name crashes when name is None.""" """Test that MetadataService.update_metadata_name crashes when name is None."""
with patch("services.metadata_service.current_user") as mock_user: mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123" mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456" mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
# This should crash with TypeError when calling len(None) # This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None) MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
@ -81,10 +84,11 @@ class TestMetadataNullableBug:
mock_metadata_args.name = None # From args["name"] mock_metadata_args.name = None # From args["name"]
mock_metadata_args.type = None # From args["type"] mock_metadata_args.type = None # From args["type"]
with patch("services.metadata_service.current_user") as mock_user: mock_user = create_autospec(Account, instance=True)
mock_user.current_tenant_id = "tenant-123" mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456" mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
# Step 4: Service layer crashes on len(None) # Step 4: Service layer crashes on len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args) MetadataService.create_metadata("dataset-123", mock_metadata_args)

View File

@ -72,6 +72,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
const [showSwitchModal, setShowSwitchModal] = useState<boolean>(false) const [showSwitchModal, setShowSwitchModal] = useState<boolean>(false)
const [showImportDSLModal, setShowImportDSLModal] = useState<boolean>(false) const [showImportDSLModal, setShowImportDSLModal] = useState<boolean>(false)
const [secretEnvList, setSecretEnvList] = useState<EnvironmentVariable[]>([]) const [secretEnvList, setSecretEnvList] = useState<EnvironmentVariable[]>([])
const [showExportWarning, setShowExportWarning] = useState(false)
const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({
name, name,
@ -159,6 +160,14 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
onExport() onExport()
return return
} }
setShowExportWarning(true)
}
const handleConfirmExport = async () => {
if (!appDetail)
return
setShowExportWarning(false)
try { try {
const workflowDraft = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`) const workflowDraft = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`)
const list = (workflowDraft.environment_variables || []).filter(env => env.value_type === 'secret') const list = (workflowDraft.environment_variables || []).filter(env => env.value_type === 'secret')
@ -407,6 +416,16 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
onClose={() => setSecretEnvList([])} onClose={() => setSecretEnvList([])}
/> />
)} )}
{showExportWarning && (
<Confirm
type="info"
isShow={showExportWarning}
title={t('workflow.sidebar.exportWarning')}
content={t('workflow.sidebar.exportWarningDesc')}
onConfirm={handleConfirmExport}
onCancel={() => setShowExportWarning(false)}
/>
)}
</div> </div>
) )
} }

View File

@ -32,6 +32,7 @@ export type ActionButtonProps = {
size?: 'xs' | 's' | 'm' | 'l' | 'xl' size?: 'xs' | 's' | 'm' | 'l' | 'xl'
state?: ActionButtonState state?: ActionButtonState
styleCss?: CSSProperties styleCss?: CSSProperties
ref?: React.Ref<HTMLButtonElement>
} & React.ButtonHTMLAttributes<HTMLButtonElement> & VariantProps<typeof actionButtonVariants> } & React.ButtonHTMLAttributes<HTMLButtonElement> & VariantProps<typeof actionButtonVariants>
function getActionButtonState(state: ActionButtonState) { function getActionButtonState(state: ActionButtonState) {
@ -49,24 +50,22 @@ function getActionButtonState(state: ActionButtonState) {
} }
} }
const ActionButton = React.forwardRef<HTMLButtonElement, ActionButtonProps>( const ActionButton = ({ className, size, state = ActionButtonState.Default, styleCss, children, ref, ...props }: ActionButtonProps) => {
({ className, size, state = ActionButtonState.Default, styleCss, children, ...props }, ref) => { return (
return ( <button
<button type='button'
type='button' className={classNames(
className={classNames( actionButtonVariants({ className, size }),
actionButtonVariants({ className, size }), getActionButtonState(state),
getActionButtonState(state), )}
)} ref={ref}
ref={ref} style={styleCss}
style={styleCss} {...props}
{...props} >
> {children}
{children} </button>
</button> )
) }
},
)
ActionButton.displayName = 'ActionButton' ActionButton.displayName = 'ActionButton'
export default ActionButton export default ActionButton

View File

@ -35,27 +35,26 @@ export type ButtonProps = {
loading?: boolean loading?: boolean
styleCss?: CSSProperties styleCss?: CSSProperties
spinnerClassName?: string spinnerClassName?: string
ref?: React.Ref<HTMLButtonElement>
} & React.ButtonHTMLAttributes<HTMLButtonElement> & VariantProps<typeof buttonVariants> } & React.ButtonHTMLAttributes<HTMLButtonElement> & VariantProps<typeof buttonVariants>
const Button = React.forwardRef<HTMLButtonElement, ButtonProps>( const Button = ({ className, variant, size, destructive, loading, styleCss, children, spinnerClassName, ref, ...props }: ButtonProps) => {
({ className, variant, size, destructive, loading, styleCss, children, spinnerClassName, ...props }, ref) => { return (
return ( <button
<button type='button'
type='button' className={classNames(
className={classNames( buttonVariants({ variant, size, className }),
buttonVariants({ variant, size, className }), destructive && 'btn-destructive',
destructive && 'btn-destructive', )}
)} ref={ref}
ref={ref} style={styleCss}
style={styleCss} {...props}
{...props} >
> {children}
{children} {loading && <Spinner loading={loading} className={classNames('!ml-1 !h-3 !w-3 !border-2 !text-white', spinnerClassName)} />}
{loading && <Spinner loading={loading} className={classNames('!ml-1 !h-3 !w-3 !border-2 !text-white', spinnerClassName)} />} </button>
</button> )
) }
},
)
Button.displayName = 'Button' Button.displayName = 'Button'
export default Button export default Button

View File

@ -30,9 +30,10 @@ export type InputProps = {
wrapperClassName?: string wrapperClassName?: string
styleCss?: CSSProperties styleCss?: CSSProperties
unit?: string unit?: string
ref?: React.Ref<HTMLInputElement>
} & Omit<React.InputHTMLAttributes<HTMLInputElement>, 'size'> & VariantProps<typeof inputVariants> } & Omit<React.InputHTMLAttributes<HTMLInputElement>, 'size'> & VariantProps<typeof inputVariants>
const Input = React.forwardRef<HTMLInputElement, InputProps>(({ const Input = ({
size, size,
disabled, disabled,
destructive, destructive,
@ -46,8 +47,9 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({
placeholder, placeholder,
onChange = noop, onChange = noop,
unit, unit,
ref,
...props ...props
}, ref) => { }: InputProps) => {
const { t } = useTranslation() const { t } = useTranslation()
return ( return (
<div className={cn('relative w-full', wrapperClassName)}> <div className={cn('relative w-full', wrapperClassName)}>
@ -93,7 +95,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({
} }
</div> </div>
) )
}) }
Input.displayName = 'Input' Input.displayName = 'Input'

View File

@ -107,10 +107,13 @@ const initMermaid = () => {
return isMermaidInitialized return isMermaidInitialized
} }
const Flowchart = React.forwardRef((props: { type FlowchartProps = {
PrimitiveCode: string PrimitiveCode: string
theme?: 'light' | 'dark' theme?: 'light' | 'dark'
}, ref) => { ref?: React.Ref<HTMLDivElement>
}
const Flowchart = (props: FlowchartProps) => {
const { t } = useTranslation() const { t } = useTranslation()
const [svgString, setSvgString] = useState<string | null>(null) const [svgString, setSvgString] = useState<string | null>(null)
const [look, setLook] = useState<'classic' | 'handDrawn'>('classic') const [look, setLook] = useState<'classic' | 'handDrawn'>('classic')
@ -490,7 +493,7 @@ const Flowchart = React.forwardRef((props: {
} }
return ( return (
<div ref={ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}> <div ref={props.ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}>
<div className={themeClasses.segmented}> <div className={themeClasses.segmented}>
<div className="msh-segmented-group"> <div className="msh-segmented-group">
<label className="msh-segmented-item m-2 flex w-[200px] items-center space-x-1"> <label className="msh-segmented-item m-2 flex w-[200px] items-center space-x-1">
@ -572,7 +575,7 @@ const Flowchart = React.forwardRef((props: {
)} )}
</div> </div>
) )
}) }
Flowchart.displayName = 'Flowchart' Flowchart.displayName = 'Flowchart'

View File

@ -24,30 +24,29 @@ export type TextareaProps = {
disabled?: boolean disabled?: boolean
destructive?: boolean destructive?: boolean
styleCss?: CSSProperties styleCss?: CSSProperties
ref?: React.Ref<HTMLTextAreaElement>
} & React.TextareaHTMLAttributes<HTMLTextAreaElement> & VariantProps<typeof textareaVariants> } & React.TextareaHTMLAttributes<HTMLTextAreaElement> & VariantProps<typeof textareaVariants>
const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>( const Textarea = ({ className, value, onChange, disabled, size, destructive, styleCss, ref, ...props }: TextareaProps) => {
({ className, value, onChange, disabled, size, destructive, styleCss, ...props }, ref) => { return (
return ( <textarea
<textarea ref={ref}
ref={ref} style={styleCss}
style={styleCss} className={cn(
className={cn( 'min-h-20 w-full appearance-none border border-transparent bg-components-input-bg-normal p-2 text-components-input-text-filled caret-primary-600 outline-none placeholder:text-components-input-text-placeholder hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:border-components-input-border-active focus:bg-components-input-bg-active focus:shadow-xs',
'min-h-20 w-full appearance-none border border-transparent bg-components-input-bg-normal p-2 text-components-input-text-filled caret-primary-600 outline-none placeholder:text-components-input-text-placeholder hover:border-components-input-border-hover hover:bg-components-input-bg-hover focus:border-components-input-border-active focus:bg-components-input-bg-active focus:shadow-xs', textareaVariants({ size }),
textareaVariants({ size }), disabled && 'cursor-not-allowed border-transparent bg-components-input-bg-disabled text-components-input-text-filled-disabled hover:border-transparent hover:bg-components-input-bg-disabled',
disabled && 'cursor-not-allowed border-transparent bg-components-input-bg-disabled text-components-input-text-filled-disabled hover:border-transparent hover:bg-components-input-bg-disabled', destructive && 'border-components-input-border-destructive bg-components-input-bg-destructive text-components-input-text-filled hover:border-components-input-border-destructive hover:bg-components-input-bg-destructive focus:border-components-input-border-destructive focus:bg-components-input-bg-destructive',
destructive && 'border-components-input-border-destructive bg-components-input-bg-destructive text-components-input-text-filled hover:border-components-input-border-destructive hover:bg-components-input-bg-destructive focus:border-components-input-border-destructive focus:bg-components-input-bg-destructive', className,
className, )}
)} value={value ?? ''}
value={value ?? ''} onChange={onChange}
onChange={onChange} disabled={disabled}
disabled={disabled} {...props}
{...props} >
> </textarea>
</textarea> )
) }
},
)
Textarea.displayName = 'Textarea' Textarea.displayName = 'Textarea'
export default Textarea export default Textarea

View File

@ -1,14 +1,14 @@
import type { ComponentProps, FC, ReactNode } from 'react' import type { ComponentProps, FC, ReactNode } from 'react'
import { forwardRef } from 'react'
import classNames from '@/utils/classnames' import classNames from '@/utils/classnames'
export type PreviewContainerProps = ComponentProps<'div'> & { export type PreviewContainerProps = ComponentProps<'div'> & {
header: ReactNode header: ReactNode
mainClassName?: string mainClassName?: string
ref?: React.Ref<HTMLDivElement>
} }
export const PreviewContainer: FC<PreviewContainerProps> = forwardRef((props, ref) => { export const PreviewContainer: FC<PreviewContainerProps> = (props) => {
const { children, className, header, mainClassName, ...rest } = props const { children, className, header, mainClassName, ref, ...rest } = props
return <div className={className}> return <div className={className}>
<div <div
{...rest} {...rest}
@ -25,5 +25,5 @@ export const PreviewContainer: FC<PreviewContainerProps> = forwardRef((props, re
</main> </main>
</div> </div>
</div> </div>
}) }
PreviewContainer.displayName = 'PreviewContainer' PreviewContainer.displayName = 'PreviewContainer'

View File

@ -740,84 +740,6 @@ Workflow applications offers non-session support and is ideal for translation, a
--- ---
<Heading
url='/files/:file_id/preview'
method='GET'
title='File Preview'
name='#file-preview'
/>
<Row>
<Col>
Preview or download uploaded files. This endpoint allows you to access files that have been previously uploaded via the File Upload API.
<i>Files can only be accessed if they belong to messages within the requesting application.</i>
### Path Parameters
- `file_id` (string) Required
The unique identifier of the file to preview, obtained from the File Upload API response.
### Query Parameters
- `as_attachment` (boolean) Optional
Whether to force download the file as an attachment. Default is `false` (preview in browser).
### Response
Returns the file content with appropriate headers for browser display or download.
- `Content-Type` Set based on file mime type
- `Content-Length` File size in bytes (if available)
- `Content-Disposition` Set to "attachment" if `as_attachment=true`
- `Cache-Control` Caching headers for performance
- `Accept-Ranges` Set to "bytes" for audio/video files
### Errors
- 400, `invalid_param`, abnormal parameter input
- 403, `file_access_denied`, file access denied or file does not belong to current application
- 404, `file_not_found`, file not found or has been deleted
- 500, internal server error
</Col>
<Col sticky>
### Request Example
<CodeGroup
title="Request"
tag="GET"
label="/files/:file_id/preview"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview' \\
--header 'Authorization: Bearer {api_key}'`}
/>
### Download as Attachment
<CodeGroup
title="Download Request"
tag="GET"
label="/files/:file_id/preview?as_attachment=true"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview?as_attachment=true' \\
--header 'Authorization: Bearer {api_key}' \\
--output downloaded_file.png`}
/>
### Response Headers Example
<CodeGroup title="Response Headers">
```http {{ title: 'Headers - Image Preview' }}
Content-Type: image/png
Content-Length: 1024
Cache-Control: public, max-age=3600
```
</CodeGroup>
### Download Response Headers
<CodeGroup title="Download Response Headers">
```http {{ title: 'Headers - File Download' }}
Content-Type: image/png
Content-Length: 1024
Content-Disposition: attachment; filename*=UTF-8''example.png
Cache-Control: public, max-age=3600
```
</CodeGroup>
</Col>
</Row>
---
<Heading <Heading
url='/workflows/logs' url='/workflows/logs'
method='GET' method='GET'

View File

@ -736,84 +736,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
--- ---
<Heading
url='/files/:file_id/preview'
method='GET'
title='ファイルプレビュー'
name='#file-preview'
/>
<Row>
<Col>
アップロードされたファイルをプレビューまたはダウンロードします。このエンドポイントを使用すると、以前にファイルアップロード API でアップロードされたファイルにアクセスできます。
<i>ファイルは、リクエストしているアプリケーションのメッセージ範囲内にある場合のみアクセス可能です。</i>
### パスパラメータ
- `file_id` (string) 必須
プレビューするファイルの一意識別子。ファイルアップロード API レスポンスから取得します。
### クエリパラメータ
- `as_attachment` (boolean) オプション
ファイルを添付ファイルとして強制ダウンロードするかどうか。デフォルトは `false`(ブラウザでプレビュー)。
### レスポンス
ブラウザ表示またはダウンロード用の適切なヘッダー付きでファイル内容を返します。
- `Content-Type` ファイル MIME タイプに基づいて設定
- `Content-Length` ファイルサイズ(バイト、利用可能な場合)
- `Content-Disposition` `as_attachment=true` の場合は "attachment" に設定
- `Cache-Control` パフォーマンス向上のためのキャッシュヘッダー
- `Accept-Ranges` 音声/動画ファイルの場合は "bytes" に設定
### エラー
- 400, `invalid_param`, パラメータ入力異常
- 403, `file_access_denied`, ファイルアクセス拒否またはファイルが現在のアプリケーションに属していません
- 404, `file_not_found`, ファイルが見つからないか削除されています
- 500, サーバー内部エラー
</Col>
<Col sticky>
### リクエスト例
<CodeGroup
title="Request"
tag="GET"
label="/files/:file_id/preview"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview' \\
--header 'Authorization: Bearer {api_key}'`}
/>
### 添付ファイルとしてダウンロード
<CodeGroup
title="Download Request"
tag="GET"
label="/files/:file_id/preview?as_attachment=true"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview?as_attachment=true' \\
--header 'Authorization: Bearer {api_key}' \\
--output downloaded_file.png`}
/>
### レスポンスヘッダー例
<CodeGroup title="Response Headers">
```http {{ title: 'ヘッダー - 画像プレビュー' }}
Content-Type: image/png
Content-Length: 1024
Cache-Control: public, max-age=3600
```
</CodeGroup>
### ダウンロードレスポンスヘッダー
<CodeGroup title="Download Response Headers">
```http {{ title: 'ヘッダー - ファイルダウンロード' }}
Content-Type: image/png
Content-Length: 1024
Content-Disposition: attachment; filename*=UTF-8''example.png
Cache-Control: public, max-age=3600
```
</CodeGroup>
</Col>
</Row>
---
<Heading <Heading
url='/workflows/logs' url='/workflows/logs'
method='GET' method='GET'

View File

@ -727,83 +727,6 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等
</Row> </Row>
--- ---
<Heading
url='/files/:file_id/preview'
method='GET'
title='文件预览'
name='#file-preview'
/>
<Row>
<Col>
预览或下载已上传的文件。此端点允许您访问先前通过文件上传 API 上传的文件。
<i>文件只能在属于请求应用程序的消息范围内访问。</i>
### 路径参数
- `file_id` (string) 必需
要预览的文件的唯一标识符,从文件上传 API 响应中获得。
### 查询参数
- `as_attachment` (boolean) 可选
是否强制将文件作为附件下载。默认为 `false`(在浏览器中预览)。
### 响应
返回带有适当浏览器显示或下载标头的文件内容。
- `Content-Type` 根据文件 MIME 类型设置
- `Content-Length` 文件大小(以字节为单位,如果可用)
- `Content-Disposition` 如果 `as_attachment=true` 则设置为 "attachment"
- `Cache-Control` 用于性能的缓存标头
- `Accept-Ranges` 对于音频/视频文件设置为 "bytes"
### 错误
- 400, `invalid_param`, 参数输入异常
- 403, `file_access_denied`, 文件访问被拒绝或文件不属于当前应用程序
- 404, `file_not_found`, 文件未找到或已被删除
- 500, 服务内部错误
</Col>
<Col sticky>
### 请求示例
<CodeGroup
title="Request"
tag="GET"
label="/files/:file_id/preview"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview' \\
--header 'Authorization: Bearer {api_key}'`}
/>
### 作为附件下载
<CodeGroup
title="Request"
tag="GET"
label="/files/:file_id/preview?as_attachment=true"
targetCode={`curl -X GET '${props.appDetail.api_base_url}/files/72fa9618-8f89-4a37-9b33-7e1178a24a67/preview?as_attachment=true' \\
--header 'Authorization: Bearer {api_key}' \\
--output downloaded_file.png`}
/>
### 响应标头示例
<CodeGroup title="Response Headers">
```http {{ title: 'Headers - 图片预览' }}
Content-Type: image/png
Content-Length: 1024
Cache-Control: public, max-age=3600
```
</CodeGroup>
### 文件下载响应标头
<CodeGroup title="Download Response Headers">
```http {{ title: 'Headers - 文件下载' }}
Content-Type: image/png
Content-Length: 1024
Content-Disposition: attachment; filename*=UTF-8''example.png
Cache-Control: public, max-age=3600
```
</CodeGroup>
</Col>
</Row>
---
<Heading <Heading
url='/workflows/logs' url='/workflows/logs'
method='GET' method='GET'

View File

@ -1,5 +1,4 @@
'use client' 'use client'
import type { ForwardRefRenderFunction } from 'react'
import { useImperativeHandle } from 'react' import { useImperativeHandle } from 'react'
import React, { useCallback, useEffect, useMemo, useState } from 'react' import React, { useCallback, useEffect, useMemo, useState } from 'react'
import type { Dependency, GitHubItemAndMarketPlaceDependency, PackageDependency, Plugin, VersionInfo } from '../../../types' import type { Dependency, GitHubItemAndMarketPlaceDependency, PackageDependency, Plugin, VersionInfo } from '../../../types'
@ -21,6 +20,7 @@ type Props = {
onDeSelectAll: () => void onDeSelectAll: () => void
onLoadedAllPlugin: (installedInfo: Record<string, VersionInfo>) => void onLoadedAllPlugin: (installedInfo: Record<string, VersionInfo>) => void
isFromMarketPlace?: boolean isFromMarketPlace?: boolean
ref?: React.Ref<ExposeRefs>
} }
export type ExposeRefs = { export type ExposeRefs = {
@ -28,7 +28,7 @@ export type ExposeRefs = {
deSelectAllPlugins: () => void deSelectAllPlugins: () => void
} }
const InstallByDSLList: ForwardRefRenderFunction<ExposeRefs, Props> = ({ const InstallByDSLList = ({
allPlugins, allPlugins,
selectedPlugins, selectedPlugins,
onSelect, onSelect,
@ -36,7 +36,8 @@ const InstallByDSLList: ForwardRefRenderFunction<ExposeRefs, Props> = ({
onDeSelectAll, onDeSelectAll,
onLoadedAllPlugin, onLoadedAllPlugin,
isFromMarketPlace, isFromMarketPlace,
}, ref) => { ref,
}: Props) => {
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
// DSL has id, to get plugin info to show more info // DSL has id, to get plugin info to show more info
const { isLoading: isFetchingMarketplaceDataById, data: infoGetById, error: infoByIdError } = useFetchPluginsInMarketPlaceByInfo(allPlugins.filter(d => d.type === 'marketplace').map((d) => { const { isLoading: isFetchingMarketplaceDataById, data: infoGetById, error: infoByIdError } = useFetchPluginsInMarketPlaceByInfo(allPlugins.filter(d => d.type === 'marketplace').map((d) => {
@ -268,4 +269,4 @@ const InstallByDSLList: ForwardRefRenderFunction<ExposeRefs, Props> = ({
</> </>
) )
} }
export default React.forwardRef(InstallByDSLList) export default InstallByDSLList

View File

@ -82,9 +82,7 @@ const PluginTypeSwitch = ({
}, [showSearchParams, handleActivePluginTypeChange]) }, [showSearchParams, handleActivePluginTypeChange])
useEffect(() => { useEffect(() => {
window.addEventListener('popstate', () => { window.addEventListener('popstate', handlePopState)
handlePopState()
})
return () => { return () => {
window.removeEventListener('popstate', handlePopState) window.removeEventListener('popstate', handlePopState)
} }

View File

@ -1,5 +1,5 @@
'use client' 'use client'
import React, { forwardRef, useEffect, useImperativeHandle, useMemo, useRef } from 'react' import React, { useEffect, useImperativeHandle, useMemo, useRef } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import useStickyScroll, { ScrollPosition } from '../use-sticky-scroll' import useStickyScroll, { ScrollPosition } from '../use-sticky-scroll'
import Item from './item' import Item from './item'
@ -17,18 +17,20 @@ export type ListProps = {
tags: string[] tags: string[]
toolContentClassName?: string toolContentClassName?: string
disableMaxWidth?: boolean disableMaxWidth?: boolean
ref?: React.Ref<ListRef>
} }
export type ListRef = { handleScroll: () => void } export type ListRef = { handleScroll: () => void }
const List = forwardRef<ListRef, ListProps>(({ const List = ({
wrapElemRef, wrapElemRef,
searchText, searchText,
tags, tags,
list, list,
toolContentClassName, toolContentClassName,
disableMaxWidth = false, disableMaxWidth = false,
}, ref) => { ref,
}: ListProps) => {
const { t } = useTranslation() const { t } = useTranslation()
const hasFilter = !searchText const hasFilter = !searchText
const hasRes = list.length > 0 const hasRes = list.length > 0
@ -125,7 +127,7 @@ const List = forwardRef<ListRef, ListProps>(({
</div> </div>
</> </>
) )
}) }
List.displayName = 'List' List.displayName = 'List'

View File

@ -1003,6 +1003,10 @@ const translation = {
noLastRunFound: 'Kein vorheriger Lauf gefunden', noLastRunFound: 'Kein vorheriger Lauf gefunden',
lastOutput: 'Letzte Ausgabe', lastOutput: 'Letzte Ausgabe',
}, },
sidebar: {
exportWarning: 'Aktuelle gespeicherte Version exportieren',
exportWarningDesc: 'Dies wird die derzeit gespeicherte Version Ihres Workflows exportieren. Wenn Sie ungespeicherte Änderungen im Editor haben, speichern Sie diese bitte zuerst, indem Sie die Exportoption im Workflow-Canvas verwenden.',
},
} }
export default translation export default translation

View File

@ -139,6 +139,10 @@ const translation = {
export: 'Export DSL with secret values ', export: 'Export DSL with secret values ',
}, },
}, },
sidebar: {
exportWarning: 'Export Current Saved Version',
exportWarningDesc: 'This will export the current saved version of your workflow. If you have unsaved changes in the editor, please save them first by using the export option in the workflow canvas.',
},
chatVariable: { chatVariable: {
panelTitle: 'Conversation Variables', panelTitle: 'Conversation Variables',
panelDescription: 'Conversation Variables are used to store interactive information that LLM needs to remember, including conversation history, uploaded files, user preferences. They are read-write. ', panelDescription: 'Conversation Variables are used to store interactive information that LLM needs to remember, including conversation history, uploaded files, user preferences. They are read-write. ',

View File

@ -1003,6 +1003,10 @@ const translation = {
noMatchingInputsFound: 'No se encontraron entradas coincidentes de la última ejecución.', noMatchingInputsFound: 'No se encontraron entradas coincidentes de la última ejecución.',
lastOutput: 'Última salida', lastOutput: 'Última salida',
}, },
sidebar: {
exportWarning: 'Exportar la versión guardada actual',
exportWarningDesc: 'Esto exportará la versión guardada actual de tu flujo de trabajo. Si tienes cambios no guardados en el editor, guárdalos primero utilizando la opción de exportar en el lienzo del flujo de trabajo.',
},
} }
export default translation export default translation

View File

@ -1003,6 +1003,10 @@ const translation = {
copyLastRunError: 'نتوانستم ورودی‌های آخرین اجرای را کپی کنم', copyLastRunError: 'نتوانستم ورودی‌های آخرین اجرای را کپی کنم',
lastOutput: 'آخرین خروجی', lastOutput: 'آخرین خروجی',
}, },
sidebar: {
exportWarning: 'صادرات نسخه ذخیره شده فعلی',
exportWarningDesc: 'این نسخه فعلی ذخیره شده از کار خود را صادر خواهد کرد. اگر تغییرات غیرذخیره شده‌ای در ویرایشگر دارید، لطفاً ابتدا از گزینه صادرات در بوم کار برای ذخیره آنها استفاده کنید.',
},
} }
export default translation export default translation

View File

@ -1003,6 +1003,10 @@ const translation = {
copyLastRunError: 'Échec de la copie des entrées de la dernière exécution', copyLastRunError: 'Échec de la copie des entrées de la dernière exécution',
lastOutput: 'Dernière sortie', lastOutput: 'Dernière sortie',
}, },
sidebar: {
exportWarning: 'Exporter la version enregistrée actuelle',
exportWarningDesc: 'Cela exportera la version actuelle enregistrée de votre flux de travail. Si vous avez des modifications non enregistrées dans l\'éditeur, veuillez d\'abord les enregistrer en utilisant l\'option d\'exportation dans le canevas du flux de travail.',
},
} }
export default translation export default translation

View File

@ -1023,6 +1023,10 @@ const translation = {
copyLastRunError: 'अंतिम रन इनपुट को कॉपी करने में विफल', copyLastRunError: 'अंतिम रन इनपुट को कॉपी करने में विफल',
lastOutput: 'अंतिम आउटपुट', lastOutput: 'अंतिम आउटपुट',
}, },
sidebar: {
exportWarning: 'वर्तमान सहेजी गई संस्करण निर्यात करें',
exportWarningDesc: 'यह आपके कार्यप्रवाह का वर्तमान सहेजा हुआ संस्करण निर्यात करेगा। यदि आपके संपादक में कोई असहेजा किए गए परिवर्तन हैं, तो कृपया पहले उन्हें सहेजें, कार्यप्रवाह कैनवास में निर्यात विकल्प का उपयोग करके।',
},
} }
export default translation export default translation

View File

@ -967,6 +967,10 @@ const translation = {
lastOutput: 'Keluaran Terakhir', lastOutput: 'Keluaran Terakhir',
noLastRunFound: 'Tidak ada eksekusi sebelumnya ditemukan', noLastRunFound: 'Tidak ada eksekusi sebelumnya ditemukan',
}, },
sidebar: {
exportWarning: 'Ekspor Versi Tersimpan Saat Ini',
exportWarningDesc: 'Ini akan mengekspor versi terkini dari alur kerja Anda yang telah disimpan. Jika Anda memiliki perubahan yang belum disimpan di editor, harap simpan terlebih dahulu dengan menggunakan opsi ekspor di kanvas alur kerja.',
},
} }
export default translation export default translation

View File

@ -1029,6 +1029,10 @@ const translation = {
noLastRunFound: 'Nessuna esecuzione precedente trovata', noLastRunFound: 'Nessuna esecuzione precedente trovata',
lastOutput: 'Ultimo output', lastOutput: 'Ultimo output',
}, },
sidebar: {
exportWarning: 'Esporta la versione salvata corrente',
exportWarningDesc: 'Questo exporterà l\'attuale versione salvata del tuo flusso di lavoro. Se hai modifiche non salvate nell\'editor, ti preghiamo di salvarle prima utilizzando l\'opzione di esportazione nel canvas del flusso di lavoro.',
},
} }
export default translation export default translation

View File

@ -139,6 +139,10 @@ const translation = {
export: 'シークレット値付きでエクスポート', export: 'シークレット値付きでエクスポート',
}, },
}, },
sidebar: {
exportWarning: '現在保存されているバージョンをエクスポート',
exportWarningDesc: 'これは現在保存されているワークフローのバージョンをエクスポートします。エディターで未保存の変更がある場合は、まずワークフローキャンバスのエクスポートオプションを使用して保存してください。',
},
chatVariable: { chatVariable: {
panelTitle: '会話変数', panelTitle: '会話変数',
panelDescription: '対話情報を保存・管理(会話履歴/ファイル/ユーザー設定など)。書き換えができます。', panelDescription: '対話情報を保存・管理(会話履歴/ファイル/ユーザー設定など)。書き換えができます。',

View File

@ -1054,6 +1054,10 @@ const translation = {
copyLastRunError: '마지막 실행 입력을 복사하는 데 실패했습니다.', copyLastRunError: '마지막 실행 입력을 복사하는 데 실패했습니다.',
lastOutput: '마지막 출력', lastOutput: '마지막 출력',
}, },
sidebar: {
exportWarning: '현재 저장된 버전 내보내기',
exportWarningDesc: '이 작업은 현재 저장된 워크플로우 버전을 내보냅니다. 편집기에서 저장되지 않은 변경 사항이 있는 경우, 먼저 워크플로우 캔버스의 내보내기 옵션을 사용하여 저장해 주세요.',
},
} }
export default translation export default translation

View File

@ -1003,6 +1003,10 @@ const translation = {
copyLastRunError: 'Nie udało się skopiować danych wejściowych z ostatniego uruchomienia', copyLastRunError: 'Nie udało się skopiować danych wejściowych z ostatniego uruchomienia',
lastOutput: 'Ostatni wynik', lastOutput: 'Ostatni wynik',
}, },
sidebar: {
exportWarning: 'Eksportuj obecną zapisaną wersję',
exportWarningDesc: 'To wyeksportuje aktualnie zapisaną wersję twojego przepływu pracy. Jeśli masz niesave\'owane zmiany w edytorze, najpierw je zapisz, korzystając z opcji eksportu w kanwie przepływu pracy.',
},
} }
export default translation export default translation

View File

@ -1003,6 +1003,10 @@ const translation = {
copyLastRun: 'Copiar Última Execução', copyLastRun: 'Copiar Última Execução',
lastOutput: 'Última Saída', lastOutput: 'Última Saída',
}, },
sidebar: {
exportWarning: 'Exportar a versão salva atual',
exportWarningDesc: 'Isto irá exportar a versão atual salva do seu fluxo de trabalho. Se você tiver alterações não salvas no editor, por favor, salve-as primeiro utilizando a opção de exportação na tela do fluxo de trabalho.',
},
} }
export default translation export default translation

View File

@ -1003,6 +1003,10 @@ const translation = {
copyLastRunError: 'Nu s-au putut copia ultimele intrări de rulare', copyLastRunError: 'Nu s-au putut copia ultimele intrări de rulare',
lastOutput: 'Ultimul rezultat', lastOutput: 'Ultimul rezultat',
}, },
sidebar: {
exportWarning: 'Exportați versiunea salvată curentă',
exportWarningDesc: 'Aceasta va exporta versiunea curent salvată a fluxului dumneavoastră de lucru. Dacă aveți modificări nesalvate în editor, vă rugăm să le salvați mai întâi utilizând opțiunea de export din canvasul fluxului de lucru.',
},
} }
export default translation export default translation

View File

@ -1003,6 +1003,10 @@ const translation = {
noMatchingInputsFound: 'Не найдено соответствующих входных данных из последнего запуска.', noMatchingInputsFound: 'Не найдено соответствующих входных данных из последнего запуска.',
lastOutput: 'Последний вывод', lastOutput: 'Последний вывод',
}, },
sidebar: {
exportWarning: 'Экспортировать текущую сохранённую версию',
exportWarningDesc: 'Это экспортирует текущую сохранённую версию вашего рабочего процесса. Если у вас есть несохранённые изменения в редакторе, сначала сохраните их с помощью опции экспорта на полотне рабочего процесса.',
},
} }
export default translation export default translation

View File

@ -1003,6 +1003,10 @@ const translation = {
noMatchingInputsFound: 'Ni podatkov, ki bi ustrezali prejšnjemu zagonu', noMatchingInputsFound: 'Ni podatkov, ki bi ustrezali prejšnjemu zagonu',
lastOutput: 'Nazadnje izhod', lastOutput: 'Nazadnje izhod',
}, },
sidebar: {
exportWarning: 'Izvozi trenutna shranjena različica',
exportWarningDesc: 'To bo izvozilo trenutno shranjeno različico vašega delovnega toka. Če imate neshranjene spremembe v urejevalniku, jih najprej shranite z uporabo možnosti izvoza na platnu delovnega toka.',
},
} }
export default translation export default translation

View File

@ -1003,6 +1003,10 @@ const translation = {
noMatchingInputsFound: 'ไม่พบข้อมูลที่ตรงกันจากการรันครั้งล่าสุด', noMatchingInputsFound: 'ไม่พบข้อมูลที่ตรงกันจากการรันครั้งล่าสุด',
lastOutput: 'ผลลัพธ์สุดท้าย', lastOutput: 'ผลลัพธ์สุดท้าย',
}, },
sidebar: {
exportWarning: 'ส่งออกเวอร์ชันที่บันทึกปัจจุบัน',
exportWarningDesc: 'นี่จะส่งออกเวอร์ชันที่บันทึกไว้ปัจจุบันของเวิร์กโฟลว์ของคุณ หากคุณมีการเปลี่ยนแปลงที่ยังไม่ได้บันทึกในแก้ไข กรุณาบันทึกมันก่อนโดยใช้ตัวเลือกส่งออกในผืนผ้าใบเวิร์กโฟลว์',
},
} }
export default translation export default translation

View File

@ -1004,6 +1004,10 @@ const translation = {
copyLastRunError: 'Son çalışma girdilerini kopyalamak başarısız oldu.', copyLastRunError: 'Son çalışma girdilerini kopyalamak başarısız oldu.',
lastOutput: 'Son Çıktı', lastOutput: 'Son Çıktı',
}, },
sidebar: {
exportWarning: 'Mevcut Kaydedilmiş Versiyonu Dışa Aktar',
exportWarningDesc: 'Bu, çalışma akışınızın mevcut kaydedilmiş sürümünü dışa aktaracaktır. Editörde kaydedilmemiş değişiklikleriniz varsa, lütfen önce bunları çalışma akışı alanındaki dışa aktarma seçeneğini kullanarak kaydedin.',
},
} }
export default translation export default translation

View File

@ -1003,6 +1003,10 @@ const translation = {
noMatchingInputsFound: 'Не знайдено відповідних вхідних даних з останнього запуску', noMatchingInputsFound: 'Не знайдено відповідних вхідних даних з останнього запуску',
lastOutput: 'Останній вихід', lastOutput: 'Останній вихід',
}, },
sidebar: {
exportWarning: 'Експортувати поточну збережену версію',
exportWarningDesc: 'Це експортує поточну збережену версію вашого робочого процесу. Якщо у вас є незбережені зміни в редакторі, будь ласка, спочатку збережіть їх, використовуючи опцію експорту на полотні робочого процесу.',
},
} }
export default translation export default translation

View File

@ -1003,6 +1003,10 @@ const translation = {
copyLastRunError: 'Không thể sao chép đầu vào của lần chạy trước', copyLastRunError: 'Không thể sao chép đầu vào của lần chạy trước',
lastOutput: 'Đầu ra cuối cùng', lastOutput: 'Đầu ra cuối cùng',
}, },
sidebar: {
exportWarning: 'Xuất Phiên Bản Đã Lưu Hiện Tại',
exportWarningDesc: 'Điều này sẽ xuất phiên bản hiện tại đã được lưu của quy trình làm việc của bạn. Nếu bạn có những thay đổi chưa được lưu trong trình soạn thảo, vui lòng lưu chúng trước bằng cách sử dụng tùy chọn xuất trong bản vẽ quy trình.',
},
} }
export default translation export default translation

View File

@ -139,6 +139,10 @@ const translation = {
export: '导出包含 Secret 值的 DSL', export: '导出包含 Secret 值的 DSL',
}, },
}, },
sidebar: {
exportWarning: '导出当前已保存版本',
exportWarningDesc: '这将导出您工作流的当前已保存版本。如果您在编辑器中有未保存的更改,请先使用工作流画布中的导出选项保存它们。',
},
chatVariable: { chatVariable: {
panelTitle: '会话变量', panelTitle: '会话变量',
panelDescription: '会话变量用于存储 LLM 需要的上下文信息,如用户偏好、对话历史等。它是可读写的。', panelDescription: '会话变量用于存储 LLM 需要的上下文信息,如用户偏好、对话历史等。它是可读写的。',

View File

@ -1003,6 +1003,10 @@ const translation = {
noLastRunFound: '沒有找到之前的運行', noLastRunFound: '沒有找到之前的運行',
lastOutput: '最後的輸出', lastOutput: '最後的輸出',
}, },
sidebar: {
exportWarning: '導出當前保存的版本',
exportWarningDesc: '這將導出當前保存的工作流程版本。如果您在編輯器中有未保存的更改,請先通過使用工作流程畫布中的導出選項來保存它們。',
},
} }
export default translation export default translation

View File

@ -1 +0,0 @@
(()=>{"use strict";self.fallback=async e=>"document"===e.destination?caches.match("/_offline.html",{ignoreSearch:!0}):Response.error()})();

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,84 @@
import { DataType } from '@/app/components/datasets/metadata/types'
import { act, renderHook } from '@testing-library/react'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { useBatchUpdateDocMetadata } from '@/service/knowledge/use-metadata'
import { useDocumentListKey } from './use-document'
// Mock the post function to avoid real network requests
jest.mock('@/service/base', () => ({
post: jest.fn().mockResolvedValue({ success: true }),
}))
const NAME_SPACE = 'dataset-metadata'
describe('useBatchUpdateDocMetadata', () => {
let queryClient: QueryClient
beforeEach(() => {
// Create a fresh QueryClient before each test
queryClient = new QueryClient()
})
// Wrapper for React Query context
const wrapper = ({ children }: { children: React.ReactNode }) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
)
it('should correctly invalidate dataset and document caches', async () => {
const { result } = renderHook(() => useBatchUpdateDocMetadata(), { wrapper })
// Spy on queryClient.invalidateQueries
const invalidateSpy = jest.spyOn(queryClient, 'invalidateQueries')
// Correct payload type: each document has its own metadata_list array
const payload = {
dataset_id: 'dataset-1',
metadata_list: [
{
document_id: 'doc-1',
metadata_list: [
{ key: 'title-1', id: '01', name: 'name-1', type: DataType.string, value: 'new title 01' },
],
},
{
document_id: 'doc-2',
metadata_list: [
{ key: 'title-2', id: '02', name: 'name-1', type: DataType.string, value: 'new title 02' },
],
},
],
}
// Execute the mutation
await act(async () => {
await result.current.mutateAsync(payload)
})
// Expect invalidateQueries to have been called exactly 5 times
expect(invalidateSpy).toHaveBeenCalledTimes(5)
// Dataset cache invalidation
expect(invalidateSpy).toHaveBeenNthCalledWith(1, {
queryKey: [NAME_SPACE, 'dataset', 'dataset-1'],
})
// Document list cache invalidation
expect(invalidateSpy).toHaveBeenNthCalledWith(2, {
queryKey: [NAME_SPACE, 'document', 'dataset-1'],
})
// useDocumentListKey cache invalidation
expect(invalidateSpy).toHaveBeenNthCalledWith(3, {
queryKey: [...useDocumentListKey, 'dataset-1'],
})
// Single document cache invalidation
expect(invalidateSpy.mock.calls.slice(3)).toEqual(
expect.arrayContaining([
[{ queryKey: [NAME_SPACE, 'document', 'dataset-1', 'doc-1'] }],
[{ queryKey: [NAME_SPACE, 'document', 'dataset-1', 'doc-2'] }],
]),
)
})
})

View File

@ -119,7 +119,7 @@ export const useBatchUpdateDocMetadata = () => {
}) })
// meta data in document list // meta data in document list
await queryClient.invalidateQueries({ await queryClient.invalidateQueries({
queryKey: [NAME_SPACE, 'dataset', payload.dataset_id], queryKey: [NAME_SPACE, 'document', payload.dataset_id],
}) })
await queryClient.invalidateQueries({ await queryClient.invalidateQueries({
queryKey: [...useDocumentListKey, payload.dataset_id], queryKey: [...useDocumentListKey, payload.dataset_id],