Merge branch 'main' into feat/r2

# Conflicts:
#	docker/docker-compose.middleware.yaml
This commit is contained in:
jyong 2025-06-11 17:16:27 +08:00
commit 1d8b390584
304 changed files with 7257 additions and 4518 deletions

View File

@ -8,7 +8,7 @@ inputs:
uv-version:
description: UV version to set up
required: true
default: '0.6.14'
default: '~=0.7.11'
uv-lockfile:
description: Path to the UV lockfile to restore cache from
required: true

7
.gitignore vendored
View File

@ -192,12 +192,12 @@ sdks/python-client/dist
sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json
!.vscode/launch.json.template
!.vscode/README.md
pyrightconfig.json
api/.vscode
.idea/
.vscode
# pnpm
/.pnpm-store
@ -207,3 +207,6 @@ plugins.jsonl
# mise
mise.toml
# Next.js build output
.next/

14
.vscode/README.md vendored Normal file
View File

@ -0,0 +1,14 @@
# Debugging with VS Code
This `launch.json.template` file provides various debug configurations for the Dify project within VS Code / Cursor. To use these configurations, you should copy the contents of this file into a new file named `launch.json` in the same `.vscode` directory.
## How to Use
1. **Create `launch.json`**: If you don't have one, create a file named `launch.json` inside the `.vscode` directory.
2. **Copy Content**: Copy the entire content from `launch.json.template` into your newly created `launch.json` file.
3. **Select Debug Configuration**: Go to the Run and Debug view in VS Code / Cursor (Ctrl+Shift+D or Cmd+Shift+D).
4. **Start Debugging**: Select the desired configuration from the dropdown menu and click the green play button.
## Tips
- If you need to debug with Edge browser instead of Chrome, modify the `serverReadyAction` configuration in the "Next.js: debug full stack" section, change `"debugWithChrome"` to `"debugWithEdge"` to use Microsoft Edge for debugging.

68
.vscode/launch.json.template vendored Normal file
View File

@ -0,0 +1,68 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Python: Flask API",
"type": "debugpy",
"request": "launch",
"module": "flask",
"env": {
"FLASK_APP": "app.py",
"FLASK_ENV": "development",
"GEVENT_SUPPORT": "True"
},
"args": [
"run",
"--host=0.0.0.0",
"--port=5001",
"--no-debugger",
"--no-reload"
],
"jinja": true,
"justMyCode": true,
"cwd": "${workspaceFolder}/api",
"python": "${workspaceFolder}/api/.venv/bin/python"
},
{
"name": "Python: Celery Worker (Solo)",
"type": "debugpy",
"request": "launch",
"module": "celery",
"env": {
"GEVENT_SUPPORT": "True"
},
"args": [
"-A",
"app.celery",
"worker",
"-P",
"solo",
"-c",
"1",
"-Q",
"dataset,generation,mail,ops_trace",
"--loglevel",
"INFO"
],
"justMyCode": false,
"cwd": "${workspaceFolder}/api",
"python": "${workspaceFolder}/api/.venv/bin/python"
},
{
"name": "Next.js: debug full stack",
"type": "node",
"request": "launch",
"program": "${workspaceFolder}/web/node_modules/next/dist/bin/next",
"runtimeArgs": ["--inspect"],
"skipFiles": ["<node_internals>/**"],
"serverReadyAction": {
"action": "debugWithChrome",
"killOnServerStop": true,
"pattern": "- Local:.+(https?://.+)",
"uriFormat": "%s",
"webRoot": "${workspaceFolder}/web"
},
"cwd": "${workspaceFolder}/web"
}
]
}

View File

@ -491,3 +491,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
# Prevent Clickjacking
ALLOW_EMBED=false
# Dataset queue monitor configuration
QUEUE_MONITOR_THRESHOLD=200
# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
QUEUE_MONITOR_ALERT_EMAILS=
# Monitor interval in minutes, default is 30 minutes
QUEUE_MONITOR_INTERVAL=30

View File

@ -43,6 +43,7 @@ select = [
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage
]
ignore = [

View File

@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api
# Install uv
ENV UV_VERSION=0.6.14
ENV UV_VERSION=0.7.11
RUN pip install --no-cache-dir uv==${UV_VERSION}

View File

@ -2,7 +2,7 @@ import os
from typing import Any, Literal, Optional
from urllib.parse import parse_qsl, quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from .cache.redis_config import RedisConfig
@ -256,6 +256,25 @@ class InternalTestConfig(BaseSettings):
)
class DatasetQueueMonitorConfig(BaseSettings):
"""
Configuration settings for Dataset Queue Monitor
"""
QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field(
description="Threshold for dataset queue monitor",
default=200,
)
QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field(
description="Emails for dataset queue monitor alert, separated by commas",
default=None,
)
QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field(
description="Interval for dataset queue monitor in minutes",
default=30,
)
class MiddlewareConfig(
# place the configs in alphabet order
CeleryConfig,
@ -303,5 +322,6 @@ class MiddlewareConfig(
BaiduVectorDBConfig,
OpenGaussConfig,
TableStoreConfig,
DatasetQueueMonitorConfig,
):
pass

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="1.4.1",
default="1.4.2",
)
COMMIT_SHA: str = Field(

View File

@ -208,7 +208,7 @@ class AnnotationBatchImportApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith(".csv"):
if not file.filename or not file.filename.endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
return AppAnnotationService.batch_import_app_annotations(app_id, file)

View File

@ -119,9 +119,6 @@ class ForgotPasswordResetApi(Resource):
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
if reset_data.get("phase", "") != "reset":
raise InvalidTokenError()
# Must use token in reset phase
if reset_data.get("phase", "") != "reset":
raise InvalidTokenError()

View File

@ -374,7 +374,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith(".csv"):
if not file.filename or not file.filename.endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:

View File

@ -59,7 +59,14 @@ class InstalledAppsListApi(Resource):
if FeatureService.get_system_features().webapp_auth.enabled:
user_id = current_user.id
res = []
app_ids = [installed_app["app"].id for installed_app in installed_app_list]
webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids)
for installed_app in installed_app_list:
webapp_setting = webapp_settings.get(installed_app["app"].id)
if not webapp_setting:
continue
if webapp_setting.access_mode == "sso_verified":
continue
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id,

View File

@ -44,6 +44,17 @@ def only_edition_cloud(view):
return decorated
def only_edition_enterprise(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.ENTERPRISE_ENABLED:
abort(404)
return view(*args, **kwargs)
return decorated
def only_edition_self_hosted(view):
@wraps(view)
def decorated(*args, **kwargs):

View File

@ -29,7 +29,7 @@ from core.plugin.entities.request import (
RequestRequestUploadFile,
)
from core.tools.entities.tool_entities import ToolProviderType
from libs.helper import compact_generate_response
from libs.helper import length_prefixed_response
from models.account import Account, Tenant
from models.model import EndUser
@ -44,7 +44,7 @@ class PluginInvokeLLMApi(Resource):
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
return compact_generate_response(generator())
return length_prefixed_response(0xF, generator())
class PluginInvokeTextEmbeddingApi(Resource):
@ -101,7 +101,7 @@ class PluginInvokeTTSApi(Resource):
)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
return compact_generate_response(generator())
return length_prefixed_response(0xF, generator())
class PluginInvokeSpeech2TextApi(Resource):
@ -162,7 +162,7 @@ class PluginInvokeToolApi(Resource):
),
)
return compact_generate_response(generator())
return length_prefixed_response(0xF, generator())
class PluginInvokeParameterExtractorNodeApi(Resource):
@ -228,7 +228,7 @@ class PluginInvokeAppApi(Resource):
files=payload.files,
)
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response))
class PluginInvokeEncryptApi(Resource):

View File

@ -32,6 +32,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
)
session.add(user_model)
session.commit()
session.refresh(user_model)
else:
user_model = AccountService.load_user(user_id)
if not user_model:

View File

@ -369,6 +369,7 @@ class DatasetTagsApi(DatasetApiResource):
)
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
args = parser.parse_args()
args["type"] = "knowledge"
tag = TagService.update_tags(args, args.get("tag_id"))
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))

View File

@ -175,8 +175,11 @@ class DocumentAddByFileApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args.get("indexing_technique"):
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
if not indexing_technique:
raise ValueError("indexing_technique is required.")
args["indexing_technique"] = indexing_technique
# save file info
file = request.files["file"]
@ -206,12 +209,16 @@ class DocumentAddByFileApi(DatasetApiResource):
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
if not knowledge_config.original_document_id and not dataset_process_rule and not knowledge_config.process_rule:
raise ValueError("process_rule is required.")
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
dataset_process_rule=dataset_process_rule,
created_from="api",
)
except ProviderTokenNotInitError as ex:

View File

@ -15,4 +15,17 @@ api.add_resource(FileApi, "/files/upload")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow
from . import (
app,
audio,
completion,
conversation,
feature,
forgot_password,
login,
message,
passport,
saved_message,
site,
workflow,
)

View File

@ -10,6 +10,8 @@ from libs.passport import PassportService
from models.model import App, AppMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
class AppParameterApi(WebApiResource):
@ -46,10 +48,22 @@ class AppMeta(WebApiResource):
class AppAccessMode(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("appId", type=str, required=True, location="args")
parser.add_argument("appId", type=str, required=False, location="args")
parser.add_argument("appCode", type=str, required=False, location="args")
args = parser.parse_args()
app_id = args["appId"]
features = FeatureService.get_system_features()
if not features.webapp_auth.enabled:
return {"accessMode": "public"}
app_id = args.get("appId")
if args.get("appCode"):
app_code = args["appCode"]
app_id = AppService.get_app_id_by_code(app_code)
if not app_id:
raise ValueError("appId or appCode must be provided")
res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
return {"accessMode": res.access_mode}
@ -75,6 +89,10 @@ class AppWebAuthPermission(Resource):
except Exception as e:
pass
features = FeatureService.get_system_features()
if not features.webapp_auth.enabled:
return {"result": True}
parser = reqparse.RequestParser()
parser.add_argument("appId", type=str, required=True, location="args")
args = parser.parse_args()
@ -82,7 +100,9 @@ class AppWebAuthPermission(Resource):
app_id = args["appId"]
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
res = True
if WebAppAuthService.is_app_require_permission_check(app_id=app_id):
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
return {"result": res}

View File

@ -0,0 +1,147 @@
import base64
import secrets
from flask import request
from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console.auth.error import (
EmailCodeError,
EmailPasswordResetLimitError,
InvalidEmailError,
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import api
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService
class ForgotPasswordSendEmailApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
token = None
if account is None:
raise AccountNotFound()
else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
return {"result": "success", "data": token}
class ForgotPasswordCheckApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
user_email = args["email"]
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"])
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"])
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args["email"])
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
class ForgotPasswordResetApi(Resource):
@only_edition_enterprise
@setup_required
@email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
args = parser.parse_args()
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
raise PasswordMismatchError()
# Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"])
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
if reset_data.get("phase", "") != "reset":
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"])
# Generate secure salt and hash password
salt = secrets.token_bytes(16)
password_hashed = hash_password(args["new_password"], salt)
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account:
self._update_existing_account(account, password_hashed, salt, session)
else:
raise AccountNotFound()
return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")

View File

@ -1,12 +1,11 @@
from flask import request
from flask_restful import Resource, reqparse
from jwt import InvalidTokenError # type: ignore
from werkzeug.exceptions import BadRequest
import services
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
from controllers.console.error import AccountBannedError, AccountNotFound
from controllers.console.wraps import setup_required
from controllers.console.wraps import only_edition_enterprise, setup_required
from controllers.web import api
from libs.helper import email
from libs.password import valid_password
from services.account_service import AccountService
@ -16,6 +15,8 @@ from services.webapp_auth_service import WebAppAuthService
class LoginApi(Resource):
"""Resource for web app email/password login."""
@setup_required
@only_edition_enterprise
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
@ -23,10 +24,6 @@ class LoginApi(Resource):
parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args()
app_code = request.headers.get("X-App-Code")
if app_code is None:
raise BadRequest("X-App-Code header is missing.")
try:
account = WebAppAuthService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError:
@ -36,12 +33,8 @@ class LoginApi(Resource):
except services.errors.account.AccountNotFoundError:
raise AccountNotFound()
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code)
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
return {"result": "success", "token": token}
token = WebAppAuthService.login(account=account)
return {"result": "success", "data": {"access_token": token}}
# class LogoutApi(Resource):
@ -56,6 +49,7 @@ class LoginApi(Resource):
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
@only_edition_enterprise
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
@ -78,6 +72,7 @@ class EmailCodeLoginSendEmailApi(Resource):
class EmailCodeLoginApi(Resource):
@setup_required
@only_edition_enterprise
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
@ -86,9 +81,6 @@ class EmailCodeLoginApi(Resource):
args = parser.parse_args()
user_email = args["email"]
app_code = request.headers.get("X-App-Code")
if app_code is None:
raise BadRequest("X-App-Code header is missing.")
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None:
@ -105,16 +97,12 @@ class EmailCodeLoginApi(Resource):
if not account:
raise AccountNotFound()
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code)
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
token = WebAppAuthService.login(account=account)
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "token": token}
return {"result": "success", "data": {"access_token": token}}
# api.add_resource(LoginApi, "/login")
api.add_resource(LoginApi, "/login")
# api.add_resource(LogoutApi, "/logout")
# api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
# api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")

View File

@ -1,9 +1,11 @@
import uuid
from datetime import UTC, datetime, timedelta
from flask import request
from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from controllers.web import api
from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db
@ -11,6 +13,7 @@ from libs.passport import PassportService
from models.model import App, EndUser, Site
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
class PassportResource(Resource):
@ -20,10 +23,19 @@ class PassportResource(Resource):
system_features = FeatureService.get_system_features()
app_code = request.headers.get("X-App-Code")
user_id = request.args.get("user_id")
web_app_access_token = request.args.get("web_app_access_token")
if app_code is None:
raise Unauthorized("X-App-Code header is missing.")
# exchange token for enterprise logined web user
enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token)
if enterprise_user_decoded:
# a web user has already logged in, exchange a token for this app without redirecting to the login page
return exchange_token_for_existing_web_user(
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
)
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
if not app_settings or not app_settings.access_mode == "public":
@ -84,6 +96,128 @@ class PassportResource(Resource):
api.add_resource(PassportResource, "/passport")
def decode_enterprise_webapp_user_id(jwt_token: str | None):
"""
Decode the enterprise user session from the Authorization header.
"""
if not jwt_token:
return None
decoded = PassportService().verify(jwt_token)
source = decoded.get("token_source")
if not source or source != "webapp_login_token":
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
"""
Exchange a token for an existing web user session.
"""
user_id = enterprise_user_decoded.get("user_id")
end_user_id = enterprise_user_decoded.get("end_user_id")
session_id = enterprise_user_decoded.get("session_id")
user_auth_type = enterprise_user_decoded.get("auth_type")
if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.")
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
if not site:
raise NotFound()
app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
if app_auth_type == WebAppAuthType.PUBLIC:
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
raise WebAppAuthRequiredError("Please login as external user.")
elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
raise WebAppAuthRequiredError("Please login as internal user.")
end_user = None
if end_user_id:
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
if session_id:
end_user = (
db.session.query(EndUser)
.filter(
EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
)
.first()
)
if not end_user:
if not session_id:
raise NotFound("Missing session_id for existing web user.")
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="browser",
is_anonymous=True,
session_id=session_id,
)
db.session.add(end_user)
db.session.commit()
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
exp = int(exp_dt.timestamp())
payload = {
"iss": site.id,
"sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"user_id": user_id,
"end_user_id": end_user.id,
"auth_type": user_auth_type,
"granted_at": int(datetime.now(UTC).timestamp()),
"token_source": "webapp",
"exp": exp,
}
token: str = PassportService().issue(payload)
return {
"access_token": token,
}
def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id")
end_user = None
if user_id:
end_user = (
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
)
if not end_user:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type="browser",
is_anonymous=True,
session_id=generate_session_id(),
)
db.session.add(end_user)
db.session.commit()
payload = {
"iss": site.app_id,
"sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"end_user_id": end_user.id,
}
tk = PassportService().issue(payload)
return {
"access_token": tk,
}
def generate_session_id():
"""
Generate a unique session ID.

View File

@ -1,3 +1,4 @@
from datetime import UTC, datetime
from functools import wraps
from flask import request
@ -8,8 +9,9 @@ from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequire
from extensions.ext_database import db
from libs.passport import PassportService
from models.model import App, EndUser, Site
from services.enterprise.enterprise_service import EnterpriseService
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
def validate_jwt_token(view=None):
@ -45,7 +47,8 @@ def decode_jwt_token():
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first()
app_id = decoded.get("app_id")
app_model = db.session.query(App).filter(App.id == app_id).first()
site = db.session.query(Site).filter(Site.code == app_code).first()
if not app_model:
raise NotFound()
@ -53,23 +56,30 @@ def decode_jwt_token():
raise BadRequest("Site URL is no longer valid.")
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
end_user_id = decoded.get("end_user_id")
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
raise NotFound()
# for enterprise webapp auth
app_web_auth_enabled = False
webapp_settings = None
if system_features.webapp_auth.enabled:
app_web_auth_enabled = (
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
)
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
if not webapp_settings:
raise NotFound("Web app settings not found.")
app_web_auth_enabled = webapp_settings.access_mode != "public"
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
_validate_user_accessibility(
decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings
)
return app_model, end_user
except Unauthorized as e:
if system_features.webapp_auth.enabled:
if not app_code:
raise Unauthorized("Please re-login to access the web app.")
app_web_auth_enabled = (
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public"
)
@ -95,15 +105,41 @@ def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_au
raise Unauthorized("webapp token expired.")
def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
def _validate_user_accessibility(
decoded,
app_code,
app_web_auth_enabled: bool,
system_webapp_auth_enabled: bool,
webapp_settings: WebAppSettings | None,
):
if system_webapp_auth_enabled and app_web_auth_enabled:
# Check if the user is allowed to access the web app
user_id = decoded.get("user_id")
if not user_id:
raise WebAppAuthRequiredError()
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
raise WebAppAuthAccessDeniedError()
if not webapp_settings:
raise WebAppAuthRequiredError("Web app settings not found.")
if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode):
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
raise WebAppAuthAccessDeniedError()
auth_type = decoded.get("auth_type")
granted_at = decoded.get("granted_at")
if not auth_type:
raise WebAppAuthAccessDeniedError("Missing auth_type in the token.")
if not granted_at:
raise WebAppAuthAccessDeniedError("Missing granted_at in the token.")
# check if sso has been updated
if auth_type == "external":
last_update_time = EnterpriseService.get_app_sso_settings_last_update_time()
if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time:
raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.")
elif auth_type == "internal":
last_update_time = EnterpriseService.get_workspace_sso_settings_last_update_time()
if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time:
raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.")
class WebApiResource(Resource):

View File

@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel):
status: ModelStatus
load_balancing_enabled: bool = False
def raise_for_status(self) -> None:
"""
Check model status and raise ValueError if not active.
:raises ValueError: When model status is not active, with a descriptive message
"""
if self.status == ModelStatus.ACTIVE:
return
error_messages = {
ModelStatus.NO_CONFIGURE: "Model is not configured",
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
ModelStatus.NO_PERMISSION: "No permission to use this model",
ModelStatus.DISABLED: "Model is disabled",
}
if self.status in error_messages:
raise ValueError(error_messages[self.status])
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
"""

View File

@ -41,45 +41,53 @@ class Extensible:
extensions = []
position_map: dict[str, int] = {}
# get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
current_dir_path = os.path.dirname(current_path)
# Get the package name from the module path
package_name = ".".join(cls.__module__.split(".")[:-1])
# traverse subdirectories
for subdir_name in os.listdir(current_dir_path):
if subdir_name.startswith("__"):
continue
try:
# Get package directory path
package_spec = importlib.util.find_spec(package_name)
if not package_spec or not package_spec.origin:
raise ImportError(f"Could not find package {package_name}")
subdir_path = os.path.join(current_dir_path, subdir_name)
extension_name = subdir_name
if os.path.isdir(subdir_path):
package_dir = os.path.dirname(package_spec.origin)
# Traverse subdirectories
for subdir_name in os.listdir(package_dir):
if subdir_name.startswith("__"):
continue
subdir_path = os.path.join(package_dir, subdir_name)
if not os.path.isdir(subdir_path):
continue
extension_name = subdir_name
file_names = os.listdir(subdir_path)
# is builtin extension, builtin extension
# in the front-end page and business logic, there are special treatments.
# Check for extension module file
if (extension_name + ".py") not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
continue
# Check for builtin flag and position
builtin = False
# default position is 0 can not be None for sort_to_dict_by_position_map
position = 0
if "__builtin__" in file_names:
builtin = True
builtin_file_path = os.path.join(subdir_path, "__builtin__")
if os.path.exists(builtin_file_path):
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
position_map[extension_name] = position
if (extension_name + ".py") not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
continue
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + ".py")
spec = importlib.util.spec_from_file_location(extension_name, py_path)
# Import the extension module
module_name = f"{package_name}.{extension_name}.{extension_name}"
spec = importlib.util.find_spec(module_name)
if not spec or not spec.loader:
raise Exception(f"Failed to load module {extension_name} from {py_path}")
raise ImportError(f"Failed to load module {module_name}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
# Find extension class
extension_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
@ -87,21 +95,21 @@ class Extensible:
break
if not extension_class:
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
continue
# Load schema if not builtin
json_data: dict[str, Any] = {}
if not builtin:
if "schema.json" not in file_names:
json_path = os.path.join(subdir_path, "schema.json")
if not os.path.exists(json_path):
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
continue
json_path = os.path.join(subdir_path, "schema.json")
json_data = {}
if os.path.exists(json_path):
with open(json_path, encoding="utf-8") as f:
json_data = json.load(f)
with open(json_path, encoding="utf-8") as f:
json_data = json.load(f)
# Create extension
extensions.append(
ModuleExtension(
extension_class=extension_class,
@ -113,6 +121,11 @@ class Extensible:
)
)
except Exception as e:
logging.exception("Error scanning extensions")
raise
# Sort extensions by position
sorted_extensions = sort_to_dict_by_position_map(
position_map=position_map, data=extensions, name_func=lambda x: x.name
)

View File

@ -1,5 +1,5 @@
import logging
import random
import secrets
from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@ -38,7 +38,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
if len(text_chunks) == 0:
return True
text_chunk = random.choice(text_chunks)
text_chunk = secrets.choice(text_chunks)
try:
model_provider_factory = ModelProviderFactory(tenant_id)

View File

@ -160,6 +160,10 @@ class ProviderModel(BaseModel):
deprecated: bool = False
model_config = ConfigDict(protected_namespaces=())
@property
def support_structure_output(self) -> bool:
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
class ParameterRule(BaseModel):
"""

View File

@ -98,6 +98,7 @@ class WeaveConfig(BaseTracingConfig):
entity: str | None = None
project: str
endpoint: str = "https://trace.wandb.ai"
host: str | None = None
@field_validator("endpoint")
@classmethod
@ -109,6 +110,14 @@ class WeaveConfig(BaseTracingConfig):
return v
@field_validator("host")
@classmethod
def validate_host(cls, v, info: ValidationInfo):
if v is not None and v != "":
if not v.startswith(("https://", "http://")):
raise ValueError("host must start with https:// or http://")
return v
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@ -81,7 +81,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
return {
"config_class": WeaveConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint"],
"other_keys": ["project", "entity", "endpoint", "host"],
"trace_instance": WeaveDataTrace,
}

View File

@ -40,9 +40,14 @@ class WeaveDataTrace(BaseTraceInstance):
self.weave_api_key = weave_config.api_key
self.project_name = weave_config.project
self.entity = weave_config.entity
self.host = weave_config.host
# Login with API key first, including host if provided
if self.host:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
else:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
# Login with API key first
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
if not login_status:
logger.error("Failed to login to Weights & Biases with the provided API key")
raise ValueError("Weave login failed")
@ -386,7 +391,11 @@ class WeaveDataTrace(BaseTraceInstance):
def api_check(self):
try:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
if self.host:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
else:
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
if not login_status:
raise ValueError("Weave login failed")
else:

View File

@ -11,14 +11,12 @@ class BaseBackwardsInvocation:
try:
for chunk in response:
if isinstance(chunk, BaseModel | dict):
yield BaseBackwardsInvocationResponse(data=chunk).model_dump_json().encode() + b"\n\n"
elif isinstance(chunk, str):
yield f"event: {chunk}\n\n".encode()
yield BaseBackwardsInvocationResponse(data=chunk).model_dump_json().encode()
except Exception as e:
error_message = BaseBackwardsInvocationResponse(error=str(e)).model_dump_json()
yield f"{error_message}\n\n".encode()
yield error_message.encode()
else:
yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode() + b"\n\n"
yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode()
T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel)

View File

@ -21,7 +21,7 @@ from core.plugin.entities.request import (
)
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm import llm_utils
from models.account import Tenant
@ -55,7 +55,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response:
if chunk.delta.usage:
LLMNode.deduct_llm_quota(
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
chunk.prompt_messages = []
@ -64,7 +64,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
return handle()
else:
if response.usage:
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(

View File

@ -3,7 +3,9 @@ from collections import defaultdict
from json import JSONDecodeError
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
@ -393,19 +395,13 @@ class ProviderManager:
@staticmethod
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
"""
Get all provider records of the workspace.
:param tenant_id: workspace id
:return:
"""
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
provider_name_to_provider_records_dict = defaultdict(list)
for provider in providers:
# TODO: Use provider name with prefix after the data migration
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
providers = session.scalars(stmt)
for provider in providers:
# Use provider name with prefix after the data migration
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
return provider_name_to_provider_records_dict
@staticmethod
@ -416,17 +412,12 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
# Get all provider model records of the workspace
provider_models = (
db.session.query(ProviderModel)
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
.all()
)
provider_name_to_provider_model_records_dict = defaultdict(list)
for provider_model in provider_models:
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
provider_models = session.scalars(stmt)
for provider_model in provider_models:
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
return provider_name_to_provider_model_records_dict
@staticmethod
@ -437,17 +428,14 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
preferred_provider_types = (
db.session.query(TenantPreferredModelProvider)
.filter(TenantPreferredModelProvider.tenant_id == tenant_id)
.all()
)
provider_name_to_preferred_provider_type_records_dict = {
preferred_provider_type.provider_name: preferred_provider_type
for preferred_provider_type in preferred_provider_types
}
provider_name_to_preferred_provider_type_records_dict = {}
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
preferred_provider_types = session.scalars(stmt)
provider_name_to_preferred_provider_type_records_dict = {
preferred_provider_type.provider_name: preferred_provider_type
for preferred_provider_type in preferred_provider_types
}
return provider_name_to_preferred_provider_type_records_dict
@staticmethod
@ -458,18 +446,14 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
provider_model_settings = (
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
)
provider_name_to_provider_model_settings_dict = defaultdict(list)
for provider_model_setting in provider_model_settings:
(
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
provider_model_settings = session.scalars(stmt)
for provider_model_setting in provider_model_settings:
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
provider_model_setting
)
)
return provider_name_to_provider_model_settings_dict
@staticmethod
@ -492,15 +476,14 @@ class ProviderManager:
if not model_load_balancing_enabled:
return {}
provider_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
)
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
for provider_load_balancing_config in provider_load_balancing_configs:
provider_name_to_provider_load_balancing_model_configs_dict[
provider_load_balancing_config.provider_name
].append(provider_load_balancing_config)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
provider_load_balancing_configs = session.scalars(stmt)
for provider_load_balancing_config in provider_load_balancing_configs:
provider_name_to_provider_load_balancing_model_configs_dict[
provider_load_balancing_config.provider_name
].append(provider_load_balancing_config)
return provider_name_to_provider_load_balancing_model_configs_dict
@ -626,10 +609,9 @@ class ProviderManager:
if not cached_provider_credentials:
try:
# fix origin data
if (
custom_provider_record.encrypted_config
and not custom_provider_record.encrypted_config.startswith("{")
):
if custom_provider_record.encrypted_config is None:
raise ValueError("No credentials found")
if not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
@ -733,7 +715,7 @@ class ProviderManager:
return SystemConfiguration(enabled=False)
# Convert provider_records to dict
quota_type_to_provider_records_dict = {}
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value:
continue
@ -758,6 +740,11 @@ class ProviderManager:
else:
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
if provider_record.quota_used is None:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
@ -791,10 +778,9 @@ class ProviderManager:
cached_provider_credentials = provider_credentials_cache.get()
if not cached_provider_credentials:
try:
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
provider_credentials: dict[str, Any] = {}
if provider_records and provider_records[0].encrypted_config:
provider_credentials = json.loads(provider_records[0].encrypted_config)
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(

View File

@ -720,7 +720,7 @@ STOPWORDS = {
"",
"",
"",
" ",
" ",
"0",
"1",
"2",
@ -731,16 +731,6 @@ STOPWORDS = {
"7",
"8",
"9",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",

View File

@ -184,7 +184,16 @@ class OpenSearchVector(BaseVector):
}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
query["query"] = {
"script_score": {
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID.value: document_ids_filter}}]}},
"script": {
"source": "knn_score",
"lang": "knn",
"params": {"field": Field.VECTOR.value, "query_value": query_vector, "space_type": "l2"},
},
}
}
try:
response = self._client.search(index=self._collection_name.lower(), body=query)
@ -209,10 +218,10 @@ class OpenSearchVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
full_text_query["query"]["bool"]["filter"] = [{"terms": {"metadata.document_id": document_ids_filter}}]
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
@ -255,7 +264,8 @@ class OpenSearchVector(BaseVector):
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
"document_id": {"type": "keyword"},
},
},
}

View File

@ -261,7 +261,7 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名ns: 地名,nt: 机构名
current_entity += word
else:
if current_entity:
@ -303,7 +303,6 @@ class OracleVector(BaseVector):
return docs
else:
return [Document(page_content="", metadata={})]
return []
def delete(self) -> None:
with self._get_connection() as conn:

View File

@ -139,4 +139,4 @@ class CacheEmbedding(Embeddings):
logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'")
raise ex
return embedding_results
return embedding_results # type: ignore

View File

@ -105,7 +105,7 @@ class QAIndexProcessor(BaseIndexProcessor):
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
# check file type
if not file.filename.endswith(".csv"):
if not file.filename or not file.filename.endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:

View File

@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm import LLMNode
from core.workflow.nodes.llm import llm_utils
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@ -165,7 +165,7 @@ class ReactMultiDatasetRouter:
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
return text, usage

View File

@ -1,3 +1,4 @@
- audio
- code
- time
- qrcode
- webscraper

View File

@ -153,8 +153,6 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
return str("\n".join(document_context_list))
return ""
raise RuntimeError("not segments found")
def _retriever(
self,
flask_app: Flask,

View File

@ -32,14 +32,14 @@ class ToolFileMessageTransformer:
try:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
tool_file_manager = ToolFileManager()
file = tool_file_manager.create_file_by_url(
tool_file = tool_file_manager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
conversation_id=conversation_id,
)
url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}"
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
@ -68,7 +68,7 @@ class ToolFileMessageTransformer:
assert isinstance(message.message.blob, bytes)
tool_file_manager = ToolFileManager()
file = tool_file_manager.create_file_by_raw(
tool_file = tool_file_manager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
@ -77,7 +77,7 @@ class ToolFileMessageTransformer:
filename=filename,
)
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype))
# check if file is image
if "image" in mimetype:

View File

@ -397,19 +397,44 @@ def _extract_text_from_csv(file_content: bytes) -> str:
if not rows:
return ""
# Create Markdown table
markdown_table = "| " + " | ".join(rows[0]) + " |\n"
markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n"
for row in rows[1:]:
markdown_table += "| " + " | ".join(row) + " |\n"
# Combine multi-line text in the header row
header_row = [cell.replace("\n", " ").replace("\r", "") for cell in rows[0]]
return markdown_table.strip()
# Create Markdown table
markdown_table = "| " + " | ".join(header_row) + " |\n"
markdown_table += "| " + " | ".join(["-" * len(col) for col in rows[0]]) + " |\n"
# Process each data row and combine multi-line text in each cell
for row in rows[1:]:
processed_row = [cell.replace("\n", " ").replace("\r", "") for cell in row]
markdown_table += "| " + " | ".join(processed_row) + " |\n"
return markdown_table
except Exception as e:
raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e
def _extract_text_from_excel(file_content: bytes) -> str:
"""Extract text from an Excel file using pandas."""
def _construct_markdown_table(df: pd.DataFrame) -> str:
"""Manually construct a Markdown table from a DataFrame."""
# Construct the header row
header_row = "| " + " | ".join(df.columns) + " |"
# Construct the separator row
separator_row = "| " + " | ".join(["-" * len(col) for col in df.columns]) + " |"
# Construct the data rows
data_rows = []
for _, row in df.iterrows():
data_row = "| " + " | ".join(map(str, row)) + " |"
data_rows.append(data_row)
# Combine all rows into a single string
markdown_table = "\n".join([header_row, separator_row] + data_rows)
return markdown_table
try:
excel_file = pd.ExcelFile(io.BytesIO(file_content))
markdown_table = ""
@ -417,8 +442,15 @@ def _extract_text_from_excel(file_content: bytes) -> str:
try:
df = excel_file.parse(sheet_name=sheet_name)
df.dropna(how="all", inplace=True)
# Create Markdown table two times to separate tables with a newline
markdown_table += df.to_markdown(index=False, floatfmt="") + "\n\n"
# Combine multi-line text in each cell into a single line
df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore
# Combine multi-line text in column names into a single line
df.columns = pd.Index([" ".join(col.splitlines()) for col in df.columns])
# Manually construct the Markdown table
markdown_table += _construct_markdown_table(df) + "\n\n"
except Exception as e:
continue
return markdown_table

View File

@ -1,8 +1,9 @@
import base64
import json
import secrets
import string
from collections.abc import Mapping
from copy import deepcopy
from random import randint
from typing import Any, Literal
from urllib.parse import urlencode, urlparse
@ -434,4 +435,4 @@ def _generate_random_string(n: int) -> str:
>>> _generate_random_string(5)
'abcde'
"""
return "".join([chr(randint(97, 122)) for _ in range(n)])
return "".join(secrets.choice(string.ascii_lowercase) for _ in range(n))

View File

@ -128,3 +128,12 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property
def structured_output_enabled(self) -> bool:
# NOTE(QuantumGhost): Temporary workaround for issue #20725
# (https://github.com/langgenius/dify/issues/20725).
#
# The proper fix would be to make `KnowledgeRetrievalNode` inherit
# from `BaseNode` instead of `LLMNode`.
return False

View File

@ -86,31 +86,31 @@ class KnowledgeRetrievalNode(LLMNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
)
# TODO(-LAN-): Move this check outside.
# check rate limit
if self.tenant_id:
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
with Session(db.engine) as session:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
session.commit()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
with Session(db.engine) as session:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
session.commit()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
)
# retrieve knowledge
try:

View File

@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData):
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None
structured_output_enabled: bool = False
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
@field_validator("prompt_config", mode="before")
@classmethod
@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData):
if v is None:
return PromptConfig()
return v
@property
def structured_output_enabled(self) -> bool:
return self.structured_output_switch_on and self.structured_output is not None

View File

@ -0,0 +1,156 @@
from collections.abc import Sequence
from datetime import UTC, datetime
from typing import Optional, cast
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import QuotaUnit
from core.file.models import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
from models import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
def fetch_model_config(
tenant_id: str, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
model = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=node_data_model.provider,
model=node_data_model.name,
)
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
# check model
provider_model = model.provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
# model config
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
return model, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=model.provider_model_bundle,
credentials=model.credentials,
parameters=node_data_model.completion_params,
stop=stop,
)
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
variable = variable_pool.get(selector)
if variable is None:
return []
elif isinstance(variable, FileSegment):
return [variable.value]
elif isinstance(variable, ArrayFileSegment):
return variable.value
elif isinstance(variable, NoneSegment | ArrayAnySegment):
return []
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
if not node_data_memory:
return None
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
)
)
session.execute(stmt)
session.commit()

View File

@ -3,18 +3,11 @@ import io
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast
import json_repair
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory
@ -42,12 +35,10 @@ from core.model_runtime.entities.model_entities import (
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.variables import (
ArrayAnySegment,
ArrayFileSegment,
ArraySegment,
FileSegment,
@ -74,14 +65,11 @@ from core.workflow.nodes.event import (
from core.workflow.utils.structured_output.entities import (
ResponseFormat,
SpecialModelType,
SupportStructuredOutputStatus,
)
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from . import llm_utils
from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
@ -91,7 +79,6 @@ from .entities import (
from .exc import (
InvalidContextStructureError,
InvalidVariableTypeError,
LLMModeRequiredError,
LLMNodeError,
MemoryRolePrefixRequiredError,
ModelNotExistError,
@ -163,6 +150,7 @@ class LLMNode(BaseNode[LLMNodeData]):
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
variable_pool = self.graph_runtime_state.variable_pool
try:
# init messages template
@ -181,7 +169,10 @@ class LLMNode(BaseNode[LLMNodeData]):
# fetch files
files = (
self._fetch_files(selector=self.node_data.vision.configs.variable_selector)
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=self.node_data.vision.configs.variable_selector,
)
if self.node_data.vision.enabled
else []
)
@ -203,15 +194,18 @@ class LLMNode(BaseNode[LLMNodeData]):
model_instance, model_config = self._fetch_model_config(self.node_data.model)
# fetch memory
memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance)
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
)
query = None
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
if not query and (
query_variable := self.graph_runtime_state.variable_pool.get(
(SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)
)
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
query = query_variable.text
@ -225,7 +219,7 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
)
@ -254,7 +248,7 @@ class LLMNode(BaseNode[LLMNodeData]):
usage = event.usage
finish_reason = event.finish_reason
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
structured_output = process_structured_output(result_text)
@ -277,7 +271,7 @@ class LLMNode(BaseNode[LLMNodeData]):
llm_usage=usage,
)
)
except LLMNodeError as e:
except ValueError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -450,18 +444,6 @@ class LLMNode(BaseNode[LLMNodeData]):
return inputs
def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is None:
return []
elif isinstance(variable, FileSegment):
return [variable.value]
elif isinstance(variable, ArrayFileSegment):
return variable.value
elif isinstance(variable, NoneSegment | ArrayAnySegment):
return []
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled:
return
@ -527,91 +509,23 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model_name = node_data_model.name
provider_name = node_data_model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
model, model_config_with_cred = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model
)
completion_params = model_config_with_cred.parameters
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = node_data_model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = node_data_model.mode
if not model_mode:
raise LLMModeRequiredError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
raise ModelNotExistError(f"Model {model_name} not exist.")
support_structured_output = self._check_model_structured_output_support()
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
# Set appropriate response format based on model capabilities
self._set_response_format(completion_params, model_schema.parameter_rules)
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
def _fetch_memory(
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
if not node_data_memory:
return None
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
if self.node_data.structured_output_enabled:
if model_schema.support_structure_output:
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
else:
# Set appropriate response format based on model capabilities
self._set_response_format(completion_params, model_schema.parameter_rules)
model_config_with_cred.parameters = completion_params
return model, model_config_with_cred
def _fetch_prompt_messages(
self,
@ -786,13 +700,25 @@ class LLMNode(BaseNode[LLMNodeData]):
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
support_structured_output = self._check_model_structured_output_support()
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
filtered_prompt_messages = self._handle_prompt_based_schema(
prompt_messages=filtered_prompt_messages,
)
stop = model_config.stop
return filtered_prompt_messages, stop
model = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.model,
)
model_schema = model.model_type_instance.get_model_schema(
model=model_config.model,
credentials=model.credentials,
)
if not model_schema:
raise ModelNotExistError(f"Model {model_config.model} not exist.")
if self.node_data.structured_output_enabled:
if not model_schema.support_structure_output:
filtered_prompt_messages = self._handle_prompt_based_schema(
prompt_messages=filtered_prompt_messages,
)
return filtered_prompt_messages, model_config.stop
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
structured_output: dict[str, Any] = {}
@ -813,55 +739,6 @@ class LLMNode(BaseNode[LLMNodeData]):
structured_output = parsed
return structured_output
@classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
)
)
session.execute(stmt)
session.commit()
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@ -903,7 +780,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_mapping["#context#"] = node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value]
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
if node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
@ -1185,32 +1062,6 @@ class LLMNode(BaseNode[LLMNodeData]):
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
"""
Check if the current model supports structured output.
Returns:
SupportStructuredOutput: The support status of structured output
"""
# Early return if structured output is disabled
if (
not isinstance(self.node_data, LLMNodeData)
or not self.node_data.structured_output_enabled
or not self.node_data.structured_output
):
return SupportStructuredOutputStatus.DISABLED
# Get model schema and check if it exists
model_schema = self._fetch_model_schema(self.node_data.model.provider)
if not model_schema:
return SupportStructuredOutputStatus.DISABLED
# Check if model supports structured output feature
return (
SupportStructuredOutputStatus.SUPPORTED
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
else SupportStructuredOutputStatus.UNSUPPORTED
)
def _save_multimodal_output_and_convert_result_to_markdown(
self,
contents: str | list[PromptMessageContentUnionTypes] | None,

View File

@ -28,8 +28,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import LLMNode, ModelConfig
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser
from .entities import ParameterExtractorNodeData
@ -83,7 +84,7 @@ def extract_json(text):
return None
class ParameterExtractorNode(LLMNode):
class ParameterExtractorNode(BaseNode):
"""
Parameter Extractor Node.
"""
@ -116,8 +117,11 @@ class ParameterExtractorNode(LLMNode):
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""
variable_pool = self.graph_runtime_state.variable_pool
files = (
self._fetch_files(
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled
@ -137,7 +141,9 @@ class ParameterExtractorNode(LLMNode):
raise ModelSchemaNotFoundError("Model schema not found")
# fetch memory
memory = self._fetch_memory(
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
@ -279,7 +285,7 @@ class ParameterExtractorNode(LLMNode):
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
if text is None:
text = ""
@ -794,7 +800,9 @@ class ParameterExtractorNode(LLMNode):
Fetch model config.
"""
if not self._model_instance or not self._model_config:
self._model_instance, self._model_config = super()._fetch_model_config(node_data_model)
self._model_instance, self._model_config = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model
)
return self._model_instance, self._model_config

View File

@ -19,3 +19,12 @@ class QuestionClassifierNodeData(BaseNodeData):
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
vision: VisionConfig = Field(default_factory=VisionConfig)
@property
def structured_output_enabled(self) -> bool:
# NOTE(QuantumGhost): Temporary workaround for issue #20725
# (https://github.com/langgenius/dify/issues/20725).
#
# The proper fix would be to make `QuestionClassifierNode` inherit
# from `BaseNode` instead of `LLMNode`.
return False

View File

@ -18,6 +18,7 @@ from core.workflow.nodes.llm import (
LLMNode,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
@ -50,7 +51,9 @@ class QuestionClassifierNode(LLMNode):
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
@ -59,7 +62,8 @@ class QuestionClassifierNode(LLMNode):
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
files = (
self._fetch_files(
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled

View File

@ -14,11 +14,3 @@ class SpecialModelType(StrEnum):
GEMINI = "gemini"
OLLAMA = "ollama"
class SupportStructuredOutputStatus(StrEnum):
"""Constants for structured output support status"""
SUPPORTED = "supported"
UNSUPPORTED = "unsupported"
DISABLED = "disabled"

View File

@ -70,6 +70,7 @@ def init_app(app: DifyApp) -> Celery:
"schedule.update_tidb_serverless_status_task",
"schedule.clean_messages",
"schedule.mail_clean_document_notify_task",
"schedule.queue_monitor_task",
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = {
@ -98,6 +99,12 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
"schedule": crontab(minute="0", hour="10", day_of_week="1"),
},
"datasets-queue-monitor": {
"task": "schedule.queue_monitor_task.queue_monitor_task",
"schedule": timedelta(
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
),
},
}
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)

View File

@ -57,6 +57,9 @@ def load_user_from_request(request_from_flask_login):
raise Unauthorized("Invalid Authorization token.")
decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")
source = decoded.get("token_source")
if source:
raise Unauthorized("Invalid Authorization token.")
if not user_id:
raise Unauthorized("Invalid Authorization token.")

View File

@ -1,8 +1,9 @@
import json
import logging
import random
import re
import secrets
import string
import struct
import subprocess
import time
import uuid
@ -14,10 +15,12 @@ from zoneinfo import available_timezones
from flask import Response, stream_with_context
from flask_restful import fields
from pydantic import BaseModel
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.file import helpers as file_helpers
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_redis import redis_client
if TYPE_CHECKING:
@ -175,14 +178,14 @@ def generate_string(n):
letters_digits = string.ascii_letters + string.digits
result = ""
for i in range(n):
result += random.choice(letters_digits)
result += secrets.choice(letters_digits)
return result
def extract_remote_ip(request) -> str:
if request.headers.get("CF-Connecting-IP"):
return cast(str, request.headers.get("Cf-Connecting-Ip"))
return cast(str, request.headers.get("CF-Connecting-IP"))
elif request.headers.getlist("X-Forwarded-For"):
return cast(str, request.headers.getlist("X-Forwarded-For")[0])
else:
@ -196,7 +199,7 @@ def generate_text_hash(text: str) -> str:
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype="application/json")
return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json")
else:
def generate() -> Generator:
@ -205,6 +208,60 @@ def compact_generate_response(response: Union[Mapping, Generator, RateLimitGener
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
"""
This function is used to return a response with a length prefix.
Magic number is a one byte number that indicates the type of the response.
For a compatibility with latest plugin daemon https://github.com/langgenius/dify-plugin-daemon/pull/341
Avoid using line-based response, it leads a memory issue.
We uses following format:
| Field | Size | Description |
|---------------|----------|---------------------------------|
| Magic Number | 1 byte | Magic number identifier |
| Reserved | 1 byte | Reserved field |
| Header Length | 2 bytes | Header length (usually 0xa) |
| Data Length | 4 bytes | Length of the data |
| Reserved | 6 bytes | Reserved fields |
| Data | Variable | Actual data content |
| Reserved Fields | Header | Data |
|-----------------|----------|----------|
| 4 bytes total | Variable | Variable |
all data is in little endian
"""
def pack_response_with_length_prefix(response: bytes) -> bytes:
header_length = 0xA
data_length = len(response)
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
if isinstance(response, dict):
return Response(
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
status=200,
mimetype="application/json",
)
elif isinstance(response, BaseModel):
return Response(
response=pack_response_with_length_prefix(response.model_dump_json().encode("utf-8")),
status=200,
mimetype="application/json",
)
def generate() -> Generator:
for chunk in response:
if isinstance(chunk, str):
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
else:
yield pack_response_with_length_prefix(chunk)
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
class TokenManager:
@classmethod
def generate_token(

View File

@ -0,0 +1,60 @@
"""`workflow_draft_varaibles` add `node_execution_id` column, add an index for `workflow_node_executions`.
Revision ID: 4474872b0ee6
Revises: 2adcbe1f5dfb
Create Date: 2025-06-06 14:24:44.213018
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4474872b0ee6'
down_revision = '2adcbe1f5dfb'
branch_labels = None
depends_on = None
def upgrade():
# `CREATE INDEX CONCURRENTLY` cannot run within a transaction, so use the `autocommit_block`
# context manager to wrap the index creation statement.
# Reference:
#
# - https://www.postgresql.org/docs/current/sql-createindex.html#:~:text=Another%20difference%20is,CREATE%20INDEX%20CONCURRENTLY%20cannot.
# - https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.migration.MigrationContext.autocommit_block
with op.get_context().autocommit_block():
op.create_index(
op.f('workflow_node_executions_tenant_id_idx'),
"workflow_node_executions",
['tenant_id', 'workflow_id', 'node_id', sa.literal_column('created_at DESC')],
unique=False,
postgresql_concurrently=True,
)
with op.batch_alter_table('workflow_draft_variables', schema=None) as batch_op:
batch_op.add_column(sa.Column('node_execution_id', models.types.StringUUID(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
# `DROP INDEX CONCURRENTLY` cannot run within a transaction, so use the `autocommit_block`
# context manager to wrap the index creation statement.
# Reference:
#
# - https://www.postgresql.org/docs/current/sql-createindex.html#:~:text=Another%20difference%20is,CREATE%20INDEX%20CONCURRENTLY%20cannot.
# - https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.migration.MigrationContext.autocommit_block
# `DROP INDEX CONCURRENTLY` cannot run within a transaction, so commit existing transactions first.
# Reference:
#
# https://www.postgresql.org/docs/current/sql-createindex.html#:~:text=Another%20difference%20is,CREATE%20INDEX%20CONCURRENTLY%20cannot.
with op.get_context().autocommit_block():
op.drop_index(op.f('workflow_node_executions_tenant_id_idx'), postgresql_concurrently=True)
with op.batch_alter_table('workflow_draft_variables', schema=None) as batch_op:
batch_op.drop_column('node_execution_id')
# ### end Alembic commands ###

View File

@ -1,6 +1,9 @@
from datetime import datetime
from enum import Enum
from typing import Optional
from sqlalchemy import func
from sqlalchemy import func, text
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .engine import db
@ -51,20 +54,24 @@ class Provider(Base):
),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
encrypted_config = db.Column(db.Text, nullable=True)
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
last_used = db.Column(db.DateTime, nullable=True)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider_type: Mapped[str] = mapped_column(
db.String(40), nullable=False, server_default=text("'custom'::character varying")
)
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying"))
quota_limit = db.Column(db.BigInteger, nullable=True)
quota_used = db.Column(db.BigInteger, default=0)
quota_type: Mapped[Optional[str]] = mapped_column(
db.String(40), nullable=True, server_default=text("''::character varying")
)
quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True)
quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def __repr__(self):
return (
@ -104,15 +111,15 @@ class ProviderModel(Base):
),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True)
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TenantDefaultModel(Base):
@ -122,13 +129,13 @@ class TenantDefaultModel(Base):
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TenantPreferredModelProvider(Base):
@ -138,12 +145,12 @@ class TenantPreferredModelProvider(Base):
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
preferred_provider_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderOrder(Base):
@ -153,22 +160,24 @@ class ProviderOrder(Base):
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
account_id = db.Column(StringUUID, nullable=False)
payment_product_id = db.Column(db.String(191), nullable=False)
payment_id = db.Column(db.String(191))
transaction_id = db.Column(db.String(191))
quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1"))
currency = db.Column(db.String(40))
total_amount = db.Column(db.Integer)
payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying"))
paid_at = db.Column(db.DateTime)
pay_failed_at = db.Column(db.DateTime)
refunded_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False)
payment_id: Mapped[Optional[str]] = mapped_column(db.String(191))
transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191))
quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1"))
currency: Mapped[Optional[str]] = mapped_column(db.String(40))
total_amount: Mapped[Optional[int]] = mapped_column(db.Integer)
payment_status: Mapped[str] = mapped_column(
db.String(40), nullable=False, server_default=text("'wait_pay'::character varying")
)
paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderModelSetting(Base):
@ -182,15 +191,15 @@ class ProviderModelSetting(Base):
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class LoadBalancingModelConfig(Base):
@ -204,13 +213,13 @@ class LoadBalancingModelConfig(Base):
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
name = db.Column(db.String(255), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True)
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -16,8 +16,8 @@ if TYPE_CHECKING:
from models.model import AppMode
import sqlalchemy as sa
from sqlalchemy import UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import Index, PrimaryKeyConstraint, UniqueConstraint, func
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
@ -615,28 +615,48 @@ class WorkflowNodeExecutionModel(Base):
"""
__tablename__ = "workflow_node_executions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
db.Index(
"workflow_node_execution_workflow_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"workflow_run_id",
),
db.Index(
"workflow_node_execution_node_run_idx", "tenant_id", "app_id", "workflow_id", "triggered_from", "node_id"
),
db.Index(
"workflow_node_execution_id_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_execution_id",
),
)
@declared_attr
def __table_args__(cls): # noqa
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
"workflow_node_execution_workflow_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"workflow_run_id",
),
Index(
"workflow_node_execution_node_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_id",
),
Index(
"workflow_node_execution_id_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_execution_id",
),
Index(
# The first argument is the index name,
# which we leave as `None`` to allow auto-generation by the ORM.
None,
cls.tenant_id,
cls.workflow_id,
cls.node_id,
# MyPy may flag the following line because it doesn't recognize that
# the `declared_attr` decorator passes the receiving class as the first
# argument to this method, allowing us to reference class attributes.
cls.created_at.desc(), # type: ignore
),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
@ -910,14 +930,29 @@ class WorkflowDraftVariable(Base):
selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector")
# The data type of this variable's value
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
# JSON string
# The variable's value serialized as a JSON string
value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value")
# visible
# Controls whether the variable should be displayed in the variable inspection panel
visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
# Determines whether this variable can be modified by users
editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
# The `node_execution_id` field identifies the workflow node execution that created this variable.
# It corresponds to the `id` field in the `WorkflowNodeExecutionModel` model.
#
# This field is not `None` for system variables and node variables, and is `None`
# for conversation variables.
node_execution_id: Mapped[str | None] = mapped_column(
StringUUID,
nullable=True,
default=None,
)
def get_selector(self) -> list[str]:
selector = json.loads(self.selector)
if not isinstance(selector, list):

View File

@ -2,6 +2,8 @@
warn_return_any = True
warn_unused_configs = True
check_untyped_defs = True
cache_fine_grained = True
sqlite_cache = True
exclude = (?x)(
core/model_runtime/model_providers/
| tests/

View File

@ -56,7 +56,6 @@ dependencies = [
"opentelemetry-sdk==1.27.0",
"opentelemetry-semantic-conventions==0.48b0",
"opentelemetry-util-http==0.48b0",
"pandas-stubs~=2.2.3.241009",
"pandas[excel,output-formatting,performance]~=2.2.2",
"pandoc~=2.4",
"psycogreen~=1.0.2",
@ -104,7 +103,7 @@ dev = [
"dotenv-linter~=0.5.0",
"faker~=32.1.0",
"lxml-stubs~=0.5.1",
"mypy~=1.15.0",
"mypy~=1.16.0",
"ruff~=0.11.5",
"pytest~=8.3.2",
"pytest-benchmark~=4.0.0",
@ -152,6 +151,8 @@ dev = [
"types_pyOpenSSL>=24.1.0",
"types_cffi>=1.17.0",
"types_setuptools>=80.9.0",
"pandas-stubs~=2.2.3",
"scipy-stubs>=1.15.3.0",
]
############################################################

View File

@ -0,0 +1,62 @@
import logging
from datetime import datetime
from urllib.parse import urlparse
import click
from flask import render_template
from redis import Redis
import app
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_mail import mail
# Create a dedicated Redis connection (using the same configuration as Celery)
celery_broker_url = dify_config.CELERY_BROKER_URL
parsed = urlparse(celery_broker_url)
host = parsed.hostname or "localhost"
port = parsed.port or 6379
password = parsed.password or None
redis_db = parsed.path.strip("/") or "1" # type: ignore
celery_redis = Redis(host=host, port=port, password=password, db=redis_db)
@app.celery.task(queue="monitor")
def queue_monitor_task():
queue_name = "dataset"
threshold = dify_config.QUEUE_MONITOR_THRESHOLD
try:
queue_length = celery_redis.llen(f"{queue_name}")
logging.info(click.style(f"Start monitor {queue_name}", fg="green"))
logging.info(click.style(f"Queue length: {queue_length}", fg="green"))
if queue_length >= threshold:
warning_msg = f"Queue {queue_name} task count exceeded the limit.: {queue_length}/{threshold}"
logging.warning(click.style(warning_msg, fg="red"))
alter_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS
if alter_emails:
to_list = alter_emails.split(",")
for to in to_list:
try:
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
html_content = render_template(
"queue_monitor_alert_email_template_en-US.html",
queue_name=queue_name,
queue_length=queue_length,
threshold=threshold,
alert_time=current_time,
)
mail.send(
to=to, subject="Alert: Dataset Queue pending tasks exceeded the limit", html=html_content
)
except Exception as e:
logging.exception(click.style("Exception occurred during sending email", fg="red"))
except Exception as e:
logging.exception(click.style("Exception occurred during queue monitoring", fg="red"))
finally:
if db.session.is_active:
db.session.close()

View File

@ -1,7 +1,6 @@
import base64
import json
import logging
import random
import secrets
import uuid
from datetime import UTC, datetime, timedelta
@ -261,7 +260,7 @@ class AccountService:
@staticmethod
def generate_account_deletion_verification_code(account: Account) -> tuple[str, str]:
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, token_type="account_deletion", additional_data={"code": code}
)
@ -429,7 +428,7 @@ class AccountService:
additional_data: dict[str, Any] = {},
):
if not code:
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
additional_data["code"] = code
token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data=additional_data
@ -456,7 +455,7 @@ class AccountService:
raise EmailCodeLoginRateLimitExceededError()
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="email_code_login", additional_data={"code": code}
)

View File

@ -395,3 +395,15 @@ class AppService:
if not site:
raise ValueError(f"App with id {app_id} not found")
return str(site.code)
@staticmethod
def get_app_id_by_code(app_code: str) -> str:
"""
Get app id by app code
:param app_code: app code
:return: app id
"""
site = db.session.query(Site).filter(Site.code == app_code).first()
if not site:
raise ValueError(f"App with code {app_code} not found")
return str(site.app_id)

View File

@ -2,7 +2,7 @@ import copy
import datetime
import json
import logging
import random
import secrets
import time
import uuid
from collections import Counter
@ -1140,7 +1140,7 @@ class DocumentService:
documents.append(document)
batch = document.batch
else:
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
batch = time.strftime("%Y%m%d%H%M%S") + str(100000 + secrets.randbelow(exclusive_upper_bound=900000))
# save process rule
if not dataset_process_rule:
process_rule = knowledge_config.process_rule

View File

@ -1,3 +1,5 @@
from datetime import datetime
from pydantic import BaseModel, Field
from services.enterprise.base import EnterpriseRequest
@ -5,7 +7,7 @@ from services.enterprise.base import EnterpriseRequest
class WebAppSettings(BaseModel):
access_mode: str = Field(
description="Access mode for the web app. Can be 'public' or 'private'",
description="Access mode for the web app. Can be 'public', 'private', 'private_all', 'sso_verified'",
default="private",
alias="accessMode",
)
@ -20,6 +22,28 @@ class EnterpriseService:
def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
@classmethod
def get_app_sso_settings_last_update_time(cls) -> datetime:
data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time")
if not data:
raise ValueError("No data found.")
try:
# parse the UTC timestamp from the response
return datetime.fromisoformat(data.replace("Z", "+00:00"))
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
@classmethod
def get_workspace_sso_settings_last_update_time(cls) -> datetime:
data = EnterpriseRequest.send_request("GET", "/sso/workspace/last-update-time")
if not data:
raise ValueError("No data found.")
try:
# parse the UTC timestamp from the response
return datetime.fromisoformat(data.replace("Z", "+00:00"))
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str):

View File

@ -46,6 +46,8 @@ class TagService:
@staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
if not tag_type or not tag_name:
return []
tags = (
db.session.query(Tag)
.filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
@ -88,7 +90,7 @@ class TagService:
@staticmethod
def update_tags(args: dict, tag_id: str) -> Tag:
if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]):
if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")):
raise ValueError("Tag name already exists")
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:

View File

@ -1,31 +1,38 @@
import random
import enum
import secrets
from datetime import UTC, datetime, timedelta
from typing import Any, Optional, cast
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from controllers.web.error import WebAppAuthAccessDeniedError
from extensions.ext_database import db
from libs.helper import TokenManager
from libs.passport import PassportService
from libs.password import compare_password
from models.account import Account, AccountStatus
from models.model import App, EndUser, Site
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
from services.feature_service import FeatureService
from tasks.mail_email_code_login import send_email_code_login_mail_task
class WebAppAuthType(enum.StrEnum):
"""Enum for web app authentication types."""
PUBLIC = "public"
INTERNAL = "internal"
EXTERNAL = "external"
class WebAppAuthService:
"""Service for web app authentication."""
@staticmethod
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
account = Account.query.filter_by(email=email).first()
account = db.session.query(Account).filter_by(email=email).first()
if not account:
raise AccountNotFoundError()
@ -38,12 +45,8 @@ class WebAppAuthService:
return cast(Account, account)
@classmethod
def login(cls, account: Account, app_code: str, end_user_id: str) -> str:
site = db.session.query(Site).filter(Site.code == app_code).first()
if not site:
raise NotFound("Site not found.")
access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id)
def login(cls, account: Account) -> str:
access_token = cls._get_account_jwt_token(account=account)
return access_token
@ -66,9 +69,9 @@ class WebAppAuthService:
if email is None:
raise ValueError("Email must be provided.")
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code}
account=account, email=email, token_type="email_code_login", additional_data={"code": code}
)
send_email_code_login_mail_task.delay(
language=language,
@ -80,11 +83,11 @@ class WebAppAuthService:
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "webapp_email_code_login")
return TokenManager.get_token_data(token, "email_code_login")
@classmethod
def revoke_email_code_login_token(cls, token: str):
TokenManager.revoke_token(token, "webapp_email_code_login")
TokenManager.revoke_token(token, "email_code_login")
@classmethod
def create_end_user(cls, app_code, email) -> EndUser:
@ -109,33 +112,67 @@ class WebAppAuthService:
return end_user
@classmethod
def _validate_user_accessibility(cls, account: Account, app_code: str):
"""Check if the user is allowed to access the app."""
system_features = FeatureService.get_system_features()
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
if (
app_settings.access_mode != "public"
and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code)
):
raise WebAppAuthAccessDeniedError()
@classmethod
def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str:
def _get_account_jwt_token(cls, account: Account) -> str:
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
exp = int(exp_dt.timestamp())
payload = {
"iss": site.id,
"sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"user_id": account.id,
"end_user_id": end_user_id,
"token_source": "webapp",
"session_id": account.email,
"token_source": "webapp_login_token",
"auth_type": "internal",
"exp": exp,
}
token: str = PassportService().issue(payload)
return token
@classmethod
def is_app_require_permission_check(
cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None
) -> bool:
"""
Check if the app requires permission check based on its access mode.
"""
modes_requiring_permission_check = [
"private",
"private_all",
]
if access_mode:
return access_mode in modes_requiring_permission_check
if not app_code and not app_id:
raise ValueError("Either app_code or app_id must be provided.")
if app_code:
app_id = AppService.get_app_id_by_code(app_code)
if not app_id:
raise ValueError("App ID could not be determined from the provided app_code.")
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check:
return True
return False
@classmethod
def get_app_auth_type(cls, app_code: str | None = None, access_mode: str | None = None) -> WebAppAuthType:
"""
Get the authentication type for the app based on its access mode.
"""
if not app_code and not access_mode:
raise ValueError("Either app_code or access_mode must be provided.")
if access_mode:
if access_mode == "public":
return WebAppAuthType.PUBLIC
elif access_mode in ["private", "private_all"]:
return WebAppAuthType.INTERNAL
elif access_mode == "sso_verified":
return WebAppAuthType.EXTERNAL
if app_code:
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code)
return cls.get_app_auth_type(access_mode=webapp_settings.access_mode)
raise ValueError("Could not determine app authentication type.")

View File

@ -5,7 +5,7 @@ import uuid
import click
from celery import shared_task # type: ignore
from sqlalchemy import func, select
from sqlalchemy import func
from sqlalchemy.orm import Session
from core.model_manager import ModelManager
@ -68,11 +68,6 @@ def batch_create_segment_to_index_task(
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
word_count_change = 0
segments_to_insert: list[str] = []
max_position_stmt = select(func.max(DocumentSegment.position)).where(
DocumentSegment.document_id == dataset_document.id
)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(

View File

@ -0,0 +1,129 @@
<!DOCTYPE html>
<html>
<head>
<style>
body {
font-family: 'Arial', sans-serif;
line-height: 16pt;
color: #101828;
background-color: #e9ebf0;
margin: 0;
padding: 0;
}
.container {
width: 600px;
min-height: 605px;
margin: 40px auto;
padding: 36px 48px;
background-color: #fcfcfd;
border-radius: 16px;
border: 1px solid #ffffff;
box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
}
.header {
margin-bottom: 24px;
}
.header img {
max-width: 100px;
height: auto;
}
.title {
font-weight: 600;
font-size: 24px;
line-height: 28.8px;
}
.description {
font-size: 13px;
line-height: 16px;
color: #676f83;
margin-top: 12px;
}
.alert-content {
padding: 16px 32px;
text-align: center;
border-radius: 16px;
background-color: #fef0f0;
margin: 16px auto;
border: 1px solid #fda29b;
}
.alert-title {
line-height: 24px;
font-weight: 700;
font-size: 18px;
color: #d92d20;
}
.alert-detail {
line-height: 20px;
font-size: 14px;
margin-top: 8px;
}
.typography {
letter-spacing: -0.07px;
font-weight: 400;
font-style: normal;
font-size: 14px;
line-height: 20px;
color: #354052;
margin-top: 12px;
margin-bottom: 12px;
}
.typography p{
margin: 0 auto;
}
.typography-title {
color: #101828;
font-size: 14px;
font-style: normal;
font-weight: 600;
line-height: 20px;
margin-top: 12px;
margin-bottom: 4px;
}
.tip-list{
margin: 0;
padding-left: 10px;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<img src="https://assets.dify.ai/images/logo.png" alt="Dify Logo" />
</div>
<p class="title">Queue Monitoring Alert</p>
<p class="typography">Our system has detected an abnormal queue status that requires your attention:</p>
<div class="alert-content">
<div class="alert-title">Queue Task Alert</div>
<div class="alert-detail">
Queue "{{queue_name}}" has {{queue_length}} pending tasks (Threshold: {{threshold}})
</div>
</div>
<div class="typography">
<p style="margin-bottom:4px">Recommended actions:</p>
<p>1. Check the queue processing status in the system dashboard</p>
<p>2. Verify if there are any processing bottlenecks</p>
<p>3. Consider scaling up workers if needed</p>
</div>
<p class="typography-title">Additional Information:</p>
<ul class="typography tip-list">
<li>Alert triggered at: {{alert_time}}</li>
</ul>
</div>
</body>
</html>

View File

@ -3,11 +3,16 @@ import os
import time
import uuid
from collections.abc import Generator
from unittest.mock import MagicMock
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from app_factory import create_app
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
@ -19,13 +24,27 @@ from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@pytest.fixture(scope="session")
def app():
# Set up storage configuration
os.environ["STORAGE_TYPE"] = "opendal"
os.environ["OPENDAL_SCHEME"] = "fs"
os.environ["OPENDAL_FS_ROOT"] = "storage"
# Ensure storage directory exists
os.makedirs("storage", exist_ok=True)
app = create_app()
dify_config.LOGIN_DISABLED = True
return app
def init_llm_node(config: dict) -> LLMNode:
graph_config = {
"edges": [
@ -40,13 +59,19 @@ def init_llm_node(config: dict) -> LLMNode:
graph = Graph.init(graph_config=graph_config)
# Use proper UUIDs for database compatibility
tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
tenant_id=tenant_id,
app_id=app_id,
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
workflow_id=workflow_id,
graph_config=graph_config,
user_id="1",
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
@ -77,115 +102,197 @@ def init_llm_node(config: dict) -> LLMNode:
return node
def test_execute_llm(setup_model_mock):
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {
"provider": "langgenius/openai/openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
def test_execute_llm(app):
with app.app_context():
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {
"provider": "langgenius/openai/openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
"prompt_template": [
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
},
)
)
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
# Mock db.session.close()
db.session.close = MagicMock()
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
node._fetch_model_config = get_mocked_fetch_model_config(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials=credentials,
)
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
# execute node
result = node._run()
assert isinstance(result, Generator)
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "langgenius/openai/openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
assert isinstance(result, Generator)
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock):
def test_execute_llm_with_jinja2(app, setup_code_executor_mock):
"""
Test execute LLM node with jinja2
"""
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_config": {
"jinja2_variables": [
{"variable": "sys_query", "value_selector": ["sys", "query"]},
{"variable": "output", "value_selector": ["abc", "output"]},
]
with app.app_context():
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_config": {
"jinja2_variables": [
{"variable": "sys_query", "value_selector": ["sys", "query"]},
{"variable": "output", "value_selector": ["abc", "output"]},
]
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
"edition_type": "jinja2",
},
{
"role": "user",
"text": "{{#sys.query#}}",
"jinja2_text": "{{sys_query}}",
"edition_type": "basic",
},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
"edition_type": "jinja2",
},
{
"role": "user",
"text": "{{#sys.query#}}",
"jinja2_text": "{{sys_query}}",
"edition_type": "basic",
},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
},
)
)
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
# Mock db.session.close()
db.session.close = MagicMock()
# Mock db.session.close()
db.session.close = MagicMock()
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
node._fetch_model_config = get_mocked_fetch_model_config(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials=credentials,
)
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
# execute node
result = node._run()
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert "sunny" in json.dumps(item.run_result.process_data)
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert "sunny" in json.dumps(item.run_result.process_data)
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
def test_extract_json():

View File

@ -353,7 +353,7 @@ def test_extract_json_from_tool_call():
assert result["location"] == "kawaii"
def test_chat_parameter_extractor_with_memory(setup_model_mock):
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
"""
Test chat parameter extractor with memory.
"""
@ -384,7 +384,8 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock):
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
node._fetch_memory = get_mocked_fetch_memory("customized memory")
# Test the mock before running the actual test
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
db.session.close = MagicMock()
result = node._run()

View File

@ -1,4 +1,4 @@
import random
import secrets
from unittest.mock import MagicMock, patch
import pytest
@ -34,7 +34,7 @@ def test_retry_logic_success(mock_request):
side_effects = []
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
status_code = random.choice(STATUS_FORCELIST)
status_code = secrets.choice(STATUS_FORCELIST)
mock_response = MagicMock()
mock_response.status_code = status_code
side_effects.append(mock_response)

View File

@ -25,6 +25,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
@ -170,7 +171,7 @@ def model_config():
)
def test_fetch_files_with_file_segment(llm_node):
def test_fetch_files_with_file_segment():
file = File(
id="1",
tenant_id="test",
@ -180,13 +181,14 @@ def test_fetch_files_with_file_segment(llm_node):
related_id="1",
storage_key="",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
variable_pool = VariablePool()
variable_pool.add(["sys", "files"], file)
result = llm_node._fetch_files(selector=["sys", "files"])
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == [file]
def test_fetch_files_with_array_file_segment(llm_node):
def test_fetch_files_with_array_file_segment():
files = [
File(
id="1",
@ -207,28 +209,32 @@ def test_fetch_files_with_array_file_segment(llm_node):
storage_key="",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
variable_pool = VariablePool()
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_node._fetch_files(selector=["sys", "files"])
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == files
def test_fetch_files_with_none_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
def test_fetch_files_with_none_segment():
variable_pool = VariablePool()
variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"])
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []
def test_fetch_files_with_array_any_segment(llm_node):
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
def test_fetch_files_with_array_any_segment():
variable_pool = VariablePool()
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_node._fetch_files(selector=["sys", "files"])
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []
def test_fetch_files_with_non_existent_variable(llm_node):
result = llm_node._fetch_files(selector=["sys", "files"])
def test_fetch_files_with_non_existent_variable():
variable_pool = VariablePool()
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []

View File

@ -1,5 +1,7 @@
import io
from unittest.mock import Mock, patch
import pandas as pd
import pytest
from docx.oxml.text.paragraph import CT_P
@ -187,145 +189,134 @@ def test_node_type(document_extractor_node):
@patch("pandas.ExcelFile")
def test_extract_text_from_excel_single_sheet(mock_excel_file):
"""Test extracting text from Excel file with single sheet."""
# Mock DataFrame
mock_df = Mock()
mock_df.dropna = Mock()
mock_df.to_markdown.return_value = "| Name | Age |\n|------|-----|\n| John | 25 |"
"""Test extracting text from Excel file with single sheet and multiline content."""
# Test multi-line cell
data = {"Name\nwith\nnewline": ["John\nDoe", "Jane\nSmith"], "Age": [25, 30]}
df = pd.DataFrame(data)
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["Sheet1"]
mock_excel_instance.parse.return_value = mock_df
mock_excel_instance.parse.return_value = df
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_content"
result = _extract_text_from_excel(file_content)
expected_manual = "| Name with newline | Age |\n| ----------------- | --- |\n\
| John Doe | 25 |\n| Jane Smith | 30 |\n\n"
expected = "| Name | Age |\n|------|-----|\n| John | 25 |\n\n"
assert result == expected
mock_excel_file.assert_called_once()
mock_df.dropna.assert_called_once_with(how="all", inplace=True)
mock_df.to_markdown.assert_called_once_with(index=False, floatfmt="")
assert expected_manual == result
mock_excel_instance.parse.assert_called_once_with(sheet_name="Sheet1")
@patch("pandas.ExcelFile")
def test_extract_text_from_excel_multiple_sheets(mock_excel_file):
"""Test extracting text from Excel file with multiple sheets."""
# Mock DataFrames for different sheets
mock_df1 = Mock()
mock_df1.dropna = Mock()
mock_df1.to_markdown.return_value = "| Product | Price |\n|---------|-------|\n| Apple | 1.50 |"
"""Test extracting text from Excel file with multiple sheets and multiline content."""
mock_df2 = Mock()
mock_df2.dropna = Mock()
mock_df2.to_markdown.return_value = "| City | Population |\n|------|------------|\n| NYC | 8000000 |"
# Test multi-line cell
data1 = {"Product\nName": ["Apple\nRed", "Banana\nYellow"], "Price": [1.50, 0.99]}
df1 = pd.DataFrame(data1)
data2 = {"City\nName": ["New\nYork", "Los\nAngeles"], "Population": [8000000, 3900000]}
df2 = pd.DataFrame(data2)
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["Products", "Cities"]
mock_excel_instance.parse.side_effect = [mock_df1, mock_df2]
mock_excel_instance.parse.side_effect = [df1, df2]
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_content_multiple_sheets"
result = _extract_text_from_excel(file_content)
expected = (
"| Product | Price |\n|---------|-------|\n| Apple | 1.50 |\n\n"
"| City | Population |\n|------|------------|\n| NYC | 8000000 |\n\n"
)
assert result == expected
expected_manual1 = "| Product Name | Price |\n| ------------ | ----- |\n\
| Apple Red | 1.5 |\n| Banana Yellow | 0.99 |\n\n"
expected_manual2 = "| City Name | Population |\n| --------- | ---------- |\n\
| New York | 8000000 |\n| Los Angeles | 3900000 |\n\n"
assert expected_manual1 in result
assert expected_manual2 in result
assert mock_excel_instance.parse.call_count == 2
@patch("pandas.ExcelFile")
def test_extract_text_from_excel_empty_sheets(mock_excel_file):
"""Test extracting text from Excel file with empty sheets."""
# Mock empty DataFrame
mock_df = Mock()
mock_df.dropna = Mock()
mock_df.to_markdown.return_value = ""
# Empty excel
df = pd.DataFrame()
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["EmptySheet"]
mock_excel_instance.parse.return_value = mock_df
mock_excel_instance.parse.return_value = df
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_empty_content"
result = _extract_text_from_excel(file_content)
expected = "\n\n"
expected = "| |\n| |\n\n"
assert result == expected
mock_excel_instance.parse.assert_called_once_with(sheet_name="EmptySheet")
@patch("pandas.ExcelFile")
def test_extract_text_from_excel_sheet_parse_error(mock_excel_file):
"""Test handling of sheet parsing errors - should continue with other sheets."""
# Mock DataFrames - one successful, one that raises exception
mock_df_success = Mock()
mock_df_success.dropna = Mock()
mock_df_success.to_markdown.return_value = "| Data | Value |\n|------|-------|\n| Test | 123 |"
# Test error
data = {"Data": ["Test"], "Value": [123]}
df = pd.DataFrame(data)
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["GoodSheet", "BadSheet"]
mock_excel_instance.parse.side_effect = [mock_df_success, Exception("Parse error")]
mock_excel_instance.parse.side_effect = [df, Exception("Parse error")]
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_mixed_content"
result = _extract_text_from_excel(file_content)
expected = "| Data | Value |\n|------|-------|\n| Test | 123 |\n\n"
assert result == expected
expected_manual = "| Data | Value |\n| ---- | ----- |\n| Test | 123 |\n\n"
assert expected_manual == result
@patch("pandas.ExcelFile")
def test_extract_text_from_excel_file_error(mock_excel_file):
"""Test handling of Excel file reading errors."""
mock_excel_file.side_effect = Exception("Invalid Excel file")
file_content = b"invalid_excel_content"
with pytest.raises(Exception) as exc_info:
_extract_text_from_excel(file_content)
# Note: The function should raise TextExtractionError, but since it's not imported in the test,
# we check for the general Exception pattern
assert "Failed to extract text from Excel file" in str(exc_info.value)
assert mock_excel_instance.parse.call_count == 2
@patch("pandas.ExcelFile")
def test_extract_text_from_excel_io_bytesio_usage(mock_excel_file):
"""Test that BytesIO is properly used with the file content."""
import io
# Mock DataFrame
mock_df = Mock()
mock_df.dropna = Mock()
mock_df.to_markdown.return_value = "| Test | Data |\n|------|------|\n| 1 | A |"
# Test bytesio
data = {"Test": [1], "Data": ["A"]}
df = pd.DataFrame(data)
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["TestSheet"]
mock_excel_instance.parse.return_value = mock_df
mock_excel_instance.parse.return_value = df
mock_excel_file.return_value = mock_excel_instance
file_content = b"test_excel_bytes"
result = _extract_text_from_excel(file_content)
# Verify that ExcelFile was called with a BytesIO object
mock_excel_file.assert_called_once()
call_args = mock_excel_file.call_args[0][0]
assert isinstance(call_args, io.BytesIO)
call_arg = mock_excel_file.call_args[0][0]
assert isinstance(call_arg, io.BytesIO)
expected = "| Test | Data |\n|------|------|\n| 1 | A |\n\n"
assert result == expected
expected_manual = "| Test | Data |\n| ---- | ---- |\n| 1 | A |\n\n"
assert expected_manual == result
@patch("pandas.ExcelFile")
def test_extract_text_from_excel_all_sheets_fail(mock_excel_file):
"""Test when all sheets fail to parse - should return empty string."""
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["BadSheet1", "BadSheet2"]
@ -335,29 +326,6 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file):
file_content = b"fake_excel_all_bad_sheets"
result = _extract_text_from_excel(file_content)
# Should return empty string when all sheets fail
assert result == ""
@patch("pandas.ExcelFile")
def test_extract_text_from_excel_markdown_formatting(mock_excel_file):
"""Test that markdown formatting parameters are correctly applied."""
# Mock DataFrame
mock_df = Mock()
mock_df.dropna = Mock()
mock_df.to_markdown.return_value = "| Float | Int |\n|-------|-----|\n| 123456.78 | 42 |"
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["NumberSheet"]
mock_excel_instance.parse.return_value = mock_df
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_numbers"
result = _extract_text_from_excel(file_content)
# Verify to_markdown was called with correct parameters
mock_df.to_markdown.assert_called_once_with(index=False, floatfmt="")
expected = "| Float | Int |\n|-------|-----|\n| 123456.78 | 42 |\n\n"
assert result == expected
assert mock_excel_instance.parse.call_count == 2

File diff suppressed because it is too large Load Diff

View File

@ -7,4 +7,4 @@ cd "$SCRIPT_DIR/.."
# run mypy checks
uv run --directory api --dev --with pip \
python -m mypy --install-types --non-interactive --cache-fine-grained --sqlite-cache .
python -m mypy --install-types --non-interactive ./

View File

@ -1057,7 +1057,7 @@ PLUGIN_MAX_EXECUTION_TIMEOUT=600
PIP_MIRROR_URL=
# https://github.com/langgenius/dify-plugin-daemon/blob/main/.env.example
# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss
# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss volcengine_tos
PLUGIN_STORAGE_TYPE=local
PLUGIN_STORAGE_LOCAL_ROOT=/app/storage
PLUGIN_WORKING_PATH=/app/storage/cwd
@ -1087,6 +1087,11 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID=
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET=
PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4
PLUGIN_ALIYUN_OSS_PATH=
# Plugin oss volcengine tos
PLUGIN_VOLCENGINE_TOS_ENDPOINT=
PLUGIN_VOLCENGINE_TOS_ACCESS_KEY=
PLUGIN_VOLCENGINE_TOS_SECRET_KEY=
PLUGIN_VOLCENGINE_TOS_REGION=
# ------------------------------
# OTLP Collector Configuration
@ -1106,3 +1111,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
# Prevent Clickjacking
ALLOW_EMBED=false
# Dataset queue monitor configuration
QUEUE_MONITOR_THRESHOLD=200
# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
QUEUE_MONITOR_ALERT_EMAILS=
# Monitor interval in minutes, default is 30 minutes
QUEUE_MONITOR_INTERVAL=30

View File

@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
image: langgenius/dify-api:1.4.1
image: langgenius/dify-api:1.4.2
restart: always
environment:
# Use the shared environment variables.
@ -31,7 +31,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:1.4.1
image: langgenius/dify-api:1.4.2
restart: always
environment:
# Use the shared environment variables.
@ -57,7 +57,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.4.1
image: langgenius/dify-web:1.4.2
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -142,7 +142,7 @@ services:
# plugin daemon
plugin_daemon:
image: langgenius/dify-plugin-daemon:0.1.1-local
image: langgenius/dify-plugin-daemon:0.1.2-local
restart: always
environment:
# Use the shared environment variables.
@ -184,6 +184,10 @@ services:
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ports:
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
volumes:

View File

@ -484,6 +484,10 @@ x-shared-env: &shared-api-worker-env
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
PLUGIN_ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
PLUGIN_ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
PLUGIN_VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
PLUGIN_VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
PLUGIN_VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
PLUGIN_VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ENABLE_OTEL: ${ENABLE_OTEL:-false}
OTLP_BASE_ENDPOINT: ${OTLP_BASE_ENDPOINT:-http://localhost:4318}
OTLP_API_KEY: ${OTLP_API_KEY:-}
@ -497,11 +501,14 @@ x-shared-env: &shared-api-worker-env
OTEL_BATCH_EXPORT_TIMEOUT: ${OTEL_BATCH_EXPORT_TIMEOUT:-10000}
OTEL_METRIC_EXPORT_TIMEOUT: ${OTEL_METRIC_EXPORT_TIMEOUT:-30000}
ALLOW_EMBED: ${ALLOW_EMBED:-false}
QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200}
QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-}
QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30}
services:
# API service
api:
image: langgenius/dify-api:1.4.1
image: langgenius/dify-api:1.4.2
restart: always
environment:
# Use the shared environment variables.
@ -530,7 +537,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:1.4.1
image: langgenius/dify-api:1.4.2
restart: always
environment:
# Use the shared environment variables.
@ -556,7 +563,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.4.1
image: langgenius/dify-web:1.4.2
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@ -641,7 +648,7 @@ services:
# plugin daemon
plugin_daemon:
image: langgenius/dify-plugin-daemon:0.1.1-local
image: langgenius/dify-plugin-daemon:0.1.2-local
restart: always
environment:
# Use the shared environment variables.
@ -683,6 +690,10 @@ services:
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ports:
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
volumes:

View File

@ -152,3 +152,8 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID=
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET=
PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4
PLUGIN_ALIYUN_OSS_PATH=
# Plugin oss volcengine tos
PLUGIN_VOLCENGINE_TOS_ENDPOINT=
PLUGIN_VOLCENGINE_TOS_ACCESS_KEY=
PLUGIN_VOLCENGINE_TOS_SECRET_KEY=
PLUGIN_VOLCENGINE_TOS_REGION=

View File

@ -47,7 +47,7 @@ class DifyClient:
def text_to_audio(self, text: str, user: str, streaming: bool = False):
data = {"text": text, "user": user, "streaming": streaming}
return self._send_request("POST", "/text-to-audio", data=data)
return self._send_request("POST", "/text-to-audio", json=data)
def get_meta(self, user):
params = {"user": user}

View File

@ -18,9 +18,10 @@ const queryDateFormat = 'YYYY-MM-DD HH:mm'
export type IChartViewProps = {
appId: string
headerRight: React.ReactNode
}
export default function ChartView({ appId }: IChartViewProps) {
export default function ChartView({ appId, headerRight }: IChartViewProps) {
const { t } = useTranslation()
const appDetail = useAppStore(state => state.appDetail)
const isChatApp = appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow'
@ -46,19 +47,24 @@ export default function ChartView({ appId }: IChartViewProps) {
return (
<div>
<div className='system-xl-semibold mb-4 mt-8 flex flex-row items-center text-text-primary'>
<span className='mr-3'>{t('appOverview.analysis.title')}</span>
<SimpleSelect
items={Object.entries(TIME_PERIOD_MAPPING).map(([k, v]) => ({ value: k, name: t(`appLog.filter.period.${v.name}`) }))}
className='mt-0 !w-40'
onSelect={(item) => {
const id = item.value
const value = TIME_PERIOD_MAPPING[id]?.value ?? '-1'
const name = item.name || t('appLog.filter.period.allTime')
onSelect({ value, name })
}}
defaultValue={'2'}
/>
<div className='mb-4'>
<div className='system-xl-semibold mb-2 text-text-primary'>{t('common.appMenus.overview')}</div>
<div className='flex items-center justify-between'>
<div className='flex flex-row items-center'>
<SimpleSelect
items={Object.entries(TIME_PERIOD_MAPPING).map(([k, v]) => ({ value: k, name: t(`appLog.filter.period.${v.name}`) }))}
className='mt-0 !w-40'
onSelect={(item) => {
const id = item.value
const value = TIME_PERIOD_MAPPING[id]?.value ?? '-1'
const name = item.name || t('appLog.filter.period.allTime')
onSelect({ value, name })
}}
defaultValue={'2'}
/>
</div>
{headerRight}
</div>
</div>
{!isWorkflow && (
<div className='mb-6 grid w-full grid-cols-1 gap-6 xl:grid-cols-2'>

View File

@ -1,6 +1,5 @@
import React from 'react'
import ChartView from './chartView'
import CardView from './cardView'
import TracingPanel from './tracing/panel'
import ApikeyInfoPanel from '@/app/components/app/overview/apikey-info-panel'
@ -18,9 +17,10 @@ const Overview = async (props: IDevelopProps) => {
return (
<div className="h-full overflow-scroll bg-chatbot-bg px-4 py-6 sm:px-12">
<ApikeyInfoPanel />
<TracingPanel />
<CardView appId={appId} />
<ChartView appId={appId} />
<ChartView
appId={appId}
headerRight={<TracingPanel />}
/>
</div>
)
}

View File

@ -23,19 +23,6 @@ import Divider from '@/app/components/base/divider'
const I18N_PREFIX = 'app.tracing'
const Title = ({
className,
}: {
className?: string
}) => {
const { t } = useTranslation()
return (
<div className={cn('system-xl-semibold flex items-center text-text-primary', className)}>
{t('common.appMenus.overview')}
</div>
)
}
const Panel: FC = () => {
const { t } = useTranslation()
const pathname = usePathname()
@ -154,7 +141,6 @@ const Panel: FC = () => {
if (!isLoaded) {
return (
<div className='mb-3 flex items-center justify-between'>
<Title className='h-[41px]' />
<div className='w-[200px]'>
<Loading />
</div>
@ -163,8 +149,7 @@ const Panel: FC = () => {
}
return (
<div className={cn('mb-3 flex items-center justify-between')}>
<Title className='h-[41px]' />
<div className={cn('flex items-center justify-between')}>
<div
className={cn(
'flex cursor-pointer items-center rounded-xl border-l-[0.5px] border-t border-effects-highlight bg-background-default-dodge p-2 shadow-xs hover:border-effects-highlight-lightmode-off hover:bg-background-default-lighter',

View File

@ -55,6 +55,7 @@ const weaveConfigTemplate = {
entity: '',
project: '',
endpoint: '',
host: '',
}
const ProviderConfigModal: FC<Props> = ({
@ -226,6 +227,13 @@ const ProviderConfigModal: FC<Props> = ({
onChange={handleConfigChange('endpoint')}
placeholder={'https://trace.wandb.ai/'}
/>
<Field
label='Host'
labelClassName='!text-sm'
value={(config as WeaveConfig).host}
onChange={handleConfigChange('host')}
placeholder={'https://api.wandb.ai'}
/>
</>
)}
{type === TracingProvider.langSmith && (

View File

@ -29,4 +29,5 @@ export type WeaveConfig = {
entity: string
project: string
endpoint: string
host: string
}

View File

@ -4,7 +4,7 @@ import { useContext, useContextSelector } from 'use-context-selector'
import { useRouter } from 'next/navigation'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill } from '@remixicon/react'
import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill, RiVerifiedBadgeLine } from '@remixicon/react'
import cn from '@/utils/classnames'
import type { App } from '@/types/app'
import Confirm from '@/app/components/base/confirm'
@ -338,7 +338,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
</div>
<div className='flex h-5 w-5 shrink-0 items-center justify-center'>
{app.access_mode === AccessMode.PUBLIC && <Tooltip asChild={false} popupContent={t('app.accessItemsDescription.anyone')}>
<RiGlobalLine className='h-4 w-4 text-text-accent' />
<RiGlobalLine className='h-4 w-4 text-text-quaternary' />
</Tooltip>}
{app.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && <Tooltip asChild={false} popupContent={t('app.accessItemsDescription.specific')}>
<RiLockLine className='h-4 w-4 text-text-quaternary' />
@ -346,6 +346,9 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
{app.access_mode === AccessMode.ORGANIZATION && <Tooltip asChild={false} popupContent={t('app.accessItemsDescription.organization')}>
<RiBuildingLine className='h-4 w-4 text-text-quaternary' />
</Tooltip>}
{app.access_mode === AccessMode.EXTERNAL_MEMBERS && <Tooltip asChild={false} popupContent={t('app.accessItemsDescription.external')}>
<RiVerifiedBadgeLine className='h-4 w-4 text-text-quaternary' />
</Tooltip>}
</div>
</div>
<div className='title-wrapper h-[90px] px-[14px] text-xs leading-normal text-text-tertiary'>

View File

@ -88,11 +88,11 @@ const Apps = () => {
const anchorRef = useRef<HTMLDivElement>(null)
const options = [
{ value: 'all', text: t('app.types.all'), icon: <RiApps2Line className='mr-1 h-[14px] w-[14px]' /> },
{ value: 'workflow', text: t('app.types.workflow'), icon: <RiExchange2Line className='mr-1 h-[14px] w-[14px]' /> },
{ value: 'advanced-chat', text: t('app.types.advanced'), icon: <RiMessage3Line className='mr-1 h-[14px] w-[14px]' /> },
{ value: 'chat', text: t('app.types.chatbot'), icon: <RiMessage3Line className='mr-1 h-[14px] w-[14px]' /> },
{ value: 'agent-chat', text: t('app.types.agent'), icon: <RiRobot3Line className='mr-1 h-[14px] w-[14px]' /> },
{ value: 'completion', text: t('app.types.completion'), icon: <RiFile4Line className='mr-1 h-[14px] w-[14px]' /> },
{ value: 'advanced-chat', text: t('app.types.advanced'), icon: <RiMessage3Line className='mr-1 h-[14px] w-[14px]' /> },
{ value: 'workflow', text: t('app.types.workflow'), icon: <RiExchange2Line className='mr-1 h-[14px] w-[14px]' /> },
]
useEffect(() => {

View File

@ -87,7 +87,7 @@ const Container = () => {
return (
<div ref={containerRef} className='scroll-container relative flex grow flex-col overflow-y-auto bg-background-body'>
<div className='sticky top-0 z-10 flex flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pb-2 pt-4 leading-[56px]'>
<div className='sticky top-0 z-10 flex h-[80px] shrink-0 flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pb-2 pt-4 leading-[56px]'>
<TabSliderNew
value={activeTab}
onChange={newActiveTab => setActiveTab(newActiveTab)}

View File

@ -192,15 +192,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- original_document_id が渡されない場合、新しい操作が実行され、process_rule が必要です。
- <code>indexing_technique</code> インデックスモード
- <code>high_quality</code> 高品質: 埋め込みモデルを使用してベクトルデータベースインデックスを構築
- <code>economy</code> 経済: キーワードテーブルインデックスの反転インデックスを構築
- <code>high_quality</code> 高品質埋め込みモデルを使用してベクトルデータベースインデックスを構築
- <code>economy</code> 経済キーワードテーブルインデックスの反転インデックスを構築
- <code>doc_form</code> インデックス化された内容の形式
- <code>text_model</code> テキストドキュメントは直接埋め込まれます; `economy` モードではこの形式がデフォルト
- <code>hierarchical_model</code> 親子モード
- <code>qa_model</code> Q&A モード: 分割されたドキュメントの質問と回答ペアを生成し、質問を埋め込みます
- <code>qa_model</code> Q&A モード分割されたドキュメントの質問と回答ペアを生成し、質問を埋め込みます
- <code>doc_language</code> Q&A モードでは、ドキュメントの言語を指定します。例: <code>English</code>, <code>Chinese</code>
- <code>doc_language</code> Q&A モードでは、ドキュメントの言語を指定します。例<code>English</code>, <code>Chinese</code>
- <code>process_rule</code> 処理ルール
- <code>mode</code> (string) クリーニング、セグメンテーションモード、自動 / カスタム
@ -214,7 +214,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- <code>segmentation</code> (object) セグメンテーションルール
- <code>separator</code> カスタムセグメント識別子。現在は 1 つの区切り文字のみ設定可能。デフォルトは \n
- <code>max_tokens</code> 最大長 (トークン) デフォルトは 1000
- <code>parent_mode</code> 親チャンクの検索モード: <code>full-doc</code> 全文検索 / <code>paragraph</code> 段落検索
- <code>parent_mode</code> 親チャンクの検索モード<code>full-doc</code> 全文検索 / <code>paragraph</code> 段落検索
- <code>subchunk_segmentation</code> (object) 子チャンクルール
- <code>separator</code> セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは <code>***</code>
- <code>max_tokens</code> 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります
@ -324,7 +324,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- <code>partial_members</code> 一部のメンバー
</Property>
<Property name='provider' type='string' key='provider'>
プロバイダー (オプション、デフォルト: vendor)
プロバイダー (オプション、デフォルトvendor)
- <code>vendor</code> ベンダー
- <code>external</code> 外部ナレッジ
</Property>
@ -415,16 +415,16 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
検索キーワード、オプション
</Property>
<Property name='tag_ids' type='array[string]' key='tag_ids'>
タグIDリスト、オプション
タグ ID リスト、オプション
</Property>
<Property name='page' type='string' key='page'>
ページ番号、オプション、デフォルト1
ページ番号、オプション、デフォルト 1
</Property>
<Property name='limit' type='string' key='limit'>
返されるアイテム数、オプション、デフォルト20、範囲1-100
返されるアイテム数、オプション、デフォルト 20、範囲 1-100
</Property>
<Property name='include_all' type='boolean' key='include_all'>
すべてのデータセットを含めるかどうか所有者のみ有効、オプション、デフォルトはfalse
すべてのデータセットを含めるかどうか(所有者のみ有効)、オプション、デフォルトは false
</Property>
</Properties>
</Col>
@ -2013,7 +2013,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
<Properties>
<Property name='name' type='string'>
(text) 新しいタグ名、必須、最大長50文字
(text) 新しいタグ名、必須、最大長 50 文字
</Property>
</Properties>
</Col>
@ -2099,10 +2099,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
<Properties>
<Property name='name' type='string'>
(text) 変更後のタグ名、必須、最大長50文字
(text) 変更後のタグ名、必須、最大長 50 文字
</Property>
<Property name='tag_id' type='string'>
(text) タグID、必須
(text) タグ ID、必須
</Property>
</Properties>
</Col>
@ -2147,7 +2147,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
<Properties>
<Property name='tag_id' type='string'>
(text) タグID、必須
(text) タグ ID、必須
</Property>
</Properties>
</Col>
@ -2188,10 +2188,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
<Properties>
<Property name='tag_ids' type='list'>
(list) タグIDリスト、必須
(list) タグ ID リスト、必須
</Property>
<Property name='target_id' type='string'>
(text) ナレッジベースID、必須
(text) ナレッジベース ID、必須
</Property>
</Properties>
</Col>
@ -2230,10 +2230,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
<Properties>
<Property name='tag_id' type='string'>
(text) タグID、必須
(text) タグ ID、必須
</Property>
<Property name='target_id' type='string'>
(text) ナレッジベースID、必須
(text) ナレッジベース ID、必須
</Property>
</Properties>
</Col>
@ -2273,7 +2273,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Path
<Properties>
<Property name='dataset_id' type='string'>
(text) ナレッジベースID
(text) ナレッジベース ID
</Property>
</Properties>
</Col>

View File

@ -207,7 +207,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- <code>doc_language</code> 在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code>
- <code>process_rule</code> 处理规则
- <code>mode</code> (string) 清洗、分段模式 automatic 自动 / custom 自定义 / hierarchical 父子
- <code>mode</code> (string) 清洗、分段模式automatic 自动 / custom 自定义 / hierarchical 父子
- <code>rules</code> (object) 自定义规则(自动模式下,该字段为空)
- <code>pre_processing_rules</code> (array[object]) 预处理规则
- <code>id</code> (string) 预处理规则的唯一标识符
@ -234,12 +234,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- <code>hybrid_search</code> 混合检索
- <code>semantic_search</code> 语义检索
- <code>full_text_search</code> 全文检索
- <code>reranking_enable</code> (bool) 是否开启rerank
- <code>reranking_enable</code> (bool) 是否开启 rerank
- <code>reranking_model</code> (object) Rerank 模型配置
- <code>reranking_provider_name</code> (string) Rerank 模型的提供商
- <code>reranking_model_name</code> (string) Rerank 模型的名称
- <code>top_k</code> (int) 召回条数
- <code>score_threshold_enabled</code> (bool)是否开启召回分数限制
- <code>score_threshold_enabled</code> (bool) 是否开启召回分数限制
- <code>score_threshold</code> (float) 召回分数限制
</Property>
<Property name='embedding_model' type='string' key='embedding_model'>
@ -350,12 +350,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- <code>hybrid_search</code> 混合检索
- <code>semantic_search</code> 语义检索
- <code>full_text_search</code> 全文检索
- <code>reranking_enable</code> (bool) 是否开启rerank
- <code>reranking_enable</code> (bool) 是否开启 rerank
- <code>reranking_model</code> (object) Rerank 模型配置
- <code>reranking_provider_name</code> (string) Rerank 模型的提供商
- <code>reranking_model_name</code> (string) Rerank 模型的名称
- <code>top_k</code> (int) 召回条数
- <code>score_threshold_enabled</code> (bool)是否开启召回分数限制
- <code>score_threshold_enabled</code> (bool) 是否开启召回分数限制
- <code>score_threshold</code> (float) 召回分数限制
</Property>
</Properties>
@ -1322,7 +1322,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
文档 ID
</Property>
<Property name='segment_id' type='string' key='segment_id'>
文档分段ID
文档分段 ID
</Property>
</Properties>
</Col>
@ -1435,7 +1435,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
文档 ID
</Property>
<Property name='segment_id' type='string' key='segment_id'>
文档分段ID
文档分段 ID
</Property>
</Properties>
@ -2223,7 +2223,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- <code>document_id</code> (string) 文档 ID
- <code>metadata_list</code> (list) 元数据列表
- <code>id</code> (string) 元数据 ID
- <code>type</code> (string) 元数据类型
- <code>value</code> (string) 元数据值
- <code>name</code> (string) 元数据名称
</Property>
</Properties>
@ -2404,7 +2404,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
<Properties>
<Property name='name' type='string'>
(text) 新标签名称必填最大长度为50
(text) 新标签名称,必填,最大长度为 50
</Property>
</Properties>
</Col>
@ -2490,10 +2490,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
<Properties>
<Property name='name' type='string'>
(text) 修改后的标签名称必填最大长度为50
(text) 修改后的标签名称,必填,最大长度为 50
</Property>
<Property name='tag_id' type='string'>
(text) 标签ID必填
(text) 标签 ID必填
</Property>
</Properties>
</Col>
@ -2538,7 +2538,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
<Properties>
<Property name='tag_id' type='string'>
(text) 标签ID必填
(text) 标签 ID必填
</Property>
</Properties>
</Col>
@ -2579,10 +2579,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
<Properties>
<Property name='tag_ids' type='list'>
(list) 标签ID列表必填
(list) 标签 ID 列表,必填
</Property>
<Property name='target_id' type='string'>
(text) 知识库ID必填
(text) 知识库 ID必填
</Property>
</Properties>
</Col>
@ -2621,10 +2621,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
<Properties>
<Property name='tag_id' type='string'>
(text) 标签ID必填
(text) 标签 ID必填
</Property>
<Property name='target_id' type='string'>
(text) 知识库ID必填
(text) 知识库 ID必填
</Property>
</Properties>
</Col>
@ -2664,7 +2664,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Path
<Properties>
<Property name='dataset_id' type='string'>
(text) 知识库ID
(text) 知识库 ID
</Property>
</Properties>
</Col>

View File

@ -1,14 +1,48 @@
import React from 'react'
'use client'
import React, { useEffect, useState } from 'react'
import type { FC } from 'react'
import type { Metadata } from 'next'
export const metadata: Metadata = {
icons: 'data:,', // prevent browser from using default favicon
}
import { usePathname, useSearchParams } from 'next/navigation'
import Loading from '../components/base/loading'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { AccessMode } from '@/models/access-control'
import { getAppAccessModeByAppCode } from '@/service/share'
const Layout: FC<{
children: React.ReactNode
}> = ({ children }) => {
const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending)
const setWebAppAccessMode = useGlobalPublicStore(s => s.setWebAppAccessMode)
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
const pathname = usePathname()
const searchParams = useSearchParams()
const redirectUrl = searchParams.get('redirect_url')
const [isLoading, setIsLoading] = useState(true)
useEffect(() => {
(async () => {
if (!systemFeatures.webapp_auth.enabled) {
setIsLoading(false)
return
}
let appCode: string | null = null
if (redirectUrl)
appCode = redirectUrl?.split('/').pop() || null
else
appCode = pathname.split('/').pop() || null
if (!appCode)
return
setIsLoading(true)
const ret = await getAppAccessModeByAppCode(appCode)
setWebAppAccessMode(ret?.accessMode || AccessMode.PUBLIC)
setIsLoading(false)
})()
}, [pathname, redirectUrl, setWebAppAccessMode])
if (isLoading || isGlobalPending) {
return <div className='flex h-full w-full items-center justify-center'>
<Loading />
</div>
}
return (
<div className="h-full min-w-[300px] pb-[env(safe-area-inset-bottom)]">
{children}

View File

@ -0,0 +1,96 @@
'use client'
import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { useState } from 'react'
import { useRouter, useSearchParams } from 'next/navigation'
import { useContext } from 'use-context-selector'
import Countdown from '@/app/components/signin/countdown'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import Toast from '@/app/components/base/toast'
import { sendWebAppResetPasswordCode, verifyWebAppResetPasswordCode } from '@/service/common'
import I18NContext from '@/context/i18n'
export default function CheckCode() {
const { t } = useTranslation()
const router = useRouter()
const searchParams = useSearchParams()
const email = decodeURIComponent(searchParams.get('email') as string)
const token = decodeURIComponent(searchParams.get('token') as string)
const [code, setVerifyCode] = useState('')
const [loading, setIsLoading] = useState(false)
const { locale } = useContext(I18NContext)
const verify = async () => {
try {
if (!code.trim()) {
Toast.notify({
type: 'error',
message: t('login.checkCode.emptyCode'),
})
return
}
if (!/\d{6}/.test(code)) {
Toast.notify({
type: 'error',
message: t('login.checkCode.invalidCode'),
})
return
}
setIsLoading(true)
const ret = await verifyWebAppResetPasswordCode({ email, code, token })
if (ret.is_valid) {
const params = new URLSearchParams(searchParams)
params.set('token', encodeURIComponent(ret.token))
router.push(`/webapp-reset-password/set-password?${params.toString()}`)
}
}
catch (error) { console.error(error) }
finally {
setIsLoading(false)
}
}
const resendCode = async () => {
try {
const res = await sendWebAppResetPasswordCode(email, locale)
if (res.result === 'success') {
const params = new URLSearchParams(searchParams)
params.set('token', encodeURIComponent(res.data))
router.replace(`/webapp-reset-password/check-code?${params.toString()}`)
}
}
catch (error) { console.error(error) }
}
return <div className='flex flex-col gap-3'>
<div className='inline-flex h-14 w-14 items-center justify-center rounded-2xl border border-components-panel-border-subtle bg-background-default-dodge text-text-accent-light-mode-only shadow-lg'>
<RiMailSendFill className='h-6 w-6 text-2xl' />
</div>
<div className='pb-4 pt-2'>
<h2 className='title-4xl-semi-bold text-text-primary'>{t('login.checkCode.checkYourEmail')}</h2>
<p className='body-md-regular mt-2 text-text-secondary'>
<span dangerouslySetInnerHTML={{ __html: t('login.checkCode.tips', { email }) as string }}></span>
<br />
{t('login.checkCode.validTime')}
</p>
</div>
<form action="">
<input type='text' className='hidden' />
<label htmlFor="code" className='system-md-semibold mb-1 text-text-secondary'>{t('login.checkCode.verificationCode')}</label>
<Input value={code} onChange={e => setVerifyCode(e.target.value)} max-length={6} className='mt-1' placeholder={t('login.checkCode.verificationCodePlaceholder') as string} />
<Button loading={loading} disabled={loading} className='my-3 w-full' variant='primary' onClick={verify}>{t('login.checkCode.verify')}</Button>
<Countdown onResend={resendCode} />
</form>
<div className='py-2'>
<div className='h-px bg-gradient-to-r from-background-gradient-mask-transparent via-divider-regular to-background-gradient-mask-transparent'></div>
</div>
<div onClick={() => router.back()} className='flex h-9 cursor-pointer items-center justify-center text-text-tertiary'>
<div className='bg-background-default-dimm inline-block rounded-full p-1'>
<RiArrowLeftLine size={12} />
</div>
<span className='system-xs-regular ml-2'>{t('login.back')}</span>
</div>
</div>
}

View File

@ -0,0 +1,30 @@
'use client'
import Header from '@/app/signin/_header'
import cn from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
export default function SignInLayout({ children }: any) {
const { systemFeatures } = useGlobalPublicStore()
return <>
<div className={cn('flex min-h-screen w-full justify-center bg-background-default-burn p-6')}>
<div className={cn('flex w-full shrink-0 flex-col rounded-2xl border border-effects-highlight bg-background-default-subtle')}>
<Header />
<div className={
cn(
'flex w-full grow flex-col items-center justify-center',
'px-6',
'md:px-[108px]',
)
}>
<div className='flex w-[400px] flex-col'>
{children}
</div>
</div>
{!systemFeatures.branding.enabled && <div className='system-xs-regular px-8 py-6 text-text-tertiary'>
© {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
</div>}
</div>
</div>
</>
}

View File

@ -0,0 +1,104 @@
'use client'
import Link from 'next/link'
import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { useState } from 'react'
import { useRouter, useSearchParams } from 'next/navigation'
import { useContext } from 'use-context-selector'
import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown'
import { emailRegex } from '@/config'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import Toast from '@/app/components/base/toast'
import { sendResetPasswordCode } from '@/service/common'
import I18NContext from '@/context/i18n'
import { noop } from 'lodash-es'
import useDocumentTitle from '@/hooks/use-document-title'
export default function CheckCode() {
const { t } = useTranslation()
useDocumentTitle('')
const searchParams = useSearchParams()
const router = useRouter()
const [email, setEmail] = useState('')
const [loading, setIsLoading] = useState(false)
const { locale } = useContext(I18NContext)
const handleGetEMailVerificationCode = async () => {
try {
if (!email) {
Toast.notify({ type: 'error', message: t('login.error.emailEmpty') })
return
}
if (!emailRegex.test(email)) {
Toast.notify({
type: 'error',
message: t('login.error.emailInValid'),
})
return
}
setIsLoading(true)
const res = await sendResetPasswordCode(email, locale)
if (res.result === 'success') {
localStorage.setItem(COUNT_DOWN_KEY, `${COUNT_DOWN_TIME_MS}`)
const params = new URLSearchParams(searchParams)
params.set('token', encodeURIComponent(res.data))
params.set('email', encodeURIComponent(email))
router.push(`/webapp-reset-password/check-code?${params.toString()}`)
}
else if (res.code === 'account_not_found') {
Toast.notify({
type: 'error',
message: t('login.error.registrationNotAllowed'),
})
}
else {
Toast.notify({
type: 'error',
message: res.data,
})
}
}
catch (error) {
console.error(error)
}
finally {
setIsLoading(false)
}
}
return <div className='flex flex-col gap-3'>
<div className='inline-flex h-14 w-14 items-center justify-center rounded-2xl border border-components-panel-border-subtle bg-background-default-dodge shadow-lg'>
<RiLockPasswordLine className='h-6 w-6 text-2xl text-text-accent-light-mode-only' />
</div>
<div className='pb-4 pt-2'>
<h2 className='title-4xl-semi-bold text-text-primary'>{t('login.resetPassword')}</h2>
<p className='body-md-regular mt-2 text-text-secondary'>
{t('login.resetPasswordDesc')}
</p>
</div>
<form onSubmit={noop}>
<input type='text' className='hidden' />
<div className='mb-2'>
<label htmlFor="email" className='system-md-semibold my-2 text-text-secondary'>{t('login.email')}</label>
<div className='mt-1'>
<Input id='email' type="email" disabled={loading} value={email} placeholder={t('login.emailPlaceholder') as string} onChange={e => setEmail(e.target.value)} />
</div>
<div className='mt-3'>
<Button loading={loading} disabled={loading} variant='primary' className='w-full' onClick={handleGetEMailVerificationCode}>{t('login.sendVerificationCode')}</Button>
</div>
</div>
</form>
<div className='py-2'>
<div className='h-px bg-gradient-to-r from-background-gradient-mask-transparent via-divider-regular to-background-gradient-mask-transparent'></div>
</div>
<Link href={`/webapp-signin?${searchParams.toString()}`} className='flex h-9 items-center justify-center text-text-tertiary hover:text-text-primary'>
<div className='inline-block rounded-full bg-background-default-dimmed p-1'>
<RiArrowLeftLine size={12} />
</div>
<span className='system-xs-regular ml-2'>{t('login.backToLogin')}</span>
</Link>
</div>
}

View File

@ -0,0 +1,188 @@
'use client'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useRouter, useSearchParams } from 'next/navigation'
import cn from 'classnames'
import { RiCheckboxCircleFill } from '@remixicon/react'
import { useCountDown } from 'ahooks'
import Button from '@/app/components/base/button'
import { changeWebAppPasswordWithToken } from '@/service/common'
import Toast from '@/app/components/base/toast'
import Input from '@/app/components/base/input'
const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/
const ChangePasswordForm = () => {
const { t } = useTranslation()
const router = useRouter()
const searchParams = useSearchParams()
const token = decodeURIComponent(searchParams.get('token') || '')
const [password, setPassword] = useState('')
const [confirmPassword, setConfirmPassword] = useState('')
const [showSuccess, setShowSuccess] = useState(false)
const [showPassword, setShowPassword] = useState(false)
const [showConfirmPassword, setShowConfirmPassword] = useState(false)
const showErrorMessage = useCallback((message: string) => {
Toast.notify({
type: 'error',
message,
})
}, [])
const getSignInUrl = () => {
return `/webapp-signin?redirect_url=${searchParams.get('redirect_url') || ''}`
}
const AUTO_REDIRECT_TIME = 5000
const [leftTime, setLeftTime] = useState<number | undefined>(undefined)
const [countdown] = useCountDown({
leftTime,
onEnd: () => {
router.replace(getSignInUrl())
},
})
const valid = useCallback(() => {
if (!password.trim()) {
showErrorMessage(t('login.error.passwordEmpty'))
return false
}
if (!validPassword.test(password)) {
showErrorMessage(t('login.error.passwordInvalid'))
return false
}
if (password !== confirmPassword) {
showErrorMessage(t('common.account.notEqual'))
return false
}
return true
}, [password, confirmPassword, showErrorMessage, t])
const handleChangePassword = useCallback(async () => {
if (!valid())
return
try {
await changeWebAppPasswordWithToken({
url: '/forgot-password/resets',
body: {
token,
new_password: password,
password_confirm: confirmPassword,
},
})
setShowSuccess(true)
setLeftTime(AUTO_REDIRECT_TIME)
}
catch (error) {
console.error(error)
}
}, [password, token, valid, confirmPassword])
return (
<div className={
cn(
'flex w-full grow flex-col items-center justify-center',
'px-6',
'md:px-[108px]',
)
}>
{!showSuccess && (
<div className='flex flex-col md:w-[400px]'>
<div className="mx-auto w-full">
<h2 className="title-4xl-semi-bold text-text-primary">
{t('login.changePassword')}
</h2>
<p className='body-md-regular mt-2 text-text-secondary'>
{t('login.changePasswordTip')}
</p>
</div>
<div className="mx-auto mt-6 w-full">
<div className="bg-white">
{/* Password */}
<div className='mb-5'>
<label htmlFor="password" className="system-md-semibold my-2 text-text-secondary">
{t('common.account.newPassword')}
</label>
<div className='relative mt-1'>
<Input
id="password" type={showPassword ? 'text' : 'password'}
value={password}
onChange={e => setPassword(e.target.value)}
placeholder={t('login.passwordPlaceholder') || ''}
/>
<div className="absolute inset-y-0 right-0 flex items-center">
<Button
type="button"
variant='ghost'
onClick={() => setShowPassword(!showPassword)}
>
{showPassword ? '👀' : '😝'}
</Button>
</div>
</div>
<div className='body-xs-regular mt-1 text-text-secondary'>{t('login.error.passwordInvalid')}</div>
</div>
{/* Confirm Password */}
<div className='mb-5'>
<label htmlFor="confirmPassword" className="system-md-semibold my-2 text-text-secondary">
{t('common.account.confirmPassword')}
</label>
<div className='relative mt-1'>
<Input
id="confirmPassword"
type={showConfirmPassword ? 'text' : 'password'}
value={confirmPassword}
onChange={e => setConfirmPassword(e.target.value)}
placeholder={t('login.confirmPasswordPlaceholder') || ''}
/>
<div className="absolute inset-y-0 right-0 flex items-center">
<Button
type="button"
variant='ghost'
onClick={() => setShowConfirmPassword(!showConfirmPassword)}
>
{showConfirmPassword ? '👀' : '😝'}
</Button>
</div>
</div>
</div>
<div>
<Button
variant='primary'
className='w-full'
onClick={handleChangePassword}
>
{t('login.changePasswordBtn')}
</Button>
</div>
</div>
</div>
</div>
)}
{showSuccess && (
<div className="flex flex-col md:w-[400px]">
<div className="mx-auto w-full">
<div className="mb-3 flex h-14 w-14 items-center justify-center rounded-2xl border border-components-panel-border-subtle font-bold shadow-lg">
<RiCheckboxCircleFill className='h-6 w-6 text-text-success' />
</div>
<h2 className="title-4xl-semi-bold text-text-primary">
{t('login.passwordChangedTip')}
</h2>
</div>
<div className="mx-auto mt-6 w-full">
<Button variant='primary' className='w-full' onClick={() => {
setLeftTime(undefined)
router.replace(getSignInUrl())
}}>{t('login.passwordChanged')} ({Math.round(countdown / 1000)}) </Button>
</div>
</div>
)}
</div>
)
}
export default ChangePasswordForm

View File

@ -0,0 +1,115 @@
'use client'
import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { useCallback, useState } from 'react'
import { useRouter, useSearchParams } from 'next/navigation'
import { useContext } from 'use-context-selector'
import Countdown from '@/app/components/signin/countdown'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import Toast from '@/app/components/base/toast'
import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common'
import I18NContext from '@/context/i18n'
import { setAccessToken } from '@/app/components/share/utils'
import { fetchAccessToken } from '@/service/share'
export default function CheckCode() {
const { t } = useTranslation()
const router = useRouter()
const searchParams = useSearchParams()
const email = decodeURIComponent(searchParams.get('email') as string)
const token = decodeURIComponent(searchParams.get('token') as string)
const [code, setVerifyCode] = useState('')
const [loading, setIsLoading] = useState(false)
const { locale } = useContext(I18NContext)
const redirectUrl = searchParams.get('redirect_url')
const getAppCodeFromRedirectUrl = useCallback(() => {
const appCode = redirectUrl?.split('/').pop()
if (!appCode)
return null
return appCode
}, [redirectUrl])
const verify = async () => {
try {
const appCode = getAppCodeFromRedirectUrl()
if (!code.trim()) {
Toast.notify({
type: 'error',
message: t('login.checkCode.emptyCode'),
})
return
}
if (!/\d{6}/.test(code)) {
Toast.notify({
type: 'error',
message: t('login.checkCode.invalidCode'),
})
return
}
if (!redirectUrl || !appCode) {
Toast.notify({
type: 'error',
message: t('login.error.redirectUrlMissing'),
})
return
}
setIsLoading(true)
const ret = await webAppEmailLoginWithCode({ email, code, token })
if (ret.result === 'success') {
localStorage.setItem('webapp_access_token', ret.data.access_token)
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: ret.data.access_token })
await setAccessToken(appCode, tokenResp.access_token)
router.replace(redirectUrl)
}
}
catch (error) { console.error(error) }
finally {
setIsLoading(false)
}
}
const resendCode = async () => {
try {
const ret = await sendWebAppEMailLoginCode(email, locale)
if (ret.result === 'success') {
const params = new URLSearchParams(searchParams)
params.set('token', encodeURIComponent(ret.data))
router.replace(`/webapp-signin/check-code?${params.toString()}`)
}
}
catch (error) { console.error(error) }
}
return <div className='flex w-[400px] flex-col gap-3'>
<div className='inline-flex h-14 w-14 items-center justify-center rounded-2xl border border-components-panel-border-subtle bg-background-default-dodge shadow-lg'>
<RiMailSendFill className='h-6 w-6 text-2xl text-text-accent-light-mode-only' />
</div>
<div className='pb-4 pt-2'>
<h2 className='title-4xl-semi-bold text-text-primary'>{t('login.checkCode.checkYourEmail')}</h2>
<p className='body-md-regular mt-2 text-text-secondary'>
<span dangerouslySetInnerHTML={{ __html: t('login.checkCode.tips', { email }) as string }}></span>
<br />
{t('login.checkCode.validTime')}
</p>
</div>
<form action="">
<label htmlFor="code" className='system-md-semibold mb-1 text-text-secondary'>{t('login.checkCode.verificationCode')}</label>
<Input value={code} onChange={e => setVerifyCode(e.target.value)} max-length={6} className='mt-1' placeholder={t('login.checkCode.verificationCodePlaceholder') as string} />
<Button loading={loading} disabled={loading} className='my-3 w-full' variant='primary' onClick={verify}>{t('login.checkCode.verify')}</Button>
<Countdown onResend={resendCode} />
</form>
<div className='py-2'>
<div className='h-px bg-gradient-to-r from-background-gradient-mask-transparent via-divider-regular to-background-gradient-mask-transparent'></div>
</div>
<div onClick={() => router.back()} className='flex h-9 cursor-pointer items-center justify-center text-text-tertiary'>
<div className='bg-background-default-dimm inline-block rounded-full p-1'>
<RiArrowLeftLine size={12} />
</div>
<span className='system-xs-regular ml-2'>{t('login.back')}</span>
</div>
</div>
}

View File

@ -0,0 +1,80 @@
'use client'
import { useRouter, useSearchParams } from 'next/navigation'
import React, { useCallback, useEffect } from 'react'
import Toast from '@/app/components/base/toast'
import { fetchWebOAuth2SSOUrl, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { SSOProtocol } from '@/types/feature'
import Loading from '@/app/components/base/loading'
import AppUnavailable from '@/app/components/base/app-unavailable'
const ExternalMemberSSOAuth = () => {
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
const searchParams = useSearchParams()
const router = useRouter()
const redirectUrl = searchParams.get('redirect_url')
const showErrorToast = (message: string) => {
Toast.notify({
type: 'error',
message,
})
}
const getAppCodeFromRedirectUrl = useCallback(() => {
const appCode = redirectUrl?.split('/').pop()
if (!appCode)
return null
return appCode
}, [redirectUrl])
const handleSSOLogin = useCallback(async () => {
const appCode = getAppCodeFromRedirectUrl()
if (!appCode || !redirectUrl) {
showErrorToast('redirect url or app code is invalid.')
return
}
switch (systemFeatures.webapp_auth.sso_config.protocol) {
case SSOProtocol.SAML: {
const samlRes = await fetchWebSAMLSSOUrl(appCode, redirectUrl)
router.push(samlRes.url)
break
}
case SSOProtocol.OIDC: {
const oidcRes = await fetchWebOIDCSSOUrl(appCode, redirectUrl)
router.push(oidcRes.url)
break
}
case SSOProtocol.OAuth2: {
const oauth2Res = await fetchWebOAuth2SSOUrl(appCode, redirectUrl)
router.push(oauth2Res.url)
break
}
case '':
break
default:
showErrorToast('SSO protocol is not supported.')
}
}, [getAppCodeFromRedirectUrl, redirectUrl, router, systemFeatures.webapp_auth.sso_config.protocol])
useEffect(() => {
handleSSOLogin()
}, [handleSSOLogin])
if (!systemFeatures.webapp_auth.sso_config.protocol) {
return <div className="flex h-full items-center justify-center">
<AppUnavailable code={403} unknownReason='sso protocol is invalid.' />
</div>
}
return (
<div className="flex h-full items-center justify-center">
<Loading />
</div>
)
}
export default React.memo(ExternalMemberSSOAuth)

Some files were not shown because too many files have changed in this diff Show More