mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/webapp-verified-sso-main
This commit is contained in:
commit
eabd34b2ae
|
|
@ -8,7 +8,7 @@ inputs:
|
|||
uv-version:
|
||||
description: UV version to set up
|
||||
required: true
|
||||
default: '0.6.14'
|
||||
default: '~=0.7.11'
|
||||
uv-lockfile:
|
||||
description: Path to the UV lockfile to restore cache from
|
||||
required: true
|
||||
|
|
|
|||
|
|
@ -192,12 +192,12 @@ sdks/python-client/dist
|
|||
sdks/python-client/dify_client.egg-info
|
||||
|
||||
.vscode/*
|
||||
!.vscode/launch.json
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/README.md
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
|
||||
.idea/
|
||||
.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
|
|
@ -207,3 +207,6 @@ plugins.jsonl
|
|||
|
||||
# mise
|
||||
mise.toml
|
||||
|
||||
# Next.js build output
|
||||
.next/
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
# Debugging with VS Code
|
||||
|
||||
This `launch.json.template` file provides various debug configurations for the Dify project within VS Code / Cursor. To use these configurations, you should copy the contents of this file into a new file named `launch.json` in the same `.vscode` directory.
|
||||
|
||||
## How to Use
|
||||
|
||||
1. **Create `launch.json`**: If you don't have one, create a file named `launch.json` inside the `.vscode` directory.
|
||||
2. **Copy Content**: Copy the entire content from `launch.json.template` into your newly created `launch.json` file.
|
||||
3. **Select Debug Configuration**: Go to the Run and Debug view in VS Code / Cursor (Ctrl+Shift+D or Cmd+Shift+D).
|
||||
4. **Start Debugging**: Select the desired configuration from the dropdown menu and click the green play button.
|
||||
|
||||
## Tips
|
||||
|
||||
- If you need to debug with Edge browser instead of Chrome, modify the `serverReadyAction` configuration in the "Next.js: debug full stack" section, change `"debugWithChrome"` to `"debugWithEdge"` to use Microsoft Edge for debugging.
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Flask API",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_ENV": "development",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001",
|
||||
"--no-debugger",
|
||||
"--no-reload"
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true,
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "Python: Celery Worker (Solo)",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"env": {
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"app.celery",
|
||||
"worker",
|
||||
"-P",
|
||||
"solo",
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "Next.js: debug full stack",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/web/node_modules/next/dist/bin/next",
|
||||
"runtimeArgs": ["--inspect"],
|
||||
"skipFiles": ["<node_internals>/**"],
|
||||
"serverReadyAction": {
|
||||
"action": "debugWithChrome",
|
||||
"killOnServerStop": true,
|
||||
"pattern": "- Local:.+(https?://.+)",
|
||||
"uriFormat": "%s",
|
||||
"webRoot": "${workspaceFolder}/web"
|
||||
},
|
||||
"cwd": "${workspaceFolder}/web"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -491,3 +491,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
|
|||
|
||||
# Prevent Clickjacking
|
||||
ALLOW_EMBED=false
|
||||
|
||||
# Dataset queue monitor configuration
|
||||
QUEUE_MONITOR_THRESHOLD=200
|
||||
# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
|
||||
QUEUE_MONITOR_ALERT_EMAILS=
|
||||
# Monitor interval in minutes, default is 30 minutes
|
||||
QUEUE_MONITOR_INTERVAL=30
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ select = [
|
|||
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
|
||||
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||
"S311", # suspicious-non-cryptographic-random-usage
|
||||
]
|
||||
|
||||
ignore = [
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
|
|||
WORKDIR /app/api
|
||||
|
||||
# Install uv
|
||||
ENV UV_VERSION=0.6.14
|
||||
ENV UV_VERSION=0.7.11
|
||||
|
||||
RUN pip install --no-cache-dir uv==${UV_VERSION}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
from typing import Any, Literal, Optional
|
||||
from urllib.parse import parse_qsl, quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from .cache.redis_config import RedisConfig
|
||||
|
|
@ -256,6 +256,25 @@ class InternalTestConfig(BaseSettings):
|
|||
)
|
||||
|
||||
|
||||
class DatasetQueueMonitorConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Dataset Queue Monitor
|
||||
"""
|
||||
|
||||
QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field(
|
||||
description="Threshold for dataset queue monitor",
|
||||
default=200,
|
||||
)
|
||||
QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field(
|
||||
description="Emails for dataset queue monitor alert, separated by commas",
|
||||
default=None,
|
||||
)
|
||||
QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field(
|
||||
description="Interval for dataset queue monitor in minutes",
|
||||
default=30,
|
||||
)
|
||||
|
||||
|
||||
class MiddlewareConfig(
|
||||
# place the configs in alphabet order
|
||||
CeleryConfig,
|
||||
|
|
@ -303,5 +322,6 @@ class MiddlewareConfig(
|
|||
BaiduVectorDBConfig,
|
||||
OpenGaussConfig,
|
||||
TableStoreConfig,
|
||||
DatasetQueueMonitorConfig,
|
||||
):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
|
|||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
session.refresh(user_model)
|
||||
else:
|
||||
user_model = AccountService.load_user(user_id)
|
||||
if not user_model:
|
||||
|
|
|
|||
|
|
@ -369,6 +369,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
)
|
||||
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.update_tags(args, args.get("tag_id"))
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
|
||||
|
|
|
|||
|
|
@ -175,8 +175,11 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
if not dataset.indexing_technique and not args.get("indexing_technique"):
|
||||
|
||||
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
|
||||
if not indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
args["indexing_technique"] = indexing_technique
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
|
@ -206,12 +209,16 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
knowledge_config = KnowledgeConfig(**args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
|
||||
if not knowledge_config.original_document_id and not dataset_process_rule and not knowledge_config.process_rule:
|
||||
raise ValueError("process_rule is required.")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
dataset_process_rule=dataset_process_rule,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
|
|
|||
|
|
@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel):
|
|||
status: ModelStatus
|
||||
load_balancing_enabled: bool = False
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""
|
||||
Check model status and raise ValueError if not active.
|
||||
|
||||
:raises ValueError: When model status is not active, with a descriptive message
|
||||
"""
|
||||
if self.status == ModelStatus.ACTIVE:
|
||||
return
|
||||
|
||||
error_messages = {
|
||||
ModelStatus.NO_CONFIGURE: "Model is not configured",
|
||||
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
|
||||
ModelStatus.NO_PERMISSION: "No permission to use this model",
|
||||
ModelStatus.DISABLED: "Model is disabled",
|
||||
}
|
||||
|
||||
if self.status in error_messages:
|
||||
raise ValueError(error_messages[self.status])
|
||||
|
||||
|
||||
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -41,45 +41,53 @@ class Extensible:
|
|||
extensions = []
|
||||
position_map: dict[str, int] = {}
|
||||
|
||||
# get the path of the current class
|
||||
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
|
||||
current_dir_path = os.path.dirname(current_path)
|
||||
# Get the package name from the module path
|
||||
package_name = ".".join(cls.__module__.split(".")[:-1])
|
||||
|
||||
# traverse subdirectories
|
||||
for subdir_name in os.listdir(current_dir_path):
|
||||
if subdir_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
# Get package directory path
|
||||
package_spec = importlib.util.find_spec(package_name)
|
||||
if not package_spec or not package_spec.origin:
|
||||
raise ImportError(f"Could not find package {package_name}")
|
||||
|
||||
subdir_path = os.path.join(current_dir_path, subdir_name)
|
||||
extension_name = subdir_name
|
||||
if os.path.isdir(subdir_path):
|
||||
package_dir = os.path.dirname(package_spec.origin)
|
||||
|
||||
# Traverse subdirectories
|
||||
for subdir_name in os.listdir(package_dir):
|
||||
if subdir_name.startswith("__"):
|
||||
continue
|
||||
|
||||
subdir_path = os.path.join(package_dir, subdir_name)
|
||||
if not os.path.isdir(subdir_path):
|
||||
continue
|
||||
|
||||
extension_name = subdir_name
|
||||
file_names = os.listdir(subdir_path)
|
||||
|
||||
# is builtin extension, builtin extension
|
||||
# in the front-end page and business logic, there are special treatments.
|
||||
# Check for extension module file
|
||||
if (extension_name + ".py") not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Check for builtin flag and position
|
||||
builtin = False
|
||||
# default position is 0 can not be None for sort_to_dict_by_position_map
|
||||
position = 0
|
||||
if "__builtin__" in file_names:
|
||||
builtin = True
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, "__builtin__")
|
||||
if os.path.exists(builtin_file_path):
|
||||
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
|
||||
position_map[extension_name] = position
|
||||
|
||||
if (extension_name + ".py") not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
||||
py_path = os.path.join(subdir_path, extension_name + ".py")
|
||||
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
||||
# Import the extension module
|
||||
module_name = f"{package_name}.{extension_name}.{extension_name}"
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if not spec or not spec.loader:
|
||||
raise Exception(f"Failed to load module {extension_name} from {py_path}")
|
||||
raise ImportError(f"Failed to load module {module_name}")
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
# Find extension class
|
||||
extension_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
||||
|
|
@ -87,21 +95,21 @@ class Extensible:
|
|||
break
|
||||
|
||||
if not extension_class:
|
||||
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
|
||||
logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
|
||||
continue
|
||||
|
||||
# Load schema if not builtin
|
||||
json_data: dict[str, Any] = {}
|
||||
if not builtin:
|
||||
if "schema.json" not in file_names:
|
||||
json_path = os.path.join(subdir_path, "schema.json")
|
||||
if not os.path.exists(json_path):
|
||||
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
json_path = os.path.join(subdir_path, "schema.json")
|
||||
json_data = {}
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
# Create extension
|
||||
extensions.append(
|
||||
ModuleExtension(
|
||||
extension_class=extension_class,
|
||||
|
|
@ -113,6 +121,11 @@ class Extensible:
|
|||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Error scanning extensions")
|
||||
raise
|
||||
|
||||
# Sort extensions by position
|
||||
sorted_extensions = sort_to_dict_by_position_map(
|
||||
position_map=position_map, data=extensions, name_func=lambda x: x.name
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
import random
|
||||
import secrets
|
||||
from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
|
|
@ -38,7 +38,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
|
|||
if len(text_chunks) == 0:
|
||||
return True
|
||||
|
||||
text_chunk = random.choice(text_chunks)
|
||||
text_chunk = secrets.choice(text_chunks)
|
||||
|
||||
try:
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
|
|
|
|||
|
|
@ -160,6 +160,10 @@ class ProviderModel(BaseModel):
|
|||
deprecated: bool = False
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def support_structure_output(self) -> bool:
|
||||
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
|
||||
|
||||
|
||||
class ParameterRule(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ class WeaveConfig(BaseTracingConfig):
|
|||
entity: str | None = None
|
||||
project: str
|
||||
endpoint: str = "https://trace.wandb.ai"
|
||||
host: str | None = None
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
|
|
@ -109,6 +110,14 @@ class WeaveConfig(BaseTracingConfig):
|
|||
|
||||
return v
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def validate_host(cls, v, info: ValidationInfo):
|
||||
if v is not None and v != "":
|
||||
if not v.startswith(("https://", "http://")):
|
||||
raise ValueError("host must start with https:// or http://")
|
||||
return v
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
|
|||
return {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint"],
|
||||
"other_keys": ["project", "entity", "endpoint", "host"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -40,9 +40,14 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
self.weave_api_key = weave_config.api_key
|
||||
self.project_name = weave_config.project
|
||||
self.entity = weave_config.entity
|
||||
self.host = weave_config.host
|
||||
|
||||
# Login with API key first, including host if provided
|
||||
if self.host:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
|
||||
else:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
|
||||
# Login with API key first
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
if not login_status:
|
||||
logger.error("Failed to login to Weights & Biases with the provided API key")
|
||||
raise ValueError("Weave login failed")
|
||||
|
|
@ -386,7 +391,11 @@ class WeaveDataTrace(BaseTraceInstance):
|
|||
|
||||
def api_check(self):
|
||||
try:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
if self.host:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
|
||||
else:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
|
||||
if not login_status:
|
||||
raise ValueError("Weave login failed")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ from collections import defaultdict
|
|||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
|
|
@ -393,19 +395,13 @@ class ProviderManager:
|
|||
|
||||
@staticmethod
|
||||
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
|
||||
"""
|
||||
Get all provider records of the workspace.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
|
||||
|
||||
provider_name_to_provider_records_dict = defaultdict(list)
|
||||
for provider in providers:
|
||||
# TODO: Use provider name with prefix after the data migration
|
||||
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
|
||||
providers = session.scalars(stmt)
|
||||
for provider in providers:
|
||||
# Use provider name with prefix after the data migration
|
||||
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -416,17 +412,12 @@ class ProviderManager:
|
|||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider model records of the workspace
|
||||
provider_models = (
|
||||
db.session.query(ProviderModel)
|
||||
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_model_records_dict = defaultdict(list)
|
||||
for provider_model in provider_models:
|
||||
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
provider_models = session.scalars(stmt)
|
||||
for provider_model in provider_models:
|
||||
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
|
||||
return provider_name_to_provider_model_records_dict
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -437,17 +428,14 @@ class ProviderManager:
|
|||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
preferred_provider_types = (
|
||||
db.session.query(TenantPreferredModelProvider)
|
||||
.filter(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
preferred_provider_type.provider_name: preferred_provider_type
|
||||
for preferred_provider_type in preferred_provider_types
|
||||
}
|
||||
|
||||
provider_name_to_preferred_provider_type_records_dict = {}
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
preferred_provider_types = session.scalars(stmt)
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
preferred_provider_type.provider_name: preferred_provider_type
|
||||
for preferred_provider_type in preferred_provider_types
|
||||
}
|
||||
return provider_name_to_preferred_provider_type_records_dict
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -458,18 +446,14 @@ class ProviderManager:
|
|||
:param tenant_id: workspace id
|
||||
:return:
|
||||
"""
|
||||
provider_model_settings = (
|
||||
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_model_settings_dict = defaultdict(list)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
|
||||
provider_model_settings = session.scalars(stmt)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
|
||||
provider_model_setting
|
||||
)
|
||||
)
|
||||
|
||||
return provider_name_to_provider_model_settings_dict
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -492,15 +476,14 @@ class ProviderManager:
|
|||
if not model_load_balancing_enabled:
|
||||
return {}
|
||||
|
||||
provider_load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
|
||||
provider_load_balancing_configs = session.scalars(stmt)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
provider_name_to_provider_load_balancing_model_configs_dict[
|
||||
provider_load_balancing_config.provider_name
|
||||
].append(provider_load_balancing_config)
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
|
|
@ -626,10 +609,9 @@ class ProviderManager:
|
|||
if not cached_provider_credentials:
|
||||
try:
|
||||
# fix origin data
|
||||
if (
|
||||
custom_provider_record.encrypted_config
|
||||
and not custom_provider_record.encrypted_config.startswith("{")
|
||||
):
|
||||
if custom_provider_record.encrypted_config is None:
|
||||
raise ValueError("No credentials found")
|
||||
if not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
|
|
@ -733,7 +715,7 @@ class ProviderManager:
|
|||
return SystemConfiguration(enabled=False)
|
||||
|
||||
# Convert provider_records to dict
|
||||
quota_type_to_provider_records_dict = {}
|
||||
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
|
@ -758,6 +740,11 @@ class ProviderManager:
|
|||
else:
|
||||
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
|
||||
|
||||
if provider_record.quota_used is None:
|
||||
raise ValueError("quota_used is None")
|
||||
if provider_record.quota_limit is None:
|
||||
raise ValueError("quota_limit is None")
|
||||
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
|
|
@ -791,10 +778,9 @@ class ProviderManager:
|
|||
cached_provider_credentials = provider_credentials_cache.get()
|
||||
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
provider_credentials: dict[str, Any] = {}
|
||||
if provider_records and provider_records[0].encrypted_config:
|
||||
provider_credentials = json.loads(provider_records[0].encrypted_config)
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
|
|
|
|||
|
|
@ -720,7 +720,7 @@ STOPWORDS = {
|
|||
"〉",
|
||||
"〈",
|
||||
"…",
|
||||
" ",
|
||||
" ",
|
||||
"0",
|
||||
"1",
|
||||
"2",
|
||||
|
|
@ -731,16 +731,6 @@ STOPWORDS = {
|
|||
"7",
|
||||
"8",
|
||||
"9",
|
||||
"0",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"8",
|
||||
"9",
|
||||
"二",
|
||||
"三",
|
||||
"四",
|
||||
|
|
|
|||
|
|
@ -184,7 +184,16 @@ class OpenSearchVector(BaseVector):
|
|||
}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||
query["query"] = {
|
||||
"script_score": {
|
||||
"query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID.value: document_ids_filter}}]}},
|
||||
"script": {
|
||||
"source": "knn_score",
|
||||
"lang": "knn",
|
||||
"params": {"field": Field.VECTOR.value, "query_value": query_vector, "space_type": "l2"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
|
|
@ -209,10 +218,10 @@ class OpenSearchVector(BaseVector):
|
|||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
|
||||
full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}}
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
|
||||
full_text_query["query"]["bool"]["filter"] = [{"terms": {"metadata.document_id": document_ids_filter}}]
|
||||
|
||||
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
||||
|
||||
|
|
@ -255,7 +264,8 @@ class OpenSearchVector(BaseVector):
|
|||
Field.METADATA_KEY.value: {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
|
||||
"doc_id": {"type": "keyword"}, # Map doc_id to keyword type
|
||||
"document_id": {"type": "keyword"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -261,7 +261,7 @@ class OracleVector(BaseVector):
|
|||
words = pseg.cut(query)
|
||||
current_entity = ""
|
||||
for word, pos in words:
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
|
||||
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
|
||||
current_entity += word
|
||||
else:
|
||||
if current_entity:
|
||||
|
|
@ -303,7 +303,6 @@ class OracleVector(BaseVector):
|
|||
return docs
|
||||
else:
|
||||
return [Document(page_content="", metadata={})]
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_connection() as conn:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
- audio
|
||||
- code
|
||||
- time
|
||||
- qrcode
|
||||
- webscraper
|
||||
|
|
|
|||
|
|
@ -153,8 +153,6 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
return str("\n".join(document_context_list))
|
||||
return ""
|
||||
|
||||
raise RuntimeError("not segments found")
|
||||
|
||||
def _retriever(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
|
|
|
|||
|
|
@ -397,19 +397,44 @@ def _extract_text_from_csv(file_content: bytes) -> str:
|
|||
if not rows:
|
||||
return ""
|
||||
|
||||
# Create Markdown table
|
||||
markdown_table = "| " + " | ".join(rows[0]) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n"
|
||||
for row in rows[1:]:
|
||||
markdown_table += "| " + " | ".join(row) + " |\n"
|
||||
# Combine multi-line text in the header row
|
||||
header_row = [cell.replace("\n", " ").replace("\r", "") for cell in rows[0]]
|
||||
|
||||
return markdown_table.strip()
|
||||
# Create Markdown table
|
||||
markdown_table = "| " + " | ".join(header_row) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["-" * len(col) for col in rows[0]]) + " |\n"
|
||||
|
||||
# Process each data row and combine multi-line text in each cell
|
||||
for row in rows[1:]:
|
||||
processed_row = [cell.replace("\n", " ").replace("\r", "") for cell in row]
|
||||
markdown_table += "| " + " | ".join(processed_row) + " |\n"
|
||||
|
||||
return markdown_table
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
"""Extract text from an Excel file using pandas."""
|
||||
|
||||
def _construct_markdown_table(df: pd.DataFrame) -> str:
|
||||
"""Manually construct a Markdown table from a DataFrame."""
|
||||
# Construct the header row
|
||||
header_row = "| " + " | ".join(df.columns) + " |"
|
||||
|
||||
# Construct the separator row
|
||||
separator_row = "| " + " | ".join(["-" * len(col) for col in df.columns]) + " |"
|
||||
|
||||
# Construct the data rows
|
||||
data_rows = []
|
||||
for _, row in df.iterrows():
|
||||
data_row = "| " + " | ".join(map(str, row)) + " |"
|
||||
data_rows.append(data_row)
|
||||
|
||||
# Combine all rows into a single string
|
||||
markdown_table = "\n".join([header_row, separator_row] + data_rows)
|
||||
return markdown_table
|
||||
|
||||
try:
|
||||
excel_file = pd.ExcelFile(io.BytesIO(file_content))
|
||||
markdown_table = ""
|
||||
|
|
@ -417,8 +442,15 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
|||
try:
|
||||
df = excel_file.parse(sheet_name=sheet_name)
|
||||
df.dropna(how="all", inplace=True)
|
||||
# Create Markdown table two times to separate tables with a newline
|
||||
markdown_table += df.to_markdown(index=False, floatfmt="") + "\n\n"
|
||||
|
||||
# Combine multi-line text in each cell into a single line
|
||||
df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore
|
||||
|
||||
# Combine multi-line text in column names into a single line
|
||||
df.columns = pd.Index([" ".join(col.splitlines()) for col in df.columns])
|
||||
|
||||
# Manually construct the Markdown table
|
||||
markdown_table += _construct_markdown_table(df) + "\n\n"
|
||||
except Exception as e:
|
||||
continue
|
||||
return markdown_table
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import base64
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from random import randint
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
|
|
@ -434,4 +435,4 @@ def _generate_random_string(n: int) -> str:
|
|||
>>> _generate_random_string(5)
|
||||
'abcde'
|
||||
"""
|
||||
return "".join([chr(randint(97, 122)) for _ in range(n)])
|
||||
return "".join(secrets.choice(string.ascii_lowercase) for _ in range(n))
|
||||
|
|
|
|||
|
|
@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData):
|
|||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: dict | None = None
|
||||
structured_output_enabled: bool = False
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData):
|
|||
if v is None:
|
||||
return PromptConfig()
|
||||
return v
|
||||
|
||||
@property
|
||||
def structured_output_enabled(self) -> bool:
|
||||
return self.structured_output_switch_on and self.structured_output is not None
|
||||
|
|
|
|||
|
|
@ -12,9 +12,7 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
|
@ -74,7 +72,6 @@ from core.workflow.nodes.event import (
|
|||
from core.workflow.utils.structured_output.entities import (
|
||||
ResponseFormat,
|
||||
SpecialModelType,
|
||||
SupportStructuredOutputStatus,
|
||||
)
|
||||
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
|
@ -277,7 +274,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
except LLMNodeError as e:
|
||||
except ValueError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
|
@ -527,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model_name = node_data_model.name
|
||||
provider_name = node_data_model.provider
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name, model_type=ModelType.LLM
|
||||
provider_model = model.provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
# model config
|
||||
completion_params = node_data_model.completion_params
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data_model.mode
|
||||
if not model_mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
support_structured_output = self._check_model_structured_output_support()
|
||||
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
|
||||
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
|
||||
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
||||
# Set appropriate response format based on model capabilities
|
||||
self._set_response_format(completion_params, model_schema.parameter_rules)
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
if self.node_data.structured_output_enabled:
|
||||
if model_schema.support_structure_output:
|
||||
node_data_model.completion_params = self._handle_native_json_schema(
|
||||
node_data_model.completion_params, model_schema.parameter_rules
|
||||
)
|
||||
else:
|
||||
# Set appropriate response format based on model capabilities
|
||||
self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules)
|
||||
|
||||
return model, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=model.provider_model_bundle,
|
||||
credentials=model.credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
|
|
@ -786,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
support_structured_output = self._check_model_structured_output_support()
|
||||
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
|
||||
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||
prompt_messages=filtered_prompt_messages,
|
||||
)
|
||||
stop = model_config.stop
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=self.node_data.model.provider,
|
||||
model=self.node_data.model.name,
|
||||
)
|
||||
model_schema = model.model_type_instance.get_model_schema(
|
||||
model=self.node_data.model.name,
|
||||
credentials=model.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.")
|
||||
if self.node_data.structured_output_enabled:
|
||||
if not model_schema.support_structure_output:
|
||||
filtered_prompt_messages = self._handle_prompt_based_schema(
|
||||
prompt_messages=filtered_prompt_messages,
|
||||
)
|
||||
return filtered_prompt_messages, model_config.stop
|
||||
|
||||
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
|
||||
structured_output: dict[str, Any] = {}
|
||||
|
|
@ -903,7 +900,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
variable_mapping["#context#"] = node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value]
|
||||
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
|
||||
|
||||
if node_data.memory:
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
|
||||
|
|
@ -1185,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
except json.JSONDecodeError:
|
||||
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||
|
||||
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
|
||||
"""
|
||||
Check if the current model supports structured output.
|
||||
|
||||
Returns:
|
||||
SupportStructuredOutput: The support status of structured output
|
||||
"""
|
||||
# Early return if structured output is disabled
|
||||
if (
|
||||
not isinstance(self.node_data, LLMNodeData)
|
||||
or not self.node_data.structured_output_enabled
|
||||
or not self.node_data.structured_output
|
||||
):
|
||||
return SupportStructuredOutputStatus.DISABLED
|
||||
# Get model schema and check if it exists
|
||||
model_schema = self._fetch_model_schema(self.node_data.model.provider)
|
||||
if not model_schema:
|
||||
return SupportStructuredOutputStatus.DISABLED
|
||||
|
||||
# Check if model supports structured output feature
|
||||
return (
|
||||
SupportStructuredOutputStatus.SUPPORTED
|
||||
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
|
||||
else SupportStructuredOutputStatus.UNSUPPORTED
|
||||
)
|
||||
|
||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||
self,
|
||||
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||
|
|
|
|||
|
|
@ -14,11 +14,3 @@ class SpecialModelType(StrEnum):
|
|||
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
class SupportStructuredOutputStatus(StrEnum):
|
||||
"""Constants for structured output support status"""
|
||||
|
||||
SUPPORTED = "supported"
|
||||
UNSUPPORTED = "unsupported"
|
||||
DISABLED = "disabled"
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ def init_app(app: DifyApp) -> Celery:
|
|||
"schedule.update_tidb_serverless_status_task",
|
||||
"schedule.clean_messages",
|
||||
"schedule.mail_clean_document_notify_task",
|
||||
"schedule.queue_monitor_task",
|
||||
]
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
beat_schedule = {
|
||||
|
|
@ -98,6 +99,12 @@ def init_app(app: DifyApp) -> Celery:
|
|||
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
|
||||
"schedule": crontab(minute="0", hour="10", day_of_week="1"),
|
||||
},
|
||||
"datasets-queue-monitor": {
|
||||
"task": "schedule.queue_monitor_task.queue_monitor_task",
|
||||
"schedule": timedelta(
|
||||
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
|
||||
),
|
||||
},
|
||||
}
|
||||
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import subprocess
|
||||
import time
|
||||
|
|
@ -18,6 +18,7 @@ from flask_restful import fields
|
|||
from configs import dify_config
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.file import helpers as file_helpers
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -175,7 +176,7 @@ def generate_string(n):
|
|||
letters_digits = string.ascii_letters + string.digits
|
||||
result = ""
|
||||
for i in range(n):
|
||||
result += random.choice(letters_digits)
|
||||
result += secrets.choice(letters_digits)
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -196,7 +197,7 @@ def generate_text_hash(text: str) -> str:
|
|||
|
||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype="application/json")
|
||||
return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json")
|
||||
else:
|
||||
|
||||
def generate() -> Generator:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
"""`workflow_draft_varaibles` add `node_execution_id` column, add an index for `workflow_node_executions`.
|
||||
|
||||
Revision ID: 4474872b0ee6
|
||||
Revises: 2adcbe1f5dfb
|
||||
Create Date: 2025-06-06 14:24:44.213018
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '4474872b0ee6'
|
||||
down_revision = '2adcbe1f5dfb'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# `CREATE INDEX CONCURRENTLY` cannot run within a transaction, so use the `autocommit_block`
|
||||
# context manager to wrap the index creation statement.
|
||||
# Reference:
|
||||
#
|
||||
# - https://www.postgresql.org/docs/current/sql-createindex.html#:~:text=Another%20difference%20is,CREATE%20INDEX%20CONCURRENTLY%20cannot.
|
||||
# - https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.migration.MigrationContext.autocommit_block
|
||||
with op.get_context().autocommit_block():
|
||||
op.create_index(
|
||||
op.f('workflow_node_executions_tenant_id_idx'),
|
||||
"workflow_node_executions",
|
||||
['tenant_id', 'workflow_id', 'node_id', sa.literal_column('created_at DESC')],
|
||||
unique=False,
|
||||
postgresql_concurrently=True,
|
||||
)
|
||||
|
||||
with op.batch_alter_table('workflow_draft_variables', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('node_execution_id', models.types.StringUUID(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# `DROP INDEX CONCURRENTLY` cannot run within a transaction, so use the `autocommit_block`
|
||||
# context manager to wrap the index creation statement.
|
||||
# Reference:
|
||||
#
|
||||
# - https://www.postgresql.org/docs/current/sql-createindex.html#:~:text=Another%20difference%20is,CREATE%20INDEX%20CONCURRENTLY%20cannot.
|
||||
# - https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.migration.MigrationContext.autocommit_block
|
||||
# `DROP INDEX CONCURRENTLY` cannot run within a transaction, so commit existing transactions first.
|
||||
# Reference:
|
||||
#
|
||||
# https://www.postgresql.org/docs/current/sql-createindex.html#:~:text=Another%20difference%20is,CREATE%20INDEX%20CONCURRENTLY%20cannot.
|
||||
with op.get_context().autocommit_block():
|
||||
op.drop_index(op.f('workflow_node_executions_tenant_id_idx'), postgresql_concurrently=True)
|
||||
|
||||
with op.batch_alter_table('workflow_draft_variables', schema=None) as batch_op:
|
||||
batch_op.drop_column('node_execution_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1,6 +1,9 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
|
|
@ -51,20 +54,24 @@ class Provider(Base):
|
|||
),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(255), nullable=False)
|
||||
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
|
||||
encrypted_config = db.Column(db.Text, nullable=True)
|
||||
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
last_used = db.Column(db.DateTime, nullable=True)
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(
|
||||
db.String(40), nullable=False, server_default=text("'custom'::character varying")
|
||||
)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
|
||||
last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
|
||||
|
||||
quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying"))
|
||||
quota_limit = db.Column(db.BigInteger, nullable=True)
|
||||
quota_used = db.Column(db.BigInteger, default=0)
|
||||
quota_type: Mapped[Optional[str]] = mapped_column(
|
||||
db.String(40), nullable=True, server_default=text("''::character varying")
|
||||
)
|
||||
quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True)
|
||||
quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0)
|
||||
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
|
|
@ -104,15 +111,15 @@ class ProviderModel(Base):
|
|||
),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(255), nullable=False)
|
||||
model_name = db.Column(db.String(255), nullable=False)
|
||||
model_type = db.Column(db.String(40), nullable=False)
|
||||
encrypted_config = db.Column(db.Text, nullable=True)
|
||||
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TenantDefaultModel(Base):
|
||||
|
|
@ -122,13 +129,13 @@ class TenantDefaultModel(Base):
|
|||
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(255), nullable=False)
|
||||
model_name = db.Column(db.String(255), nullable=False)
|
||||
model_type = db.Column(db.String(40), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TenantPreferredModelProvider(Base):
|
||||
|
|
@ -138,12 +145,12 @@ class TenantPreferredModelProvider(Base):
|
|||
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(255), nullable=False)
|
||||
preferred_provider_type = db.Column(db.String(40), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ProviderOrder(Base):
|
||||
|
|
@ -153,22 +160,24 @@ class ProviderOrder(Base):
|
|||
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(255), nullable=False)
|
||||
account_id = db.Column(StringUUID, nullable=False)
|
||||
payment_product_id = db.Column(db.String(191), nullable=False)
|
||||
payment_id = db.Column(db.String(191))
|
||||
transaction_id = db.Column(db.String(191))
|
||||
quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1"))
|
||||
currency = db.Column(db.String(40))
|
||||
total_amount = db.Column(db.Integer)
|
||||
payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying"))
|
||||
paid_at = db.Column(db.DateTime)
|
||||
pay_failed_at = db.Column(db.DateTime)
|
||||
refunded_at = db.Column(db.DateTime)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False)
|
||||
payment_id: Mapped[Optional[str]] = mapped_column(db.String(191))
|
||||
transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191))
|
||||
quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1"))
|
||||
currency: Mapped[Optional[str]] = mapped_column(db.String(40))
|
||||
total_amount: Mapped[Optional[int]] = mapped_column(db.Integer)
|
||||
payment_status: Mapped[str] = mapped_column(
|
||||
db.String(40), nullable=False, server_default=text("'wait_pay'::character varying")
|
||||
)
|
||||
paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||
pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||
refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ProviderModelSetting(Base):
|
||||
|
|
@ -182,15 +191,15 @@ class ProviderModelSetting(Base):
|
|||
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(255), nullable=False)
|
||||
model_name = db.Column(db.String(255), nullable=False)
|
||||
model_type = db.Column(db.String(40), nullable=False)
|
||||
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
|
||||
load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class LoadBalancingModelConfig(Base):
|
||||
|
|
@ -204,13 +213,13 @@ class LoadBalancingModelConfig(Base):
|
|||
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(255), nullable=False)
|
||||
model_name = db.Column(db.String(255), nullable=False)
|
||||
model_type = db.Column(db.String(40), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
encrypted_config = db.Column(db.Text, nullable=True)
|
||||
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
||||
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ if TYPE_CHECKING:
|
|||
from models.model import AppMode
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import Index, PrimaryKeyConstraint, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
|
||||
from core.helper import encrypter
|
||||
|
|
@ -590,28 +590,48 @@ class WorkflowNodeExecutionModel(Base):
|
|||
"""
|
||||
|
||||
__tablename__ = "workflow_node_executions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
db.Index(
|
||||
"workflow_node_execution_workflow_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"workflow_run_id",
|
||||
),
|
||||
db.Index(
|
||||
"workflow_node_execution_node_run_idx", "tenant_id", "app_id", "workflow_id", "triggered_from", "node_id"
|
||||
),
|
||||
db.Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
)
|
||||
|
||||
@declared_attr
|
||||
def __table_args__(cls): # noqa
|
||||
return (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
# The first argument is the index name,
|
||||
# which we leave as `None`` to allow auto-generation by the ORM.
|
||||
None,
|
||||
cls.tenant_id,
|
||||
cls.workflow_id,
|
||||
cls.node_id,
|
||||
# MyPy may flag the following line because it doesn't recognize that
|
||||
# the `declared_attr` decorator passes the receiving class as the first
|
||||
# argument to this method, allowing us to reference class attributes.
|
||||
cls.created_at.desc(), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
|
|
@ -885,14 +905,29 @@ class WorkflowDraftVariable(Base):
|
|||
|
||||
selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector")
|
||||
|
||||
# The data type of this variable's value
|
||||
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
|
||||
# JSON string
|
||||
|
||||
# The variable's value serialized as a JSON string
|
||||
value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value")
|
||||
|
||||
# visible
|
||||
# Controls whether the variable should be displayed in the variable inspection panel
|
||||
visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
|
||||
|
||||
# Determines whether this variable can be modified by users
|
||||
editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
|
||||
|
||||
# The `node_execution_id` field identifies the workflow node execution that created this variable.
|
||||
# It corresponds to the `id` field in the `WorkflowNodeExecutionModel` model.
|
||||
#
|
||||
# This field is not `None` for system variables and node variables, and is `None`
|
||||
# for conversation variables.
|
||||
node_execution_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
|
||||
def get_selector(self) -> list[str]:
|
||||
selector = json.loads(self.selector)
|
||||
if not isinstance(selector, list):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import click
|
||||
from flask import render_template
|
||||
from redis import Redis
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_mail import mail
|
||||
|
||||
# Create a dedicated Redis connection (using the same configuration as Celery)
|
||||
celery_broker_url = dify_config.CELERY_BROKER_URL
|
||||
|
||||
parsed = urlparse(celery_broker_url)
|
||||
host = parsed.hostname or "localhost"
|
||||
port = parsed.port or 6379
|
||||
password = parsed.password or None
|
||||
redis_db = parsed.path.strip("/") or "1" # type: ignore
|
||||
|
||||
celery_redis = Redis(host=host, port=port, password=password, db=redis_db)
|
||||
|
||||
|
||||
@app.celery.task(queue="monitor")
|
||||
def queue_monitor_task():
|
||||
queue_name = "dataset"
|
||||
threshold = dify_config.QUEUE_MONITOR_THRESHOLD
|
||||
|
||||
try:
|
||||
queue_length = celery_redis.llen(f"{queue_name}")
|
||||
logging.info(click.style(f"Start monitor {queue_name}", fg="green"))
|
||||
logging.info(click.style(f"Queue length: {queue_length}", fg="green"))
|
||||
|
||||
if queue_length >= threshold:
|
||||
warning_msg = f"Queue {queue_name} task count exceeded the limit.: {queue_length}/{threshold}"
|
||||
logging.warning(click.style(warning_msg, fg="red"))
|
||||
alter_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS
|
||||
if alter_emails:
|
||||
to_list = alter_emails.split(",")
|
||||
for to in to_list:
|
||||
try:
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
html_content = render_template(
|
||||
"queue_monitor_alert_email_template_en-US.html",
|
||||
queue_name=queue_name,
|
||||
queue_length=queue_length,
|
||||
threshold=threshold,
|
||||
alert_time=current_time,
|
||||
)
|
||||
mail.send(
|
||||
to=to, subject="Alert: Dataset Queue pending tasks exceeded the limit", html=html_content
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(click.style("Exception occurred during sending email", fg="red"))
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(click.style("Exception occurred during queue monitoring", fg="red"))
|
||||
finally:
|
||||
if db.session.is_active:
|
||||
db.session.close()
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
|
@ -261,7 +260,7 @@ class AccountService:
|
|||
|
||||
@staticmethod
|
||||
def generate_account_deletion_verification_code(account: Account) -> tuple[str, str]:
|
||||
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
|
||||
token = TokenManager.generate_token(
|
||||
account=account, token_type="account_deletion", additional_data={"code": code}
|
||||
)
|
||||
|
|
@ -429,7 +428,7 @@ class AccountService:
|
|||
additional_data: dict[str, Any] = {},
|
||||
):
|
||||
if not code:
|
||||
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
|
||||
additional_data["code"] = code
|
||||
token = TokenManager.generate_token(
|
||||
account=account, email=email, token_type="reset_password", additional_data=additional_data
|
||||
|
|
@ -456,7 +455,7 @@ class AccountService:
|
|||
|
||||
raise EmailCodeLoginRateLimitExceededError()
|
||||
|
||||
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
|
||||
token = TokenManager.generate_token(
|
||||
account=account, email=email, token_type="email_code_login", additional_data={"code": code}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import copy
|
|||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
|
|
@ -970,7 +970,7 @@ class DocumentService:
|
|||
documents.append(document)
|
||||
batch = document.batch
|
||||
else:
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(100000 + secrets.randbelow(exclusive_upper_bound=900000))
|
||||
# save process rule
|
||||
if not dataset_process_rule:
|
||||
process_rule = knowledge_config.process_rule
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ class TagService:
|
|||
|
||||
@staticmethod
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
|
||||
if not tag_type or not tag_name:
|
||||
return []
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
|
|
@ -88,7 +90,7 @@ class TagService:
|
|||
|
||||
@staticmethod
|
||||
def update_tags(args: dict, tag_id: str) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]):
|
||||
if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")):
|
||||
raise ValueError("Tag name already exists")
|
||||
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
|
||||
if not tag:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import enum
|
||||
import random
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ class WebAppAuthService:
|
|||
if email is None:
|
||||
raise ValueError("Email must be provided.")
|
||||
|
||||
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
|
||||
token = TokenManager.generate_token(
|
||||
account=account, email=email, token_type="email_code_login", additional_data={"code": code}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import uuid
|
|||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
|
|
@ -68,11 +68,6 @@ def batch_create_segment_to_index_task(
|
|||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
word_count_change = 0
|
||||
segments_to_insert: list[str] = []
|
||||
max_position_stmt = select(func.max(DocumentSegment.position)).where(
|
||||
DocumentSegment.document_id == dataset_document.id
|
||||
)
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,129 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
|
||||
<head>
|
||||
<style>
|
||||
body {
|
||||
font-family: 'Arial', sans-serif;
|
||||
line-height: 16pt;
|
||||
color: #101828;
|
||||
background-color: #e9ebf0;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.container {
|
||||
width: 600px;
|
||||
min-height: 605px;
|
||||
margin: 40px auto;
|
||||
padding: 36px 48px;
|
||||
background-color: #fcfcfd;
|
||||
border-radius: 16px;
|
||||
border: 1px solid #ffffff;
|
||||
box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
|
||||
}
|
||||
|
||||
.header {
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
|
||||
.header img {
|
||||
max-width: 100px;
|
||||
height: auto;
|
||||
}
|
||||
|
||||
.title {
|
||||
font-weight: 600;
|
||||
font-size: 24px;
|
||||
line-height: 28.8px;
|
||||
}
|
||||
|
||||
.description {
|
||||
font-size: 13px;
|
||||
line-height: 16px;
|
||||
color: #676f83;
|
||||
margin-top: 12px;
|
||||
}
|
||||
|
||||
.alert-content {
|
||||
padding: 16px 32px;
|
||||
text-align: center;
|
||||
border-radius: 16px;
|
||||
background-color: #fef0f0;
|
||||
margin: 16px auto;
|
||||
border: 1px solid #fda29b;
|
||||
}
|
||||
|
||||
.alert-title {
|
||||
line-height: 24px;
|
||||
font-weight: 700;
|
||||
font-size: 18px;
|
||||
color: #d92d20;
|
||||
}
|
||||
|
||||
.alert-detail {
|
||||
line-height: 20px;
|
||||
font-size: 14px;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
.typography {
|
||||
letter-spacing: -0.07px;
|
||||
font-weight: 400;
|
||||
font-style: normal;
|
||||
font-size: 14px;
|
||||
line-height: 20px;
|
||||
color: #354052;
|
||||
margin-top: 12px;
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
.typography p{
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
.typography-title {
|
||||
color: #101828;
|
||||
font-size: 14px;
|
||||
font-style: normal;
|
||||
font-weight: 600;
|
||||
line-height: 20px;
|
||||
margin-top: 12px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.tip-list{
|
||||
margin: 0;
|
||||
padding-left: 10px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<img src="https://assets.dify.ai/images/logo.png" alt="Dify Logo" />
|
||||
</div>
|
||||
<p class="title">Queue Monitoring Alert</p>
|
||||
<p class="typography">Our system has detected an abnormal queue status that requires your attention:</p>
|
||||
|
||||
<div class="alert-content">
|
||||
<div class="alert-title">Queue Task Alert</div>
|
||||
<div class="alert-detail">
|
||||
Queue "{{queue_name}}" has {{queue_length}} pending tasks (Threshold: {{threshold}})
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="typography">
|
||||
<p style="margin-bottom:4px">Recommended actions:</p>
|
||||
<p>1. Check the queue processing status in the system dashboard</p>
|
||||
<p>2. Verify if there are any processing bottlenecks</p>
|
||||
<p>3. Consider scaling up workers if needed</p>
|
||||
</div>
|
||||
|
||||
<p class="typography-title">Additional Information:</p>
|
||||
<ul class="typography tip-list">
|
||||
<li>Alert triggered at: {{alert_time}}</li>
|
||||
</ul>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
|
|
@ -3,11 +3,16 @@ import os
|
|||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app_factory import create_app
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
|
@ -19,13 +24,27 @@ from core.workflow.nodes.llm.node import LLMNode
|
|||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def app():
|
||||
# Set up storage configuration
|
||||
os.environ["STORAGE_TYPE"] = "opendal"
|
||||
os.environ["OPENDAL_SCHEME"] = "fs"
|
||||
os.environ["OPENDAL_FS_ROOT"] = "storage"
|
||||
|
||||
# Ensure storage directory exists
|
||||
os.makedirs("storage", exist_ok=True)
|
||||
|
||||
app = create_app()
|
||||
dify_config.LOGIN_DISABLED = True
|
||||
return app
|
||||
|
||||
|
||||
def init_llm_node(config: dict) -> LLMNode:
|
||||
graph_config = {
|
||||
"edges": [
|
||||
|
|
@ -40,13 +59,19 @@ def init_llm_node(config: dict) -> LLMNode:
|
|||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
# Use proper UUIDs for database compatibility
|
||||
tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
|
||||
workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
|
||||
user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
|
|
@ -77,115 +102,197 @@ def init_llm_node(config: dict) -> LLMNode:
|
|||
return node
|
||||
|
||||
|
||||
def test_execute_llm(setup_model_mock):
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
def test_execute_llm(app):
|
||||
with app.app_context():
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
"prompt_template": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
|
||||
},
|
||||
{"role": "user", "text": "{{#sys.query#}}"},
|
||||
],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
},
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
|
||||
{"role": "user", "text": "{{#sys.query#}}"},
|
||||
],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
||||
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
# Create a proper LLM result with real entities
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("1000"),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("1000"),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials=credentials,
|
||||
)
|
||||
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, Generator)
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert item.run_result.outputs is not None
|
||||
assert item.run_result.outputs.get("text") is not None
|
||||
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
# Create a simple mock model instance that doesn't call real providers
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create a simple mock model config with required attributes
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "langgenius/openai/openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config_func(_node_data_model):
|
||||
return mock_model_instance, mock_model_config
|
||||
|
||||
# Also mock ModelManager.get_model_instance to avoid database calls
|
||||
def mock_get_model_instance(_self, **kwargs):
|
||||
return mock_model_instance
|
||||
|
||||
with (
|
||||
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
|
||||
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
|
||||
):
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert item.run_result.outputs is not None
|
||||
assert item.run_result.outputs.get("text") is not None
|
||||
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock):
|
||||
def test_execute_llm_with_jinja2(app, setup_code_executor_mock):
|
||||
"""
|
||||
Test execute LLM node with jinja2
|
||||
"""
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
||||
"prompt_config": {
|
||||
"jinja2_variables": [
|
||||
{"variable": "sys_query", "value_selector": ["sys", "query"]},
|
||||
{"variable": "output", "value_selector": ["abc", "output"]},
|
||||
]
|
||||
with app.app_context():
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
||||
"prompt_config": {
|
||||
"jinja2_variables": [
|
||||
{"variable": "sys_query", "value_selector": ["sys", "query"]},
|
||||
{"variable": "output", "value_selector": ["abc", "output"]},
|
||||
]
|
||||
},
|
||||
"prompt_template": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
|
||||
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
|
||||
"edition_type": "jinja2",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"text": "{{#sys.query#}}",
|
||||
"jinja2_text": "{{sys_query}}",
|
||||
"edition_type": "basic",
|
||||
},
|
||||
],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
},
|
||||
"prompt_template": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
|
||||
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
|
||||
"edition_type": "jinja2",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"text": "{{#sys.query#}}",
|
||||
"jinja2_text": "{{sys_query}}",
|
||||
"edition_type": "basic",
|
||||
},
|
||||
],
|
||||
"memory": None,
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
# Create a proper LLM result with real entities
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("1000"),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("1000"),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials=credentials,
|
||||
)
|
||||
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert "sunny" in json.dumps(item.run_result.process_data)
|
||||
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
|
||||
# Create a simple mock model instance that doesn't call real providers
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create a simple mock model config with required attributes
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config_func(_node_data_model):
|
||||
return mock_model_instance, mock_model_config
|
||||
|
||||
# Also mock ModelManager.get_model_instance to avoid database calls
|
||||
def mock_get_model_instance(_self, **kwargs):
|
||||
return mock_model_instance
|
||||
|
||||
with (
|
||||
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
|
||||
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
|
||||
):
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert "sunny" in json.dumps(item.run_result.process_data)
|
||||
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
|
||||
|
||||
|
||||
def test_extract_json():
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import random
|
||||
import secrets
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -34,7 +34,7 @@ def test_retry_logic_success(mock_request):
|
|||
side_effects = []
|
||||
|
||||
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
|
||||
status_code = random.choice(STATUS_FORCELIST)
|
||||
status_code = secrets.choice(STATUS_FORCELIST)
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
side_effects.append(mock_response)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import io
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from docx.oxml.text.paragraph import CT_P
|
||||
|
||||
|
|
@ -187,145 +189,134 @@ def test_node_type(document_extractor_node):
|
|||
|
||||
@patch("pandas.ExcelFile")
|
||||
def test_extract_text_from_excel_single_sheet(mock_excel_file):
|
||||
"""Test extracting text from Excel file with single sheet."""
|
||||
# Mock DataFrame
|
||||
mock_df = Mock()
|
||||
mock_df.dropna = Mock()
|
||||
mock_df.to_markdown.return_value = "| Name | Age |\n|------|-----|\n| John | 25 |"
|
||||
"""Test extracting text from Excel file with single sheet and multiline content."""
|
||||
|
||||
# Test multi-line cell
|
||||
data = {"Name\nwith\nnewline": ["John\nDoe", "Jane\nSmith"], "Age": [25, 30]}
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Mock ExcelFile
|
||||
mock_excel_instance = Mock()
|
||||
mock_excel_instance.sheet_names = ["Sheet1"]
|
||||
mock_excel_instance.parse.return_value = mock_df
|
||||
mock_excel_instance.parse.return_value = df
|
||||
mock_excel_file.return_value = mock_excel_instance
|
||||
|
||||
file_content = b"fake_excel_content"
|
||||
result = _extract_text_from_excel(file_content)
|
||||
expected_manual = "| Name with newline | Age |\n| ----------------- | --- |\n\
|
||||
| John Doe | 25 |\n| Jane Smith | 30 |\n\n"
|
||||
|
||||
expected = "| Name | Age |\n|------|-----|\n| John | 25 |\n\n"
|
||||
assert result == expected
|
||||
mock_excel_file.assert_called_once()
|
||||
mock_df.dropna.assert_called_once_with(how="all", inplace=True)
|
||||
mock_df.to_markdown.assert_called_once_with(index=False, floatfmt="")
|
||||
assert expected_manual == result
|
||||
mock_excel_instance.parse.assert_called_once_with(sheet_name="Sheet1")
|
||||
|
||||
|
||||
@patch("pandas.ExcelFile")
|
||||
def test_extract_text_from_excel_multiple_sheets(mock_excel_file):
|
||||
"""Test extracting text from Excel file with multiple sheets."""
|
||||
# Mock DataFrames for different sheets
|
||||
mock_df1 = Mock()
|
||||
mock_df1.dropna = Mock()
|
||||
mock_df1.to_markdown.return_value = "| Product | Price |\n|---------|-------|\n| Apple | 1.50 |"
|
||||
"""Test extracting text from Excel file with multiple sheets and multiline content."""
|
||||
|
||||
mock_df2 = Mock()
|
||||
mock_df2.dropna = Mock()
|
||||
mock_df2.to_markdown.return_value = "| City | Population |\n|------|------------|\n| NYC | 8000000 |"
|
||||
# Test multi-line cell
|
||||
data1 = {"Product\nName": ["Apple\nRed", "Banana\nYellow"], "Price": [1.50, 0.99]}
|
||||
df1 = pd.DataFrame(data1)
|
||||
|
||||
data2 = {"City\nName": ["New\nYork", "Los\nAngeles"], "Population": [8000000, 3900000]}
|
||||
df2 = pd.DataFrame(data2)
|
||||
|
||||
# Mock ExcelFile
|
||||
mock_excel_instance = Mock()
|
||||
mock_excel_instance.sheet_names = ["Products", "Cities"]
|
||||
mock_excel_instance.parse.side_effect = [mock_df1, mock_df2]
|
||||
mock_excel_instance.parse.side_effect = [df1, df2]
|
||||
mock_excel_file.return_value = mock_excel_instance
|
||||
|
||||
file_content = b"fake_excel_content_multiple_sheets"
|
||||
result = _extract_text_from_excel(file_content)
|
||||
|
||||
expected = (
|
||||
"| Product | Price |\n|---------|-------|\n| Apple | 1.50 |\n\n"
|
||||
"| City | Population |\n|------|------------|\n| NYC | 8000000 |\n\n"
|
||||
)
|
||||
assert result == expected
|
||||
expected_manual1 = "| Product Name | Price |\n| ------------ | ----- |\n\
|
||||
| Apple Red | 1.5 |\n| Banana Yellow | 0.99 |\n\n"
|
||||
expected_manual2 = "| City Name | Population |\n| --------- | ---------- |\n\
|
||||
| New York | 8000000 |\n| Los Angeles | 3900000 |\n\n"
|
||||
|
||||
assert expected_manual1 in result
|
||||
assert expected_manual2 in result
|
||||
|
||||
assert mock_excel_instance.parse.call_count == 2
|
||||
|
||||
|
||||
@patch("pandas.ExcelFile")
|
||||
def test_extract_text_from_excel_empty_sheets(mock_excel_file):
|
||||
"""Test extracting text from Excel file with empty sheets."""
|
||||
# Mock empty DataFrame
|
||||
mock_df = Mock()
|
||||
mock_df.dropna = Mock()
|
||||
mock_df.to_markdown.return_value = ""
|
||||
|
||||
# Empty excel
|
||||
df = pd.DataFrame()
|
||||
|
||||
# Mock ExcelFile
|
||||
mock_excel_instance = Mock()
|
||||
mock_excel_instance.sheet_names = ["EmptySheet"]
|
||||
mock_excel_instance.parse.return_value = mock_df
|
||||
mock_excel_instance.parse.return_value = df
|
||||
mock_excel_file.return_value = mock_excel_instance
|
||||
|
||||
file_content = b"fake_excel_empty_content"
|
||||
result = _extract_text_from_excel(file_content)
|
||||
|
||||
expected = "\n\n"
|
||||
expected = "| |\n| |\n\n"
|
||||
assert result == expected
|
||||
|
||||
mock_excel_instance.parse.assert_called_once_with(sheet_name="EmptySheet")
|
||||
|
||||
|
||||
@patch("pandas.ExcelFile")
|
||||
def test_extract_text_from_excel_sheet_parse_error(mock_excel_file):
|
||||
"""Test handling of sheet parsing errors - should continue with other sheets."""
|
||||
# Mock DataFrames - one successful, one that raises exception
|
||||
mock_df_success = Mock()
|
||||
mock_df_success.dropna = Mock()
|
||||
mock_df_success.to_markdown.return_value = "| Data | Value |\n|------|-------|\n| Test | 123 |"
|
||||
|
||||
# Test error
|
||||
data = {"Data": ["Test"], "Value": [123]}
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Mock ExcelFile
|
||||
mock_excel_instance = Mock()
|
||||
mock_excel_instance.sheet_names = ["GoodSheet", "BadSheet"]
|
||||
mock_excel_instance.parse.side_effect = [mock_df_success, Exception("Parse error")]
|
||||
mock_excel_instance.parse.side_effect = [df, Exception("Parse error")]
|
||||
mock_excel_file.return_value = mock_excel_instance
|
||||
|
||||
file_content = b"fake_excel_mixed_content"
|
||||
result = _extract_text_from_excel(file_content)
|
||||
|
||||
expected = "| Data | Value |\n|------|-------|\n| Test | 123 |\n\n"
|
||||
assert result == expected
|
||||
expected_manual = "| Data | Value |\n| ---- | ----- |\n| Test | 123 |\n\n"
|
||||
|
||||
assert expected_manual == result
|
||||
|
||||
@patch("pandas.ExcelFile")
|
||||
def test_extract_text_from_excel_file_error(mock_excel_file):
|
||||
"""Test handling of Excel file reading errors."""
|
||||
mock_excel_file.side_effect = Exception("Invalid Excel file")
|
||||
|
||||
file_content = b"invalid_excel_content"
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
_extract_text_from_excel(file_content)
|
||||
|
||||
# Note: The function should raise TextExtractionError, but since it's not imported in the test,
|
||||
# we check for the general Exception pattern
|
||||
assert "Failed to extract text from Excel file" in str(exc_info.value)
|
||||
assert mock_excel_instance.parse.call_count == 2
|
||||
|
||||
|
||||
@patch("pandas.ExcelFile")
|
||||
def test_extract_text_from_excel_io_bytesio_usage(mock_excel_file):
|
||||
"""Test that BytesIO is properly used with the file content."""
|
||||
import io
|
||||
|
||||
# Mock DataFrame
|
||||
mock_df = Mock()
|
||||
mock_df.dropna = Mock()
|
||||
mock_df.to_markdown.return_value = "| Test | Data |\n|------|------|\n| 1 | A |"
|
||||
# Test bytesio
|
||||
data = {"Test": [1], "Data": ["A"]}
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Mock ExcelFile
|
||||
mock_excel_instance = Mock()
|
||||
mock_excel_instance.sheet_names = ["TestSheet"]
|
||||
mock_excel_instance.parse.return_value = mock_df
|
||||
mock_excel_instance.parse.return_value = df
|
||||
mock_excel_file.return_value = mock_excel_instance
|
||||
|
||||
file_content = b"test_excel_bytes"
|
||||
result = _extract_text_from_excel(file_content)
|
||||
|
||||
# Verify that ExcelFile was called with a BytesIO object
|
||||
mock_excel_file.assert_called_once()
|
||||
call_args = mock_excel_file.call_args[0][0]
|
||||
assert isinstance(call_args, io.BytesIO)
|
||||
call_arg = mock_excel_file.call_args[0][0]
|
||||
assert isinstance(call_arg, io.BytesIO)
|
||||
|
||||
expected = "| Test | Data |\n|------|------|\n| 1 | A |\n\n"
|
||||
assert result == expected
|
||||
expected_manual = "| Test | Data |\n| ---- | ---- |\n| 1 | A |\n\n"
|
||||
assert expected_manual == result
|
||||
|
||||
|
||||
@patch("pandas.ExcelFile")
|
||||
def test_extract_text_from_excel_all_sheets_fail(mock_excel_file):
|
||||
"""Test when all sheets fail to parse - should return empty string."""
|
||||
|
||||
# Mock ExcelFile
|
||||
mock_excel_instance = Mock()
|
||||
mock_excel_instance.sheet_names = ["BadSheet1", "BadSheet2"]
|
||||
|
|
@ -335,29 +326,6 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file):
|
|||
file_content = b"fake_excel_all_bad_sheets"
|
||||
result = _extract_text_from_excel(file_content)
|
||||
|
||||
# Should return empty string when all sheets fail
|
||||
assert result == ""
|
||||
|
||||
|
||||
@patch("pandas.ExcelFile")
|
||||
def test_extract_text_from_excel_markdown_formatting(mock_excel_file):
|
||||
"""Test that markdown formatting parameters are correctly applied."""
|
||||
# Mock DataFrame
|
||||
mock_df = Mock()
|
||||
mock_df.dropna = Mock()
|
||||
mock_df.to_markdown.return_value = "| Float | Int |\n|-------|-----|\n| 123456.78 | 42 |"
|
||||
|
||||
# Mock ExcelFile
|
||||
mock_excel_instance = Mock()
|
||||
mock_excel_instance.sheet_names = ["NumberSheet"]
|
||||
mock_excel_instance.parse.return_value = mock_df
|
||||
mock_excel_file.return_value = mock_excel_instance
|
||||
|
||||
file_content = b"fake_excel_numbers"
|
||||
result = _extract_text_from_excel(file_content)
|
||||
|
||||
# Verify to_markdown was called with correct parameters
|
||||
mock_df.to_markdown.assert_called_once_with(index=False, floatfmt="")
|
||||
|
||||
expected = "| Float | Int |\n|-------|-----|\n| 123456.78 | 42 |\n\n"
|
||||
assert result == expected
|
||||
assert mock_excel_instance.parse.call_count == 2
|
||||
|
|
|
|||
|
|
@ -1057,7 +1057,7 @@ PLUGIN_MAX_EXECUTION_TIMEOUT=600
|
|||
PIP_MIRROR_URL=
|
||||
|
||||
# https://github.com/langgenius/dify-plugin-daemon/blob/main/.env.example
|
||||
# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss
|
||||
# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss volcengine_tos
|
||||
PLUGIN_STORAGE_TYPE=local
|
||||
PLUGIN_STORAGE_LOCAL_ROOT=/app/storage
|
||||
PLUGIN_WORKING_PATH=/app/storage/cwd
|
||||
|
|
@ -1087,6 +1087,11 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID=
|
|||
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET=
|
||||
PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4
|
||||
PLUGIN_ALIYUN_OSS_PATH=
|
||||
# Plugin oss volcengine tos
|
||||
PLUGIN_VOLCENGINE_TOS_ENDPOINT=
|
||||
PLUGIN_VOLCENGINE_TOS_ACCESS_KEY=
|
||||
PLUGIN_VOLCENGINE_TOS_SECRET_KEY=
|
||||
PLUGIN_VOLCENGINE_TOS_REGION=
|
||||
|
||||
# ------------------------------
|
||||
# OTLP Collector Configuration
|
||||
|
|
@ -1106,3 +1111,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
|
|||
|
||||
# Prevent Clickjacking
|
||||
ALLOW_EMBED=false
|
||||
|
||||
# Dataset queue monitor configuration
|
||||
QUEUE_MONITOR_THRESHOLD=200
|
||||
# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
|
||||
QUEUE_MONITOR_ALERT_EMAILS=
|
||||
# Monitor interval in minutes, default is 30 minutes
|
||||
QUEUE_MONITOR_INTERVAL=30
|
||||
|
|
|
|||
|
|
@ -184,6 +184,10 @@ services:
|
|||
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
|
||||
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
|
||||
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
|
||||
VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
|
||||
VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
|
||||
VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
|
||||
VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
|
||||
ports:
|
||||
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
|
||||
volumes:
|
||||
|
|
|
|||
|
|
@ -121,6 +121,10 @@ services:
|
|||
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
|
||||
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
|
||||
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
|
||||
VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
|
||||
VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
|
||||
VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
|
||||
VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
|
||||
ports:
|
||||
- "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}"
|
||||
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
|
||||
|
|
|
|||
|
|
@ -484,6 +484,10 @@ x-shared-env: &shared-api-worker-env
|
|||
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
|
||||
PLUGIN_ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
|
||||
PLUGIN_ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
|
||||
PLUGIN_VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
|
||||
PLUGIN_VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
|
||||
PLUGIN_VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
|
||||
PLUGIN_VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
|
||||
ENABLE_OTEL: ${ENABLE_OTEL:-false}
|
||||
OTLP_BASE_ENDPOINT: ${OTLP_BASE_ENDPOINT:-http://localhost:4318}
|
||||
OTLP_API_KEY: ${OTLP_API_KEY:-}
|
||||
|
|
@ -497,6 +501,9 @@ x-shared-env: &shared-api-worker-env
|
|||
OTEL_BATCH_EXPORT_TIMEOUT: ${OTEL_BATCH_EXPORT_TIMEOUT:-10000}
|
||||
OTEL_METRIC_EXPORT_TIMEOUT: ${OTEL_METRIC_EXPORT_TIMEOUT:-30000}
|
||||
ALLOW_EMBED: ${ALLOW_EMBED:-false}
|
||||
QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200}
|
||||
QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-}
|
||||
QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30}
|
||||
|
||||
services:
|
||||
# API service
|
||||
|
|
@ -683,6 +690,10 @@ services:
|
|||
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
|
||||
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
|
||||
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
|
||||
VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
|
||||
VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
|
||||
VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
|
||||
VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
|
||||
ports:
|
||||
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
|
||||
volumes:
|
||||
|
|
|
|||
|
|
@ -152,3 +152,8 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID=
|
|||
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET=
|
||||
PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4
|
||||
PLUGIN_ALIYUN_OSS_PATH=
|
||||
# Plugin oss volcengine tos
|
||||
PLUGIN_VOLCENGINE_TOS_ENDPOINT=
|
||||
PLUGIN_VOLCENGINE_TOS_ACCESS_KEY=
|
||||
PLUGIN_VOLCENGINE_TOS_SECRET_KEY=
|
||||
PLUGIN_VOLCENGINE_TOS_REGION=
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class DifyClient:
|
|||
|
||||
def text_to_audio(self, text: str, user: str, streaming: bool = False):
|
||||
data = {"text": text, "user": user, "streaming": streaming}
|
||||
return self._send_request("POST", "/text-to-audio", data=data)
|
||||
return self._send_request("POST", "/text-to-audio", json=data)
|
||||
|
||||
def get_meta(self, user):
|
||||
params = {"user": user}
|
||||
|
|
|
|||
|
|
@ -18,9 +18,10 @@ const queryDateFormat = 'YYYY-MM-DD HH:mm'
|
|||
|
||||
export type IChartViewProps = {
|
||||
appId: string
|
||||
headerRight: React.ReactNode
|
||||
}
|
||||
|
||||
export default function ChartView({ appId }: IChartViewProps) {
|
||||
export default function ChartView({ appId, headerRight }: IChartViewProps) {
|
||||
const { t } = useTranslation()
|
||||
const appDetail = useAppStore(state => state.appDetail)
|
||||
const isChatApp = appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow'
|
||||
|
|
@ -46,19 +47,24 @@ export default function ChartView({ appId }: IChartViewProps) {
|
|||
|
||||
return (
|
||||
<div>
|
||||
<div className='system-xl-semibold mb-4 mt-8 flex flex-row items-center text-text-primary'>
|
||||
<span className='mr-3'>{t('appOverview.analysis.title')}</span>
|
||||
<SimpleSelect
|
||||
items={Object.entries(TIME_PERIOD_MAPPING).map(([k, v]) => ({ value: k, name: t(`appLog.filter.period.${v.name}`) }))}
|
||||
className='mt-0 !w-40'
|
||||
onSelect={(item) => {
|
||||
const id = item.value
|
||||
const value = TIME_PERIOD_MAPPING[id]?.value ?? '-1'
|
||||
const name = item.name || t('appLog.filter.period.allTime')
|
||||
onSelect({ value, name })
|
||||
}}
|
||||
defaultValue={'2'}
|
||||
/>
|
||||
<div className='mb-4'>
|
||||
<div className='system-xl-semibold mb-2 text-text-primary'>{t('common.appMenus.overview')}</div>
|
||||
<div className='flex items-center justify-between'>
|
||||
<div className='flex flex-row items-center'>
|
||||
<SimpleSelect
|
||||
items={Object.entries(TIME_PERIOD_MAPPING).map(([k, v]) => ({ value: k, name: t(`appLog.filter.period.${v.name}`) }))}
|
||||
className='mt-0 !w-40'
|
||||
onSelect={(item) => {
|
||||
const id = item.value
|
||||
const value = TIME_PERIOD_MAPPING[id]?.value ?? '-1'
|
||||
const name = item.name || t('appLog.filter.period.allTime')
|
||||
onSelect({ value, name })
|
||||
}}
|
||||
defaultValue={'2'}
|
||||
/>
|
||||
</div>
|
||||
{headerRight}
|
||||
</div>
|
||||
</div>
|
||||
{!isWorkflow && (
|
||||
<div className='mb-6 grid w-full grid-cols-1 gap-6 xl:grid-cols-2'>
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import React from 'react'
|
||||
import ChartView from './chartView'
|
||||
import CardView from './cardView'
|
||||
import TracingPanel from './tracing/panel'
|
||||
import ApikeyInfoPanel from '@/app/components/app/overview/apikey-info-panel'
|
||||
|
||||
|
|
@ -18,9 +17,10 @@ const Overview = async (props: IDevelopProps) => {
|
|||
return (
|
||||
<div className="h-full overflow-scroll bg-chatbot-bg px-4 py-6 sm:px-12">
|
||||
<ApikeyInfoPanel />
|
||||
<TracingPanel />
|
||||
<CardView appId={appId} />
|
||||
<ChartView appId={appId} />
|
||||
<ChartView
|
||||
appId={appId}
|
||||
headerRight={<TracingPanel />}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,19 +23,6 @@ import Divider from '@/app/components/base/divider'
|
|||
|
||||
const I18N_PREFIX = 'app.tracing'
|
||||
|
||||
const Title = ({
|
||||
className,
|
||||
}: {
|
||||
className?: string
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className={cn('system-xl-semibold flex items-center text-text-primary', className)}>
|
||||
{t('common.appMenus.overview')}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
const Panel: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const pathname = usePathname()
|
||||
|
|
@ -154,7 +141,6 @@ const Panel: FC = () => {
|
|||
if (!isLoaded) {
|
||||
return (
|
||||
<div className='mb-3 flex items-center justify-between'>
|
||||
<Title className='h-[41px]' />
|
||||
<div className='w-[200px]'>
|
||||
<Loading />
|
||||
</div>
|
||||
|
|
@ -163,8 +149,7 @@ const Panel: FC = () => {
|
|||
}
|
||||
|
||||
return (
|
||||
<div className={cn('mb-3 flex items-center justify-between')}>
|
||||
<Title className='h-[41px]' />
|
||||
<div className={cn('flex items-center justify-between')}>
|
||||
<div
|
||||
className={cn(
|
||||
'flex cursor-pointer items-center rounded-xl border-l-[0.5px] border-t border-effects-highlight bg-background-default-dodge p-2 shadow-xs hover:border-effects-highlight-lightmode-off hover:bg-background-default-lighter',
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ const weaveConfigTemplate = {
|
|||
entity: '',
|
||||
project: '',
|
||||
endpoint: '',
|
||||
host: '',
|
||||
}
|
||||
|
||||
const ProviderConfigModal: FC<Props> = ({
|
||||
|
|
@ -226,6 +227,13 @@ const ProviderConfigModal: FC<Props> = ({
|
|||
onChange={handleConfigChange('endpoint')}
|
||||
placeholder={'https://trace.wandb.ai/'}
|
||||
/>
|
||||
<Field
|
||||
label='Host'
|
||||
labelClassName='!text-sm'
|
||||
value={(config as WeaveConfig).host}
|
||||
onChange={handleConfigChange('host')}
|
||||
placeholder={'https://api.wandb.ai'}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{type === TracingProvider.langSmith && (
|
||||
|
|
|
|||
|
|
@ -29,4 +29,5 @@ export type WeaveConfig = {
|
|||
entity: string
|
||||
project: string
|
||||
endpoint: string
|
||||
host: string
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import { useContext, useContextSelector } from 'use-context-selector'
|
|||
import { useRouter } from 'next/navigation'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill } from '@remixicon/react'
|
||||
import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill, RiVerifiedBadgeLine } from '@remixicon/react'
|
||||
import cn from '@/utils/classnames'
|
||||
import type { App } from '@/types/app'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
|
|
@ -338,7 +338,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
</div>
|
||||
<div className='flex h-5 w-5 shrink-0 items-center justify-center'>
|
||||
{app.access_mode === AccessMode.PUBLIC && <Tooltip asChild={false} popupContent={t('app.accessItemsDescription.anyone')}>
|
||||
<RiGlobalLine className='h-4 w-4 text-text-accent' />
|
||||
<RiGlobalLine className='h-4 w-4 text-text-quaternary' />
|
||||
</Tooltip>}
|
||||
{app.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && <Tooltip asChild={false} popupContent={t('app.accessItemsDescription.specific')}>
|
||||
<RiLockLine className='h-4 w-4 text-text-quaternary' />
|
||||
|
|
@ -346,6 +346,9 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
|
|||
{app.access_mode === AccessMode.ORGANIZATION && <Tooltip asChild={false} popupContent={t('app.accessItemsDescription.organization')}>
|
||||
<RiBuildingLine className='h-4 w-4 text-text-quaternary' />
|
||||
</Tooltip>}
|
||||
{app.access_mode === AccessMode.EXTERNAL_MEMBERS && <Tooltip asChild={false} popupContent={t('app.accessItemsDescription.external')}>
|
||||
<RiVerifiedBadgeLine className='h-4 w-4 text-text-quaternary' />
|
||||
</Tooltip>}
|
||||
</div>
|
||||
</div>
|
||||
<div className='title-wrapper h-[90px] px-[14px] text-xs leading-normal text-text-tertiary'>
|
||||
|
|
|
|||
|
|
@ -88,11 +88,11 @@ const Apps = () => {
|
|||
const anchorRef = useRef<HTMLDivElement>(null)
|
||||
const options = [
|
||||
{ value: 'all', text: t('app.types.all'), icon: <RiApps2Line className='mr-1 h-[14px] w-[14px]' /> },
|
||||
{ value: 'workflow', text: t('app.types.workflow'), icon: <RiExchange2Line className='mr-1 h-[14px] w-[14px]' /> },
|
||||
{ value: 'advanced-chat', text: t('app.types.advanced'), icon: <RiMessage3Line className='mr-1 h-[14px] w-[14px]' /> },
|
||||
{ value: 'chat', text: t('app.types.chatbot'), icon: <RiMessage3Line className='mr-1 h-[14px] w-[14px]' /> },
|
||||
{ value: 'agent-chat', text: t('app.types.agent'), icon: <RiRobot3Line className='mr-1 h-[14px] w-[14px]' /> },
|
||||
{ value: 'completion', text: t('app.types.completion'), icon: <RiFile4Line className='mr-1 h-[14px] w-[14px]' /> },
|
||||
{ value: 'advanced-chat', text: t('app.types.advanced'), icon: <RiMessage3Line className='mr-1 h-[14px] w-[14px]' /> },
|
||||
{ value: 'workflow', text: t('app.types.workflow'), icon: <RiExchange2Line className='mr-1 h-[14px] w-[14px]' /> },
|
||||
]
|
||||
|
||||
useEffect(() => {
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ const Container = () => {
|
|||
|
||||
return (
|
||||
<div ref={containerRef} className='scroll-container relative flex grow flex-col overflow-y-auto bg-background-body'>
|
||||
<div className='sticky top-0 z-10 flex flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pb-2 pt-4 leading-[56px]'>
|
||||
<div className='sticky top-0 z-10 flex h-[80px] shrink-0 flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pb-2 pt-4 leading-[56px]'>
|
||||
<TabSliderNew
|
||||
value={activeTab}
|
||||
onChange={newActiveTab => setActiveTab(newActiveTab)}
|
||||
|
|
|
|||
|
|
@ -192,15 +192,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
- original_document_id が渡されない場合、新しい操作が実行され、process_rule が必要です。
|
||||
|
||||
- <code>indexing_technique</code> インデックスモード
|
||||
- <code>high_quality</code> 高品質: 埋め込みモデルを使用してベクトルデータベースインデックスを構築
|
||||
- <code>economy</code> 経済: キーワードテーブルインデックスの反転インデックスを構築
|
||||
- <code>high_quality</code> 高品質:埋め込みモデルを使用してベクトルデータベースインデックスを構築
|
||||
- <code>economy</code> 経済:キーワードテーブルインデックスの反転インデックスを構築
|
||||
|
||||
- <code>doc_form</code> インデックス化された内容の形式
|
||||
- <code>text_model</code> テキストドキュメントは直接埋め込まれます; `economy` モードではこの形式がデフォルト
|
||||
- <code>hierarchical_model</code> 親子モード
|
||||
- <code>qa_model</code> Q&A モード: 分割されたドキュメントの質問と回答ペアを生成し、質問を埋め込みます
|
||||
- <code>qa_model</code> Q&A モード:分割されたドキュメントの質問と回答ペアを生成し、質問を埋め込みます
|
||||
|
||||
- <code>doc_language</code> Q&A モードでは、ドキュメントの言語を指定します。例: <code>English</code>, <code>Chinese</code>
|
||||
- <code>doc_language</code> Q&A モードでは、ドキュメントの言語を指定します。例:<code>English</code>, <code>Chinese</code>
|
||||
|
||||
- <code>process_rule</code> 処理ルール
|
||||
- <code>mode</code> (string) クリーニング、セグメンテーションモード、自動 / カスタム
|
||||
|
|
@ -214,7 +214,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
- <code>segmentation</code> (object) セグメンテーションルール
|
||||
- <code>separator</code> カスタムセグメント識別子。現在は 1 つの区切り文字のみ設定可能。デフォルトは \n
|
||||
- <code>max_tokens</code> 最大長 (トークン) デフォルトは 1000
|
||||
- <code>parent_mode</code> 親チャンクの検索モード: <code>full-doc</code> 全文検索 / <code>paragraph</code> 段落検索
|
||||
- <code>parent_mode</code> 親チャンクの検索モード:<code>full-doc</code> 全文検索 / <code>paragraph</code> 段落検索
|
||||
- <code>subchunk_segmentation</code> (object) 子チャンクルール
|
||||
- <code>separator</code> セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは <code>***</code>
|
||||
- <code>max_tokens</code> 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります
|
||||
|
|
@ -324,7 +324,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
- <code>partial_members</code> 一部のメンバー
|
||||
</Property>
|
||||
<Property name='provider' type='string' key='provider'>
|
||||
プロバイダー (オプション、デフォルト: vendor)
|
||||
プロバイダー (オプション、デフォルト:vendor)
|
||||
- <code>vendor</code> ベンダー
|
||||
- <code>external</code> 外部ナレッジ
|
||||
</Property>
|
||||
|
|
@ -415,16 +415,16 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
検索キーワード、オプション
|
||||
</Property>
|
||||
<Property name='tag_ids' type='array[string]' key='tag_ids'>
|
||||
タグIDリスト、オプション
|
||||
タグ ID リスト、オプション
|
||||
</Property>
|
||||
<Property name='page' type='string' key='page'>
|
||||
ページ番号、オプション、デフォルト1
|
||||
ページ番号、オプション、デフォルト 1
|
||||
</Property>
|
||||
<Property name='limit' type='string' key='limit'>
|
||||
返されるアイテム数、オプション、デフォルト20、範囲1-100
|
||||
返されるアイテム数、オプション、デフォルト 20、範囲 1-100
|
||||
</Property>
|
||||
<Property name='include_all' type='boolean' key='include_all'>
|
||||
すべてのデータセットを含めるかどうか(所有者のみ有効)、オプション、デフォルトはfalse
|
||||
すべてのデータセットを含めるかどうか(所有者のみ有効)、オプション、デフォルトは false
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -2013,7 +2013,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
### Request Body
|
||||
<Properties>
|
||||
<Property name='name' type='string'>
|
||||
(text) 新しいタグ名、必須、最大長50文字
|
||||
(text) 新しいタグ名、必須、最大長 50 文字
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -2099,10 +2099,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
### Request Body
|
||||
<Properties>
|
||||
<Property name='name' type='string'>
|
||||
(text) 変更後のタグ名、必須、最大長50文字
|
||||
(text) 変更後のタグ名、必須、最大長 50 文字
|
||||
</Property>
|
||||
<Property name='tag_id' type='string'>
|
||||
(text) タグID、必須
|
||||
(text) タグ ID、必須
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -2147,7 +2147,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
### Request Body
|
||||
<Properties>
|
||||
<Property name='tag_id' type='string'>
|
||||
(text) タグID、必須
|
||||
(text) タグ ID、必須
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -2188,10 +2188,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
### Request Body
|
||||
<Properties>
|
||||
<Property name='tag_ids' type='list'>
|
||||
(list) タグIDリスト、必須
|
||||
(list) タグ ID リスト、必須
|
||||
</Property>
|
||||
<Property name='target_id' type='string'>
|
||||
(text) ナレッジベースID、必須
|
||||
(text) ナレッジベース ID、必須
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -2230,10 +2230,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
### Request Body
|
||||
<Properties>
|
||||
<Property name='tag_id' type='string'>
|
||||
(text) タグID、必須
|
||||
(text) タグ ID、必須
|
||||
</Property>
|
||||
<Property name='target_id' type='string'>
|
||||
(text) ナレッジベースID、必須
|
||||
(text) ナレッジベース ID、必須
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -2273,7 +2273,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
### Path
|
||||
<Properties>
|
||||
<Property name='dataset_id' type='string'>
|
||||
(text) ナレッジベースID
|
||||
(text) ナレッジベース ID
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
- <code>doc_language</code> 在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code>
|
||||
|
||||
- <code>process_rule</code> 处理规则
|
||||
- <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 / hierarchical 父子
|
||||
- <code>mode</code> (string) 清洗、分段模式,automatic 自动 / custom 自定义 / hierarchical 父子
|
||||
- <code>rules</code> (object) 自定义规则(自动模式下,该字段为空)
|
||||
- <code>pre_processing_rules</code> (array[object]) 预处理规则
|
||||
- <code>id</code> (string) 预处理规则的唯一标识符
|
||||
|
|
@ -234,12 +234,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
- <code>hybrid_search</code> 混合检索
|
||||
- <code>semantic_search</code> 语义检索
|
||||
- <code>full_text_search</code> 全文检索
|
||||
- <code>reranking_enable</code> (bool) 是否开启rerank
|
||||
- <code>reranking_enable</code> (bool) 是否开启 rerank
|
||||
- <code>reranking_model</code> (object) Rerank 模型配置
|
||||
- <code>reranking_provider_name</code> (string) Rerank 模型的提供商
|
||||
- <code>reranking_model_name</code> (string) Rerank 模型的名称
|
||||
- <code>top_k</code> (int) 召回条数
|
||||
- <code>score_threshold_enabled</code> (bool)是否开启召回分数限制
|
||||
- <code>score_threshold_enabled</code> (bool) 是否开启召回分数限制
|
||||
- <code>score_threshold</code> (float) 召回分数限制
|
||||
</Property>
|
||||
<Property name='embedding_model' type='string' key='embedding_model'>
|
||||
|
|
@ -350,12 +350,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
- <code>hybrid_search</code> 混合检索
|
||||
- <code>semantic_search</code> 语义检索
|
||||
- <code>full_text_search</code> 全文检索
|
||||
- <code>reranking_enable</code> (bool) 是否开启rerank
|
||||
- <code>reranking_enable</code> (bool) 是否开启 rerank
|
||||
- <code>reranking_model</code> (object) Rerank 模型配置
|
||||
- <code>reranking_provider_name</code> (string) Rerank 模型的提供商
|
||||
- <code>reranking_model_name</code> (string) Rerank 模型的名称
|
||||
- <code>top_k</code> (int) 召回条数
|
||||
- <code>score_threshold_enabled</code> (bool)是否开启召回分数限制
|
||||
- <code>score_threshold_enabled</code> (bool) 是否开启召回分数限制
|
||||
- <code>score_threshold</code> (float) 召回分数限制
|
||||
</Property>
|
||||
</Properties>
|
||||
|
|
@ -1322,7 +1322,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
文档 ID
|
||||
</Property>
|
||||
<Property name='segment_id' type='string' key='segment_id'>
|
||||
文档分段ID
|
||||
文档分段 ID
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -1435,7 +1435,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
文档 ID
|
||||
</Property>
|
||||
<Property name='segment_id' type='string' key='segment_id'>
|
||||
文档分段ID
|
||||
文档分段 ID
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
|
|
@ -2223,7 +2223,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
- <code>document_id</code> (string) 文档 ID
|
||||
- <code>metadata_list</code> (list) 元数据列表
|
||||
- <code>id</code> (string) 元数据 ID
|
||||
- <code>type</code> (string) 元数据类型
|
||||
- <code>value</code> (string) 元数据值
|
||||
- <code>name</code> (string) 元数据名称
|
||||
</Property>
|
||||
</Properties>
|
||||
|
|
@ -2404,7 +2404,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
### Request Body
|
||||
<Properties>
|
||||
<Property name='name' type='string'>
|
||||
(text) 新标签名称,必填,最大长度为50
|
||||
(text) 新标签名称,必填,最大长度为 50
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -2490,10 +2490,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
### Request Body
|
||||
<Properties>
|
||||
<Property name='name' type='string'>
|
||||
(text) 修改后的标签名称,必填,最大长度为50
|
||||
(text) 修改后的标签名称,必填,最大长度为 50
|
||||
</Property>
|
||||
<Property name='tag_id' type='string'>
|
||||
(text) 标签ID,必填
|
||||
(text) 标签 ID,必填
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -2538,7 +2538,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
### Request Body
|
||||
<Properties>
|
||||
<Property name='tag_id' type='string'>
|
||||
(text) 标签ID,必填
|
||||
(text) 标签 ID,必填
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -2579,10 +2579,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
### Request Body
|
||||
<Properties>
|
||||
<Property name='tag_ids' type='list'>
|
||||
(list) 标签ID列表,必填
|
||||
(list) 标签 ID 列表,必填
|
||||
</Property>
|
||||
<Property name='target_id' type='string'>
|
||||
(text) 知识库ID,必填
|
||||
(text) 知识库 ID,必填
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -2621,10 +2621,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
### Request Body
|
||||
<Properties>
|
||||
<Property name='tag_id' type='string'>
|
||||
(text) 标签ID,必填
|
||||
(text) 标签 ID,必填
|
||||
</Property>
|
||||
<Property name='target_id' type='string'>
|
||||
(text) 知识库ID,必填
|
||||
(text) 知识库 ID,必填
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
@ -2664,7 +2664,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
|
|||
### Path
|
||||
<Properties>
|
||||
<Property name='dataset_id' type='string'>
|
||||
(text) 知识库ID
|
||||
(text) 知识库 ID
|
||||
</Property>
|
||||
</Properties>
|
||||
</Col>
|
||||
|
|
|
|||
|
|
@ -1,14 +1,42 @@
|
|||
import React from 'react'
|
||||
'use client'
|
||||
import React, { useEffect, useState } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import type { Metadata } from 'next'
|
||||
|
||||
export const metadata: Metadata = {
|
||||
icons: 'data:,', // prevent browser from using default favicon
|
||||
}
|
||||
import { usePathname, useSearchParams } from 'next/navigation'
|
||||
import Loading from '../components/base/loading'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { getAppAccessModeByAppCode } from '@/service/share'
|
||||
|
||||
const Layout: FC<{
|
||||
children: React.ReactNode
|
||||
}> = ({ children }) => {
|
||||
const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending)
|
||||
const setWebAppAccessMode = useGlobalPublicStore(s => s.setWebAppAccessMode)
|
||||
const pathname = usePathname()
|
||||
const searchParams = useSearchParams()
|
||||
const redirectUrl = searchParams.get('redirect_url')
|
||||
const [isLoading, setIsLoading] = useState(true)
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
let appCode: string | null = null
|
||||
if (redirectUrl)
|
||||
appCode = redirectUrl?.split('/').pop() || null
|
||||
else
|
||||
appCode = pathname.split('/').pop() || null
|
||||
|
||||
if (!appCode)
|
||||
return
|
||||
setIsLoading(true)
|
||||
const ret = await getAppAccessModeByAppCode(appCode)
|
||||
setWebAppAccessMode(ret?.accessMode || AccessMode.PUBLIC)
|
||||
setIsLoading(false)
|
||||
})()
|
||||
}, [pathname, redirectUrl, setWebAppAccessMode])
|
||||
if (isLoading || isGlobalPending) {
|
||||
return <div className='flex h-full w-full items-center justify-center'>
|
||||
<Loading />
|
||||
</div>
|
||||
}
|
||||
return (
|
||||
<div className="h-full min-w-[300px] pb-[env(safe-area-inset-bottom)]">
|
||||
{children}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,96 @@
|
|||
'use client'
|
||||
import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useState } from 'react'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import Countdown from '@/app/components/signin/countdown'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { sendWebAppResetPasswordCode, verifyWebAppResetPasswordCode } from '@/service/common'
|
||||
import I18NContext from '@/context/i18n'
|
||||
|
||||
export default function CheckCode() {
|
||||
const { t } = useTranslation()
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
const email = decodeURIComponent(searchParams.get('email') as string)
|
||||
const token = decodeURIComponent(searchParams.get('token') as string)
|
||||
const [code, setVerifyCode] = useState('')
|
||||
const [loading, setIsLoading] = useState(false)
|
||||
const { locale } = useContext(I18NContext)
|
||||
|
||||
const verify = async () => {
|
||||
try {
|
||||
if (!code.trim()) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('login.checkCode.emptyCode'),
|
||||
})
|
||||
return
|
||||
}
|
||||
if (!/\d{6}/.test(code)) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('login.checkCode.invalidCode'),
|
||||
})
|
||||
return
|
||||
}
|
||||
setIsLoading(true)
|
||||
const ret = await verifyWebAppResetPasswordCode({ email, code, token })
|
||||
if (ret.is_valid) {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
params.set('token', encodeURIComponent(ret.token))
|
||||
router.push(`/webapp-reset-password/set-password?${params.toString()}`)
|
||||
}
|
||||
}
|
||||
catch (error) { console.error(error) }
|
||||
finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
const resendCode = async () => {
|
||||
try {
|
||||
const res = await sendWebAppResetPasswordCode(email, locale)
|
||||
if (res.result === 'success') {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
params.set('token', encodeURIComponent(res.data))
|
||||
router.replace(`/webapp-reset-password/check-code?${params.toString()}`)
|
||||
}
|
||||
}
|
||||
catch (error) { console.error(error) }
|
||||
}
|
||||
|
||||
return <div className='flex flex-col gap-3'>
|
||||
<div className='inline-flex h-14 w-14 items-center justify-center rounded-2xl border border-components-panel-border-subtle bg-background-default-dodge text-text-accent-light-mode-only shadow-lg'>
|
||||
<RiMailSendFill className='h-6 w-6 text-2xl' />
|
||||
</div>
|
||||
<div className='pb-4 pt-2'>
|
||||
<h2 className='title-4xl-semi-bold text-text-primary'>{t('login.checkCode.checkYourEmail')}</h2>
|
||||
<p className='body-md-regular mt-2 text-text-secondary'>
|
||||
<span dangerouslySetInnerHTML={{ __html: t('login.checkCode.tips', { email }) as string }}></span>
|
||||
<br />
|
||||
{t('login.checkCode.validTime')}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<form action="">
|
||||
<input type='text' className='hidden' />
|
||||
<label htmlFor="code" className='system-md-semibold mb-1 text-text-secondary'>{t('login.checkCode.verificationCode')}</label>
|
||||
<Input value={code} onChange={e => setVerifyCode(e.target.value)} max-length={6} className='mt-1' placeholder={t('login.checkCode.verificationCodePlaceholder') as string} />
|
||||
<Button loading={loading} disabled={loading} className='my-3 w-full' variant='primary' onClick={verify}>{t('login.checkCode.verify')}</Button>
|
||||
<Countdown onResend={resendCode} />
|
||||
</form>
|
||||
<div className='py-2'>
|
||||
<div className='h-px bg-gradient-to-r from-background-gradient-mask-transparent via-divider-regular to-background-gradient-mask-transparent'></div>
|
||||
</div>
|
||||
<div onClick={() => router.back()} className='flex h-9 cursor-pointer items-center justify-center text-text-tertiary'>
|
||||
<div className='bg-background-default-dimm inline-block rounded-full p-1'>
|
||||
<RiArrowLeftLine size={12} />
|
||||
</div>
|
||||
<span className='system-xs-regular ml-2'>{t('login.back')}</span>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
'use client'
|
||||
import Header from '@/app/signin/_header'
|
||||
|
||||
import cn from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
||||
export default function SignInLayout({ children }: any) {
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
return <>
|
||||
<div className={cn('flex min-h-screen w-full justify-center bg-background-default-burn p-6')}>
|
||||
<div className={cn('flex w-full shrink-0 flex-col rounded-2xl border border-effects-highlight bg-background-default-subtle')}>
|
||||
<Header />
|
||||
<div className={
|
||||
cn(
|
||||
'flex w-full grow flex-col items-center justify-center',
|
||||
'px-6',
|
||||
'md:px-[108px]',
|
||||
)
|
||||
}>
|
||||
<div className='flex w-[400px] flex-col'>
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
{!systemFeatures.branding.enabled && <div className='system-xs-regular px-8 py-6 text-text-tertiary'>
|
||||
© {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
|
||||
</div>}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
}
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
'use client'
|
||||
import Link from 'next/link'
|
||||
import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useState } from 'react'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown'
|
||||
import { emailRegex } from '@/config'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { sendResetPasswordCode } from '@/service/common'
|
||||
import I18NContext from '@/context/i18n'
|
||||
import { noop } from 'lodash-es'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
|
||||
export default function CheckCode() {
|
||||
const { t } = useTranslation()
|
||||
useDocumentTitle('')
|
||||
const searchParams = useSearchParams()
|
||||
const router = useRouter()
|
||||
const [email, setEmail] = useState('')
|
||||
const [loading, setIsLoading] = useState(false)
|
||||
const { locale } = useContext(I18NContext)
|
||||
|
||||
const handleGetEMailVerificationCode = async () => {
|
||||
try {
|
||||
if (!email) {
|
||||
Toast.notify({ type: 'error', message: t('login.error.emailEmpty') })
|
||||
return
|
||||
}
|
||||
|
||||
if (!emailRegex.test(email)) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('login.error.emailInValid'),
|
||||
})
|
||||
return
|
||||
}
|
||||
setIsLoading(true)
|
||||
const res = await sendResetPasswordCode(email, locale)
|
||||
if (res.result === 'success') {
|
||||
localStorage.setItem(COUNT_DOWN_KEY, `${COUNT_DOWN_TIME_MS}`)
|
||||
const params = new URLSearchParams(searchParams)
|
||||
params.set('token', encodeURIComponent(res.data))
|
||||
params.set('email', encodeURIComponent(email))
|
||||
router.push(`/webapp-reset-password/check-code?${params.toString()}`)
|
||||
}
|
||||
else if (res.code === 'account_not_found') {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('login.error.registrationNotAllowed'),
|
||||
})
|
||||
}
|
||||
else {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: res.data,
|
||||
})
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
console.error(error)
|
||||
}
|
||||
finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
return <div className='flex flex-col gap-3'>
|
||||
<div className='inline-flex h-14 w-14 items-center justify-center rounded-2xl border border-components-panel-border-subtle bg-background-default-dodge shadow-lg'>
|
||||
<RiLockPasswordLine className='h-6 w-6 text-2xl text-text-accent-light-mode-only' />
|
||||
</div>
|
||||
<div className='pb-4 pt-2'>
|
||||
<h2 className='title-4xl-semi-bold text-text-primary'>{t('login.resetPassword')}</h2>
|
||||
<p className='body-md-regular mt-2 text-text-secondary'>
|
||||
{t('login.resetPasswordDesc')}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<form onSubmit={noop}>
|
||||
<input type='text' className='hidden' />
|
||||
<div className='mb-2'>
|
||||
<label htmlFor="email" className='system-md-semibold my-2 text-text-secondary'>{t('login.email')}</label>
|
||||
<div className='mt-1'>
|
||||
<Input id='email' type="email" disabled={loading} value={email} placeholder={t('login.emailPlaceholder') as string} onChange={e => setEmail(e.target.value)} />
|
||||
</div>
|
||||
<div className='mt-3'>
|
||||
<Button loading={loading} disabled={loading} variant='primary' className='w-full' onClick={handleGetEMailVerificationCode}>{t('login.sendVerificationCode')}</Button>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
<div className='py-2'>
|
||||
<div className='h-px bg-gradient-to-r from-background-gradient-mask-transparent via-divider-regular to-background-gradient-mask-transparent'></div>
|
||||
</div>
|
||||
<Link href={`/webapp-signin?${searchParams.toString()}`} className='flex h-9 items-center justify-center text-text-tertiary hover:text-text-primary'>
|
||||
<div className='inline-block rounded-full bg-background-default-dimmed p-1'>
|
||||
<RiArrowLeftLine size={12} />
|
||||
</div>
|
||||
<span className='system-xs-regular ml-2'>{t('login.backToLogin')}</span>
|
||||
</Link>
|
||||
</div>
|
||||
}
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
'use client'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import cn from 'classnames'
|
||||
import { RiCheckboxCircleFill } from '@remixicon/react'
|
||||
import { useCountDown } from 'ahooks'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { changeWebAppPasswordWithToken } from '@/service/common'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import Input from '@/app/components/base/input'
|
||||
|
||||
const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/
|
||||
|
||||
const ChangePasswordForm = () => {
|
||||
const { t } = useTranslation()
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
const token = decodeURIComponent(searchParams.get('token') || '')
|
||||
|
||||
const [password, setPassword] = useState('')
|
||||
const [confirmPassword, setConfirmPassword] = useState('')
|
||||
const [showSuccess, setShowSuccess] = useState(false)
|
||||
const [showPassword, setShowPassword] = useState(false)
|
||||
const [showConfirmPassword, setShowConfirmPassword] = useState(false)
|
||||
|
||||
const showErrorMessage = useCallback((message: string) => {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message,
|
||||
})
|
||||
}, [])
|
||||
|
||||
const getSignInUrl = () => {
|
||||
return `/webapp-signin?redirect_url=${searchParams.get('redirect_url') || ''}`
|
||||
}
|
||||
|
||||
const AUTO_REDIRECT_TIME = 5000
|
||||
const [leftTime, setLeftTime] = useState<number | undefined>(undefined)
|
||||
const [countdown] = useCountDown({
|
||||
leftTime,
|
||||
onEnd: () => {
|
||||
router.replace(getSignInUrl())
|
||||
},
|
||||
})
|
||||
|
||||
const valid = useCallback(() => {
|
||||
if (!password.trim()) {
|
||||
showErrorMessage(t('login.error.passwordEmpty'))
|
||||
return false
|
||||
}
|
||||
if (!validPassword.test(password)) {
|
||||
showErrorMessage(t('login.error.passwordInvalid'))
|
||||
return false
|
||||
}
|
||||
if (password !== confirmPassword) {
|
||||
showErrorMessage(t('common.account.notEqual'))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}, [password, confirmPassword, showErrorMessage, t])
|
||||
|
||||
const handleChangePassword = useCallback(async () => {
|
||||
if (!valid())
|
||||
return
|
||||
try {
|
||||
await changeWebAppPasswordWithToken({
|
||||
url: '/forgot-password/resets',
|
||||
body: {
|
||||
token,
|
||||
new_password: password,
|
||||
password_confirm: confirmPassword,
|
||||
},
|
||||
})
|
||||
setShowSuccess(true)
|
||||
setLeftTime(AUTO_REDIRECT_TIME)
|
||||
}
|
||||
catch (error) {
|
||||
console.error(error)
|
||||
}
|
||||
}, [password, token, valid, confirmPassword])
|
||||
|
||||
return (
|
||||
<div className={
|
||||
cn(
|
||||
'flex w-full grow flex-col items-center justify-center',
|
||||
'px-6',
|
||||
'md:px-[108px]',
|
||||
)
|
||||
}>
|
||||
{!showSuccess && (
|
||||
<div className='flex flex-col md:w-[400px]'>
|
||||
<div className="mx-auto w-full">
|
||||
<h2 className="title-4xl-semi-bold text-text-primary">
|
||||
{t('login.changePassword')}
|
||||
</h2>
|
||||
<p className='body-md-regular mt-2 text-text-secondary'>
|
||||
{t('login.changePasswordTip')}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="mx-auto mt-6 w-full">
|
||||
<div className="bg-white">
|
||||
{/* Password */}
|
||||
<div className='mb-5'>
|
||||
<label htmlFor="password" className="system-md-semibold my-2 text-text-secondary">
|
||||
{t('common.account.newPassword')}
|
||||
</label>
|
||||
<div className='relative mt-1'>
|
||||
<Input
|
||||
id="password" type={showPassword ? 'text' : 'password'}
|
||||
value={password}
|
||||
onChange={e => setPassword(e.target.value)}
|
||||
placeholder={t('login.passwordPlaceholder') || ''}
|
||||
/>
|
||||
|
||||
<div className="absolute inset-y-0 right-0 flex items-center">
|
||||
<Button
|
||||
type="button"
|
||||
variant='ghost'
|
||||
onClick={() => setShowPassword(!showPassword)}
|
||||
>
|
||||
{showPassword ? '👀' : '😝'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div className='body-xs-regular mt-1 text-text-secondary'>{t('login.error.passwordInvalid')}</div>
|
||||
</div>
|
||||
{/* Confirm Password */}
|
||||
<div className='mb-5'>
|
||||
<label htmlFor="confirmPassword" className="system-md-semibold my-2 text-text-secondary">
|
||||
{t('common.account.confirmPassword')}
|
||||
</label>
|
||||
<div className='relative mt-1'>
|
||||
<Input
|
||||
id="confirmPassword"
|
||||
type={showConfirmPassword ? 'text' : 'password'}
|
||||
value={confirmPassword}
|
||||
onChange={e => setConfirmPassword(e.target.value)}
|
||||
placeholder={t('login.confirmPasswordPlaceholder') || ''}
|
||||
/>
|
||||
<div className="absolute inset-y-0 right-0 flex items-center">
|
||||
<Button
|
||||
type="button"
|
||||
variant='ghost'
|
||||
onClick={() => setShowConfirmPassword(!showConfirmPassword)}
|
||||
>
|
||||
{showConfirmPassword ? '👀' : '😝'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<Button
|
||||
variant='primary'
|
||||
className='w-full'
|
||||
onClick={handleChangePassword}
|
||||
>
|
||||
{t('login.changePasswordBtn')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{showSuccess && (
|
||||
<div className="flex flex-col md:w-[400px]">
|
||||
<div className="mx-auto w-full">
|
||||
<div className="mb-3 flex h-14 w-14 items-center justify-center rounded-2xl border border-components-panel-border-subtle font-bold shadow-lg">
|
||||
<RiCheckboxCircleFill className='h-6 w-6 text-text-success' />
|
||||
</div>
|
||||
<h2 className="title-4xl-semi-bold text-text-primary">
|
||||
{t('login.passwordChangedTip')}
|
||||
</h2>
|
||||
</div>
|
||||
<div className="mx-auto mt-6 w-full">
|
||||
<Button variant='primary' className='w-full' onClick={() => {
|
||||
setLeftTime(undefined)
|
||||
router.replace(getSignInUrl())
|
||||
}}>{t('login.passwordChanged')} ({Math.round(countdown / 1000)}) </Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ChangePasswordForm
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
'use client'
|
||||
import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import Countdown from '@/app/components/signin/countdown'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common'
|
||||
import I18NContext from '@/context/i18n'
|
||||
import { setAccessToken } from '@/app/components/share/utils'
|
||||
import { fetchAccessToken } from '@/service/share'
|
||||
|
||||
export default function CheckCode() {
|
||||
const { t } = useTranslation()
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
const email = decodeURIComponent(searchParams.get('email') as string)
|
||||
const token = decodeURIComponent(searchParams.get('token') as string)
|
||||
const [code, setVerifyCode] = useState('')
|
||||
const [loading, setIsLoading] = useState(false)
|
||||
const { locale } = useContext(I18NContext)
|
||||
const redirectUrl = searchParams.get('redirect_url')
|
||||
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
return appCode
|
||||
}, [redirectUrl])
|
||||
|
||||
const verify = async () => {
|
||||
try {
|
||||
const appCode = getAppCodeFromRedirectUrl()
|
||||
if (!code.trim()) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('login.checkCode.emptyCode'),
|
||||
})
|
||||
return
|
||||
}
|
||||
if (!/\d{6}/.test(code)) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('login.checkCode.invalidCode'),
|
||||
})
|
||||
return
|
||||
}
|
||||
if (!redirectUrl || !appCode) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('login.error.redirectUrlMissing'),
|
||||
})
|
||||
return
|
||||
}
|
||||
setIsLoading(true)
|
||||
const ret = await webAppEmailLoginWithCode({ email, code, token })
|
||||
if (ret.result === 'success') {
|
||||
localStorage.setItem('webapp_access_token', ret.data.access_token)
|
||||
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: ret.data.access_token })
|
||||
await setAccessToken(appCode, tokenResp.access_token)
|
||||
router.replace(redirectUrl)
|
||||
}
|
||||
}
|
||||
catch (error) { console.error(error) }
|
||||
finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
const resendCode = async () => {
|
||||
try {
|
||||
const ret = await sendWebAppEMailLoginCode(email, locale)
|
||||
if (ret.result === 'success') {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
params.set('token', encodeURIComponent(ret.data))
|
||||
router.replace(`/webapp-signin/check-code?${params.toString()}`)
|
||||
}
|
||||
}
|
||||
catch (error) { console.error(error) }
|
||||
}
|
||||
|
||||
return <div className='flex w-[400px] flex-col gap-3'>
|
||||
<div className='inline-flex h-14 w-14 items-center justify-center rounded-2xl border border-components-panel-border-subtle bg-background-default-dodge shadow-lg'>
|
||||
<RiMailSendFill className='h-6 w-6 text-2xl text-text-accent-light-mode-only' />
|
||||
</div>
|
||||
<div className='pb-4 pt-2'>
|
||||
<h2 className='title-4xl-semi-bold text-text-primary'>{t('login.checkCode.checkYourEmail')}</h2>
|
||||
<p className='body-md-regular mt-2 text-text-secondary'>
|
||||
<span dangerouslySetInnerHTML={{ __html: t('login.checkCode.tips', { email }) as string }}></span>
|
||||
<br />
|
||||
{t('login.checkCode.validTime')}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<form action="">
|
||||
<label htmlFor="code" className='system-md-semibold mb-1 text-text-secondary'>{t('login.checkCode.verificationCode')}</label>
|
||||
<Input value={code} onChange={e => setVerifyCode(e.target.value)} max-length={6} className='mt-1' placeholder={t('login.checkCode.verificationCodePlaceholder') as string} />
|
||||
<Button loading={loading} disabled={loading} className='my-3 w-full' variant='primary' onClick={verify}>{t('login.checkCode.verify')}</Button>
|
||||
<Countdown onResend={resendCode} />
|
||||
</form>
|
||||
<div className='py-2'>
|
||||
<div className='h-px bg-gradient-to-r from-background-gradient-mask-transparent via-divider-regular to-background-gradient-mask-transparent'></div>
|
||||
</div>
|
||||
<div onClick={() => router.back()} className='flex h-9 cursor-pointer items-center justify-center text-text-tertiary'>
|
||||
<div className='bg-background-default-dimm inline-block rounded-full p-1'>
|
||||
<RiArrowLeftLine size={12} />
|
||||
</div>
|
||||
<span className='system-xs-regular ml-2'>{t('login.back')}</span>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
'use client'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import React, { useCallback, useEffect } from 'react'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { fetchWebOAuth2SSOUrl, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { SSOProtocol } from '@/types/feature'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
|
||||
const ExternalMemberSSOAuth = () => {
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const searchParams = useSearchParams()
|
||||
const router = useRouter()
|
||||
|
||||
const redirectUrl = searchParams.get('redirect_url')
|
||||
|
||||
const showErrorToast = (message: string) => {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message,
|
||||
})
|
||||
}
|
||||
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
return appCode
|
||||
}, [redirectUrl])
|
||||
|
||||
const handleSSOLogin = useCallback(async () => {
|
||||
const appCode = getAppCodeFromRedirectUrl()
|
||||
if (!appCode || !redirectUrl) {
|
||||
showErrorToast('redirect url or app code is invalid.')
|
||||
return
|
||||
}
|
||||
|
||||
switch (systemFeatures.webapp_auth.sso_config.protocol) {
|
||||
case SSOProtocol.SAML: {
|
||||
const samlRes = await fetchWebSAMLSSOUrl(appCode, redirectUrl)
|
||||
router.push(samlRes.url)
|
||||
break
|
||||
}
|
||||
case SSOProtocol.OIDC: {
|
||||
const oidcRes = await fetchWebOIDCSSOUrl(appCode, redirectUrl)
|
||||
router.push(oidcRes.url)
|
||||
break
|
||||
}
|
||||
case SSOProtocol.OAuth2: {
|
||||
const oauth2Res = await fetchWebOAuth2SSOUrl(appCode, redirectUrl)
|
||||
router.push(oauth2Res.url)
|
||||
break
|
||||
}
|
||||
case '':
|
||||
break
|
||||
default:
|
||||
showErrorToast('SSO protocol is not supported.')
|
||||
}
|
||||
}, [getAppCodeFromRedirectUrl, redirectUrl, router, systemFeatures.webapp_auth.sso_config.protocol])
|
||||
|
||||
useEffect(() => {
|
||||
handleSSOLogin()
|
||||
}, [handleSSOLogin])
|
||||
|
||||
if (!systemFeatures.webapp_auth.sso_config.protocol) {
|
||||
return <div className="flex h-full items-center justify-center">
|
||||
<AppUnavailable code={403} unknownReason='sso protocol is invalid.' />
|
||||
</div>
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<Loading />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(ExternalMemberSSOAuth)
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { emailRegex } from '@/config'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { sendWebAppEMailLoginCode } from '@/service/common'
|
||||
import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown'
|
||||
import I18NContext from '@/context/i18n'
|
||||
import { noop } from 'lodash-es'
|
||||
|
||||
export default function MailAndCodeAuth() {
|
||||
const { t } = useTranslation()
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
const emailFromLink = decodeURIComponent(searchParams.get('email') || '')
|
||||
const [email, setEmail] = useState(emailFromLink)
|
||||
const [loading, setIsLoading] = useState(false)
|
||||
const { locale } = useContext(I18NContext)
|
||||
|
||||
const handleGetEMailVerificationCode = async () => {
|
||||
try {
|
||||
if (!email) {
|
||||
Toast.notify({ type: 'error', message: t('login.error.emailEmpty') })
|
||||
return
|
||||
}
|
||||
|
||||
if (!emailRegex.test(email)) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('login.error.emailInValid'),
|
||||
})
|
||||
return
|
||||
}
|
||||
setIsLoading(true)
|
||||
const ret = await sendWebAppEMailLoginCode(email, locale)
|
||||
if (ret.result === 'success') {
|
||||
localStorage.setItem(COUNT_DOWN_KEY, `${COUNT_DOWN_TIME_MS}`)
|
||||
const params = new URLSearchParams(searchParams)
|
||||
params.set('email', encodeURIComponent(email))
|
||||
params.set('token', encodeURIComponent(ret.data))
|
||||
router.push(`/webapp-signin/check-code?${params.toString()}`)
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
console.error(error)
|
||||
}
|
||||
finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (<form onSubmit={noop}>
|
||||
<input type='text' className='hidden' />
|
||||
<div className='mb-2'>
|
||||
<label htmlFor="email" className='system-md-semibold my-2 text-text-secondary'>{t('login.email')}</label>
|
||||
<div className='mt-1'>
|
||||
<Input id='email' type="email" value={email} placeholder={t('login.emailPlaceholder') as string} onChange={e => setEmail(e.target.value)} />
|
||||
</div>
|
||||
<div className='mt-3'>
|
||||
<Button loading={loading} disabled={loading || !email} variant='primary' className='w-full' onClick={handleGetEMailVerificationCode}>{t('login.continueWithCode')}</Button>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
import Link from 'next/link'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { emailRegex } from '@/config'
|
||||
import { webAppLogin } from '@/service/common'
|
||||
import Input from '@/app/components/base/input'
|
||||
import I18NContext from '@/context/i18n'
|
||||
import { noop } from 'lodash-es'
|
||||
import { setAccessToken } from '@/app/components/share/utils'
|
||||
import { fetchAccessToken } from '@/service/share'
|
||||
|
||||
type MailAndPasswordAuthProps = {
|
||||
isEmailSetup: boolean
|
||||
}
|
||||
|
||||
const passwordRegex = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/
|
||||
|
||||
export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAuthProps) {
|
||||
const { t } = useTranslation()
|
||||
const { locale } = useContext(I18NContext)
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
const [showPassword, setShowPassword] = useState(false)
|
||||
const emailFromLink = decodeURIComponent(searchParams.get('email') || '')
|
||||
const [email, setEmail] = useState(emailFromLink)
|
||||
const [password, setPassword] = useState('')
|
||||
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const redirectUrl = searchParams.get('redirect_url')
|
||||
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
return appCode
|
||||
}, [redirectUrl])
|
||||
const handleEmailPasswordLogin = async () => {
|
||||
const appCode = getAppCodeFromRedirectUrl()
|
||||
if (!email) {
|
||||
Toast.notify({ type: 'error', message: t('login.error.emailEmpty') })
|
||||
return
|
||||
}
|
||||
if (!emailRegex.test(email)) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('login.error.emailInValid'),
|
||||
})
|
||||
return
|
||||
}
|
||||
if (!password?.trim()) {
|
||||
Toast.notify({ type: 'error', message: t('login.error.passwordEmpty') })
|
||||
return
|
||||
}
|
||||
if (!passwordRegex.test(password)) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('login.error.passwordInvalid'),
|
||||
})
|
||||
return
|
||||
}
|
||||
if (!redirectUrl || !appCode) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('login.error.redirectUrlMissing'),
|
||||
})
|
||||
return
|
||||
}
|
||||
try {
|
||||
setIsLoading(true)
|
||||
const loginData: Record<string, any> = {
|
||||
email,
|
||||
password,
|
||||
language: locale,
|
||||
remember_me: true,
|
||||
}
|
||||
|
||||
const res = await webAppLogin({
|
||||
url: '/login',
|
||||
body: loginData,
|
||||
})
|
||||
if (res.result === 'success') {
|
||||
localStorage.setItem('webapp_access_token', res.data.access_token)
|
||||
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: res.data.access_token })
|
||||
await setAccessToken(appCode, tokenResp.access_token)
|
||||
router.replace(redirectUrl)
|
||||
}
|
||||
else {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: res.data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
return <form onSubmit={noop}>
|
||||
<div className='mb-3'>
|
||||
<label htmlFor="email" className="system-md-semibold my-2 text-text-secondary">
|
||||
{t('login.email')}
|
||||
</label>
|
||||
<div className="mt-1">
|
||||
<Input
|
||||
value={email}
|
||||
onChange={e => setEmail(e.target.value)}
|
||||
id="email"
|
||||
type="email"
|
||||
autoComplete="email"
|
||||
placeholder={t('login.emailPlaceholder') || ''}
|
||||
tabIndex={1}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='mb-3'>
|
||||
<label htmlFor="password" className="my-2 flex items-center justify-between">
|
||||
<span className='system-md-semibold text-text-secondary'>{t('login.password')}</span>
|
||||
<Link
|
||||
href={`/webapp-reset-password?${searchParams.toString()}`}
|
||||
className={`system-xs-regular ${isEmailSetup ? 'text-components-button-secondary-accent-text' : 'pointer-events-none text-components-button-secondary-accent-text-disabled'}`}
|
||||
tabIndex={isEmailSetup ? 0 : -1}
|
||||
aria-disabled={!isEmailSetup}
|
||||
>
|
||||
{t('login.forget')}
|
||||
</Link>
|
||||
</label>
|
||||
<div className="relative mt-1">
|
||||
<Input
|
||||
id="password"
|
||||
value={password}
|
||||
onChange={e => setPassword(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter')
|
||||
handleEmailPasswordLogin()
|
||||
}}
|
||||
type={showPassword ? 'text' : 'password'}
|
||||
autoComplete="current-password"
|
||||
placeholder={t('login.passwordPlaceholder') || ''}
|
||||
tabIndex={2}
|
||||
/>
|
||||
<div className="absolute inset-y-0 right-0 flex items-center">
|
||||
<Button
|
||||
type="button"
|
||||
variant='ghost'
|
||||
onClick={() => setShowPassword(!showPassword)}
|
||||
>
|
||||
{showPassword ? '👀' : '😝'}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='mb-2'>
|
||||
<Button
|
||||
tabIndex={2}
|
||||
variant='primary'
|
||||
onClick={handleEmailPasswordLogin}
|
||||
disabled={isLoading || !email || !password}
|
||||
className="w-full"
|
||||
>{t('login.signBtn')}</Button>
|
||||
</div>
|
||||
</form>
|
||||
}
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
'use client'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import type { FC } from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { SSOProtocol } from '@/types/feature'
|
||||
import { fetchMembersOAuth2SSOUrl, fetchMembersOIDCSSOUrl, fetchMembersSAMLSSOUrl } from '@/service/share'
|
||||
|
||||
type SSOAuthProps = {
|
||||
protocol: SSOProtocol | ''
|
||||
}
|
||||
|
||||
const SSOAuth: FC<SSOAuthProps> = ({
|
||||
protocol,
|
||||
}) => {
|
||||
const router = useRouter()
|
||||
const { t } = useTranslation()
|
||||
const searchParams = useSearchParams()
|
||||
|
||||
const redirectUrl = searchParams.get('redirect_url')
|
||||
const getAppCodeFromRedirectUrl = useCallback(() => {
|
||||
const appCode = redirectUrl?.split('/').pop()
|
||||
if (!appCode)
|
||||
return null
|
||||
|
||||
return appCode
|
||||
}, [redirectUrl])
|
||||
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
|
||||
const handleSSOLogin = () => {
|
||||
const appCode = getAppCodeFromRedirectUrl()
|
||||
if (!redirectUrl || !appCode) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: 'invalid redirect URL or app code',
|
||||
})
|
||||
return
|
||||
}
|
||||
setIsLoading(true)
|
||||
if (protocol === SSOProtocol.SAML) {
|
||||
fetchMembersSAMLSSOUrl(appCode, redirectUrl).then((res) => {
|
||||
router.push(res.url)
|
||||
}).finally(() => {
|
||||
setIsLoading(false)
|
||||
})
|
||||
}
|
||||
else if (protocol === SSOProtocol.OIDC) {
|
||||
fetchMembersOIDCSSOUrl(appCode, redirectUrl).then((res) => {
|
||||
router.push(res.url)
|
||||
}).finally(() => {
|
||||
setIsLoading(false)
|
||||
})
|
||||
}
|
||||
else if (protocol === SSOProtocol.OAuth2) {
|
||||
fetchMembersOAuth2SSOUrl(appCode, redirectUrl).then((res) => {
|
||||
router.push(res.url)
|
||||
}).finally(() => {
|
||||
setIsLoading(false)
|
||||
})
|
||||
}
|
||||
else {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: 'invalid SSO protocol',
|
||||
})
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Button
|
||||
tabIndex={0}
|
||||
onClick={() => { handleSSOLogin() }}
|
||||
disabled={isLoading}
|
||||
className="w-full"
|
||||
>
|
||||
<Lock01 className='mr-2 h-5 w-5 text-text-accent-light-mode-only' />
|
||||
<span className="truncate">{t('login.withSSO')}</span>
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
|
||||
export default SSOAuth
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
'use client'
|
||||
|
||||
import cn from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
|
||||
export default function SignInLayout({ children }: any) {
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
useDocumentTitle('')
|
||||
return <>
|
||||
<div className={cn('flex min-h-screen w-full justify-center bg-background-default-burn p-6')}>
|
||||
<div className={cn('flex w-full shrink-0 flex-col rounded-2xl border border-effects-highlight bg-background-default-subtle')}>
|
||||
{/* <Header /> */}
|
||||
<div className={cn('flex w-full grow flex-col items-center justify-center px-6 md:px-[108px]')}>
|
||||
<div className='flex justify-center md:w-[440px] lg:w-[600px]'>
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
{systemFeatures.branding.enabled === false && <div className='system-xs-regular px-8 py-6 text-text-tertiary'>
|
||||
© {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
|
||||
</div>}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
}
|
||||
|
|
@ -0,0 +1,176 @@
|
|||
import React, { useCallback, useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Link from 'next/link'
|
||||
import { RiContractLine, RiDoorLockLine, RiErrorWarningFill } from '@remixicon/react'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import MailAndCodeAuth from './components/mail-and-code-auth'
|
||||
import MailAndPasswordAuth from './components/mail-and-password-auth'
|
||||
import SSOAuth from './components/sso-auth'
|
||||
import cn from '@/utils/classnames'
|
||||
import { LicenseStatus } from '@/types/feature'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
||||
const NormalForm = () => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const [isLoading, setIsLoading] = useState(true)
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
const [authType, updateAuthType] = useState<'code' | 'password'>('password')
|
||||
const [showORLine, setShowORLine] = useState(false)
|
||||
const [allMethodsAreDisabled, setAllMethodsAreDisabled] = useState(false)
|
||||
|
||||
const init = useCallback(async () => {
|
||||
try {
|
||||
setAllMethodsAreDisabled(!systemFeatures.enable_social_oauth_login && !systemFeatures.enable_email_code_login && !systemFeatures.enable_email_password_login && !systemFeatures.sso_enforced_for_signin)
|
||||
setShowORLine((systemFeatures.enable_social_oauth_login || systemFeatures.sso_enforced_for_signin) && (systemFeatures.enable_email_code_login || systemFeatures.enable_email_password_login))
|
||||
updateAuthType(systemFeatures.enable_email_password_login ? 'password' : 'code')
|
||||
}
|
||||
catch (error) {
|
||||
console.error(error)
|
||||
setAllMethodsAreDisabled(true)
|
||||
}
|
||||
finally { setIsLoading(false) }
|
||||
}, [systemFeatures])
|
||||
useEffect(() => {
|
||||
init()
|
||||
}, [init])
|
||||
if (isLoading) {
|
||||
return <div className={
|
||||
cn(
|
||||
'flex w-full grow flex-col items-center justify-center',
|
||||
'px-6',
|
||||
'md:px-[108px]',
|
||||
)
|
||||
}>
|
||||
<Loading type='area' />
|
||||
</div>
|
||||
}
|
||||
if (systemFeatures.license?.status === LicenseStatus.LOST) {
|
||||
return <div className='mx-auto mt-8 w-full'>
|
||||
<div className='relative'>
|
||||
<div className="rounded-lg bg-gradient-to-r from-workflow-workflow-progress-bg-1 to-workflow-workflow-progress-bg-2 p-4">
|
||||
<div className='shadows-shadow-lg relative mb-2 flex h-10 w-10 items-center justify-center rounded-xl bg-components-card-bg shadow'>
|
||||
<RiContractLine className='h-5 w-5' />
|
||||
<RiErrorWarningFill className='absolute -right-1 -top-1 h-4 w-4 text-text-warning-secondary' />
|
||||
</div>
|
||||
<p className='system-sm-medium text-text-primary'>{t('login.licenseLost')}</p>
|
||||
<p className='system-xs-regular mt-1 text-text-tertiary'>{t('login.licenseLostTip')}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
if (systemFeatures.license?.status === LicenseStatus.EXPIRED) {
|
||||
return <div className='mx-auto mt-8 w-full'>
|
||||
<div className='relative'>
|
||||
<div className="rounded-lg bg-gradient-to-r from-workflow-workflow-progress-bg-1 to-workflow-workflow-progress-bg-2 p-4">
|
||||
<div className='shadows-shadow-lg relative mb-2 flex h-10 w-10 items-center justify-center rounded-xl bg-components-card-bg shadow'>
|
||||
<RiContractLine className='h-5 w-5' />
|
||||
<RiErrorWarningFill className='absolute -right-1 -top-1 h-4 w-4 text-text-warning-secondary' />
|
||||
</div>
|
||||
<p className='system-sm-medium text-text-primary'>{t('login.licenseExpired')}</p>
|
||||
<p className='system-xs-regular mt-1 text-text-tertiary'>{t('login.licenseExpiredTip')}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
if (systemFeatures.license?.status === LicenseStatus.INACTIVE) {
|
||||
return <div className='mx-auto mt-8 w-full'>
|
||||
<div className='relative'>
|
||||
<div className="rounded-lg bg-gradient-to-r from-workflow-workflow-progress-bg-1 to-workflow-workflow-progress-bg-2 p-4">
|
||||
<div className='shadows-shadow-lg relative mb-2 flex h-10 w-10 items-center justify-center rounded-xl bg-components-card-bg shadow'>
|
||||
<RiContractLine className='h-5 w-5' />
|
||||
<RiErrorWarningFill className='absolute -right-1 -top-1 h-4 w-4 text-text-warning-secondary' />
|
||||
</div>
|
||||
<p className='system-sm-medium text-text-primary'>{t('login.licenseInactive')}</p>
|
||||
<p className='system-xs-regular mt-1 text-text-tertiary'>{t('login.licenseInactiveTip')}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="mx-auto mt-8 w-full">
|
||||
<div className="mx-auto w-full">
|
||||
<h2 className="title-4xl-semi-bold text-text-primary">{t('login.pageTitle')}</h2>
|
||||
{!systemFeatures.branding.enabled && <p className='body-md-regular mt-2 text-text-tertiary'>{t('login.welcome')}</p>}
|
||||
</div>
|
||||
<div className="relative">
|
||||
<div className="mt-6 flex flex-col gap-3">
|
||||
{systemFeatures.sso_enforced_for_signin && <div className='w-full'>
|
||||
<SSOAuth protocol={systemFeatures.sso_enforced_for_signin_protocol} />
|
||||
</div>}
|
||||
</div>
|
||||
|
||||
{showORLine && <div className="relative mt-6">
|
||||
<div className="absolute inset-0 flex items-center" aria-hidden="true">
|
||||
<div className='h-px w-full bg-gradient-to-r from-background-gradient-mask-transparent via-divider-regular to-background-gradient-mask-transparent'></div>
|
||||
</div>
|
||||
<div className="relative flex justify-center">
|
||||
<span className="system-xs-medium-uppercase px-2 text-text-tertiary">{t('login.or')}</span>
|
||||
</div>
|
||||
</div>}
|
||||
{
|
||||
(systemFeatures.enable_email_code_login || systemFeatures.enable_email_password_login) && <>
|
||||
{systemFeatures.enable_email_code_login && authType === 'code' && <>
|
||||
<MailAndCodeAuth />
|
||||
{systemFeatures.enable_email_password_login && <div className='cursor-pointer py-1 text-center' onClick={() => { updateAuthType('password') }}>
|
||||
<span className='system-xs-medium text-components-button-secondary-accent-text'>{t('login.usePassword')}</span>
|
||||
</div>}
|
||||
</>}
|
||||
{systemFeatures.enable_email_password_login && authType === 'password' && <>
|
||||
<MailAndPasswordAuth isEmailSetup={systemFeatures.is_email_setup} />
|
||||
{systemFeatures.enable_email_code_login && <div className='cursor-pointer py-1 text-center' onClick={() => { updateAuthType('code') }}>
|
||||
<span className='system-xs-medium text-components-button-secondary-accent-text'>{t('login.useVerificationCode')}</span>
|
||||
</div>}
|
||||
</>}
|
||||
</>
|
||||
}
|
||||
{allMethodsAreDisabled && <>
|
||||
<div className="rounded-lg bg-gradient-to-r from-workflow-workflow-progress-bg-1 to-workflow-workflow-progress-bg-2 p-4">
|
||||
<div className='shadows-shadow-lg mb-2 flex h-10 w-10 items-center justify-center rounded-xl bg-components-card-bg shadow'>
|
||||
<RiDoorLockLine className='h-5 w-5' />
|
||||
</div>
|
||||
<p className='system-sm-medium text-text-primary'>{t('login.noLoginMethod')}</p>
|
||||
<p className='system-xs-regular mt-1 text-text-tertiary'>{t('login.noLoginMethodTip')}</p>
|
||||
</div>
|
||||
<div className="relative my-2 py-2">
|
||||
<div className="absolute inset-0 flex items-center" aria-hidden="true">
|
||||
<div className='h-px w-full bg-gradient-to-r from-background-gradient-mask-transparent via-divider-regular to-background-gradient-mask-transparent'></div>
|
||||
</div>
|
||||
</div>
|
||||
</>}
|
||||
{!systemFeatures.branding.enabled && <>
|
||||
<div className="system-xs-regular mt-2 block w-full text-text-tertiary">
|
||||
{t('login.tosDesc')}
|
||||
|
||||
<Link
|
||||
className='system-xs-medium text-text-secondary hover:underline'
|
||||
target='_blank' rel='noopener noreferrer'
|
||||
href='https://dify.ai/terms'
|
||||
>{t('login.tos')}</Link>
|
||||
&
|
||||
<Link
|
||||
className='system-xs-medium text-text-secondary hover:underline'
|
||||
target='_blank' rel='noopener noreferrer'
|
||||
href='https://dify.ai/privacy'
|
||||
>{t('login.pp')}</Link>
|
||||
</div>
|
||||
{IS_CE_EDITION && <div className="w-hull system-xs-regular mt-2 block text-text-tertiary">
|
||||
{t('login.goToInit')}
|
||||
|
||||
<Link
|
||||
className='system-xs-medium text-text-secondary hover:underline'
|
||||
href='/install'
|
||||
>{t('login.setAdminAccount')}</Link>
|
||||
</div>}
|
||||
</>}
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default NormalForm
|
||||
|
|
@ -3,30 +3,45 @@ import { useRouter, useSearchParams } from 'next/navigation'
|
|||
import type { FC } from 'react'
|
||||
import React, { useCallback, useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiDoorLockLine } from '@remixicon/react'
|
||||
import cn from '@/utils/classnames'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { fetchWebOAuth2SSOUrl, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share'
|
||||
import { setAccessToken } from '@/app/components/share/utils'
|
||||
import { removeAccessToken, setAccessToken } from '@/app/components/share/utils'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { SSOProtocol } from '@/types/feature'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import NormalForm from './normalForm'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import ExternalMemberSsoAuth from './components/external-member-sso-auth'
|
||||
import { fetchAccessToken } from '@/service/share'
|
||||
|
||||
const WebSSOForm: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const webAppAccessMode = useGlobalPublicStore(s => s.webAppAccessMode)
|
||||
const searchParams = useSearchParams()
|
||||
const router = useRouter()
|
||||
|
||||
const redirectUrl = searchParams.get('redirect_url')
|
||||
const tokenFromUrl = searchParams.get('web_sso_token')
|
||||
const message = searchParams.get('message')
|
||||
const code = searchParams.get('code')
|
||||
|
||||
const showErrorToast = (message: string) => {
|
||||
const getSigninUrl = useCallback(() => {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
params.delete('message')
|
||||
params.delete('code')
|
||||
return `/webapp-signin?${params.toString()}`
|
||||
}, [searchParams])
|
||||
|
||||
const backToHome = useCallback(() => {
|
||||
removeAccessToken()
|
||||
const url = getSigninUrl()
|
||||
router.replace(url)
|
||||
}, [getSigninUrl, router])
|
||||
|
||||
const showErrorToast = (msg: string) => {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message,
|
||||
message: msg,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -38,102 +53,73 @@ const WebSSOForm: FC = () => {
|
|||
return appCode
|
||||
}, [redirectUrl])
|
||||
|
||||
const processTokenAndRedirect = useCallback(async () => {
|
||||
const appCode = getAppCodeFromRedirectUrl()
|
||||
if (!appCode || !tokenFromUrl || !redirectUrl) {
|
||||
showErrorToast('redirect url or app code or token is invalid.')
|
||||
return
|
||||
}
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
if (message)
|
||||
return
|
||||
|
||||
await setAccessToken(appCode, tokenFromUrl)
|
||||
router.push(redirectUrl)
|
||||
}, [getAppCodeFromRedirectUrl, redirectUrl, router, tokenFromUrl])
|
||||
|
||||
const handleSSOLogin = useCallback(async () => {
|
||||
const appCode = getAppCodeFromRedirectUrl()
|
||||
if (!appCode || !redirectUrl) {
|
||||
showErrorToast('redirect url or app code is invalid.')
|
||||
return
|
||||
}
|
||||
|
||||
switch (systemFeatures.webapp_auth.sso_config.protocol) {
|
||||
case SSOProtocol.SAML: {
|
||||
const samlRes = await fetchWebSAMLSSOUrl(appCode, redirectUrl)
|
||||
router.push(samlRes.url)
|
||||
break
|
||||
const appCode = getAppCodeFromRedirectUrl()
|
||||
if (appCode && tokenFromUrl && redirectUrl) {
|
||||
localStorage.setItem('webapp_access_token', tokenFromUrl)
|
||||
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: tokenFromUrl })
|
||||
await setAccessToken(appCode, tokenResp.access_token)
|
||||
router.replace(redirectUrl)
|
||||
return
|
||||
}
|
||||
case SSOProtocol.OIDC: {
|
||||
const oidcRes = await fetchWebOIDCSSOUrl(appCode, redirectUrl)
|
||||
router.push(oidcRes.url)
|
||||
break
|
||||
if (appCode && redirectUrl && localStorage.getItem('webapp_access_token')) {
|
||||
const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: localStorage.getItem('webapp_access_token') })
|
||||
await setAccessToken(appCode, tokenResp.access_token)
|
||||
router.replace(redirectUrl)
|
||||
}
|
||||
case SSOProtocol.OAuth2: {
|
||||
const oauth2Res = await fetchWebOAuth2SSOUrl(appCode, redirectUrl)
|
||||
router.push(oauth2Res.url)
|
||||
break
|
||||
}
|
||||
case '':
|
||||
break
|
||||
default:
|
||||
showErrorToast('SSO protocol is not supported.')
|
||||
}
|
||||
}, [getAppCodeFromRedirectUrl, redirectUrl, router, systemFeatures.webapp_auth.sso_config.protocol])
|
||||
})()
|
||||
}, [getAppCodeFromRedirectUrl, redirectUrl, router, tokenFromUrl, message])
|
||||
|
||||
useEffect(() => {
|
||||
const init = async () => {
|
||||
if (message) {
|
||||
showErrorToast(message)
|
||||
return
|
||||
}
|
||||
if (webAppAccessMode && webAppAccessMode === AccessMode.PUBLIC && redirectUrl)
|
||||
router.replace(redirectUrl)
|
||||
}, [webAppAccessMode, router, redirectUrl])
|
||||
|
||||
if (!tokenFromUrl) {
|
||||
await handleSSOLogin()
|
||||
return
|
||||
}
|
||||
|
||||
await processTokenAndRedirect()
|
||||
}
|
||||
|
||||
init()
|
||||
}, [message, processTokenAndRedirect, tokenFromUrl, handleSSOLogin])
|
||||
if (tokenFromUrl)
|
||||
return <div className='flex h-full items-center justify-center'><Loading /></div>
|
||||
if (message) {
|
||||
if (tokenFromUrl) {
|
||||
return <div className='flex h-full items-center justify-center'>
|
||||
<AppUnavailable code={'App Unavailable'} unknownReason={message} />
|
||||
<Loading />
|
||||
</div>
|
||||
}
|
||||
|
||||
if (systemFeatures.webapp_auth.enabled) {
|
||||
if (systemFeatures.webapp_auth.allow_sso) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<div className={cn('flex w-full grow flex-col items-center justify-center', 'px-6', 'md:px-[108px]')}>
|
||||
<Loading />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
return <div className="flex h-full items-center justify-center">
|
||||
<div className="rounded-lg bg-gradient-to-r from-workflow-workflow-progress-bg-1 to-workflow-workflow-progress-bg-2 p-4">
|
||||
<div className='shadows-shadow-lg mb-2 flex h-10 w-10 items-center justify-center rounded-xl bg-components-card-bg shadow'>
|
||||
<RiDoorLockLine className='h-5 w-5' />
|
||||
</div>
|
||||
<p className='system-sm-medium text-text-primary'>{t('login.webapp.noLoginMethod')}</p>
|
||||
<p className='system-xs-regular mt-1 text-text-tertiary'>{t('login.webapp.noLoginMethodTip')}</p>
|
||||
</div>
|
||||
<div className="relative my-2 py-2">
|
||||
<div className="absolute inset-0 flex items-center" aria-hidden="true">
|
||||
<div className='h-px w-full bg-gradient-to-r from-background-gradient-mask-transparent via-divider-regular to-background-gradient-mask-transparent'></div>
|
||||
</div>
|
||||
</div>
|
||||
if (message) {
|
||||
return <div className='flex h-full flex-col items-center justify-center gap-y-4'>
|
||||
<AppUnavailable className='h-auto w-auto' code={code || t('share.common.appUnavailable')} unknownReason={message} />
|
||||
<span className='system-sm-regular cursor-pointer text-text-tertiary' onClick={backToHome}>{code === '403' ? t('common.userProfile.logout') : t('share.login.backToHome')}</span>
|
||||
</div>
|
||||
}
|
||||
else {
|
||||
if (!redirectUrl) {
|
||||
showErrorToast('redirect url is invalid.')
|
||||
return <div className='flex h-full items-center justify-center'>
|
||||
<AppUnavailable code={t('share.common.appUnavailable')} unknownReason='redirect url is invalid.' />
|
||||
</div>
|
||||
}
|
||||
if (webAppAccessMode && webAppAccessMode === AccessMode.PUBLIC) {
|
||||
return <div className='flex h-full items-center justify-center'>
|
||||
<Loading />
|
||||
</div>
|
||||
}
|
||||
if (!systemFeatures.webapp_auth.enabled) {
|
||||
return <div className="flex h-full items-center justify-center">
|
||||
<p className='system-xs-regular text-text-tertiary'>{t('login.webapp.disabled')}</p>
|
||||
</div>
|
||||
}
|
||||
if (webAppAccessMode && (webAppAccessMode === AccessMode.ORGANIZATION || webAppAccessMode === AccessMode.SPECIFIC_GROUPS_MEMBERS)) {
|
||||
return <div className='w-full max-w-[400px]'>
|
||||
<NormalForm />
|
||||
</div>
|
||||
}
|
||||
|
||||
if (webAppAccessMode && webAppAccessMode === AccessMode.EXTERNAL_MEMBERS)
|
||||
return <ExternalMemberSsoAuth />
|
||||
|
||||
return <div className='flex h-full flex-col items-center justify-center gap-y-4'>
|
||||
<AppUnavailable className='h-auto w-auto' isUnknownReason={true} />
|
||||
<span className='system-sm-regular cursor-pointer text-text-tertiary' onClick={backToHome}>{t('share.login.backToHome')}</span>
|
||||
</div>
|
||||
}
|
||||
|
||||
export default React.memo(WebSSOForm)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
'use client'
|
||||
import { Dialog } from '@headlessui/react'
|
||||
import { RiBuildingLine, RiGlobalLine } from '@remixicon/react'
|
||||
import { Description as DialogDescription, DialogTitle } from '@headlessui/react'
|
||||
import { RiBuildingLine, RiGlobalLine, RiVerifiedBadgeLine } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useCallback, useEffect } from 'react'
|
||||
import Button from '../../base/button'
|
||||
|
|
@ -67,8 +67,8 @@ export default function AccessControl(props: AccessControlProps) {
|
|||
return <AccessControlDialog show onClose={onClose}>
|
||||
<div className='flex flex-col gap-y-3'>
|
||||
<div className='pb-3 pl-6 pr-14 pt-6'>
|
||||
<Dialog.Title className='title-2xl-semi-bold text-text-primary'>{t('app.accessControlDialog.title')}</Dialog.Title>
|
||||
<Dialog.Description className='system-xs-regular mt-1 text-text-tertiary'>{t('app.accessControlDialog.description')}</Dialog.Description>
|
||||
<DialogTitle className='title-2xl-semi-bold text-text-primary'>{t('app.accessControlDialog.title')}</DialogTitle>
|
||||
<DialogDescription className='system-xs-regular mt-1 text-text-tertiary'>{t('app.accessControlDialog.description')}</DialogDescription>
|
||||
</div>
|
||||
<div className='flex flex-col gap-y-1 px-6 pb-3'>
|
||||
<div className='leading-6'>
|
||||
|
|
@ -80,12 +80,20 @@ export default function AccessControl(props: AccessControlProps) {
|
|||
<RiBuildingLine className='h-4 w-4 text-text-primary' />
|
||||
<p className='system-sm-medium text-text-primary'>{t('app.accessControlDialog.accessItems.organization')}</p>
|
||||
</div>
|
||||
{!hideTip && <WebAppSSONotEnabledTip />}
|
||||
</div>
|
||||
</AccessControlItem>
|
||||
<AccessControlItem type={AccessMode.SPECIFIC_GROUPS_MEMBERS}>
|
||||
<SpecificGroupsOrMembers />
|
||||
</AccessControlItem>
|
||||
<AccessControlItem type={AccessMode.EXTERNAL_MEMBERS}>
|
||||
<div className='flex items-center p-3'>
|
||||
<div className='flex grow items-center gap-x-2'>
|
||||
<RiVerifiedBadgeLine className='h-4 w-4 text-text-primary' />
|
||||
<p className='system-sm-medium text-text-primary'>{t('app.accessControlDialog.accessItems.external')}</p>
|
||||
</div>
|
||||
{!hideTip && <WebAppSSONotEnabledTip />}
|
||||
</div>
|
||||
</AccessControlItem>
|
||||
<AccessControlItem type={AccessMode.PUBLIC}>
|
||||
<div className='flex items-center gap-x-2 p-3'>
|
||||
<RiGlobalLine className='h-4 w-4 text-text-primary' />
|
||||
|
|
|
|||
|
|
@ -3,12 +3,10 @@ import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from
|
|||
import { useTranslation } from 'react-i18next'
|
||||
import { useCallback, useEffect } from 'react'
|
||||
import Avatar from '../../base/avatar'
|
||||
import Divider from '../../base/divider'
|
||||
import Tooltip from '../../base/tooltip'
|
||||
import Loading from '../../base/loading'
|
||||
import useAccessControlStore from '../../../../context/access-control-store'
|
||||
import AddMemberOrGroupDialog from './add-member-or-group-pop'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import type { AccessControlAccount, AccessControlGroup } from '@/models/access-control'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { useAppWhiteListSubjects } from '@/service/access-control'
|
||||
|
|
@ -19,11 +17,6 @@ export default function SpecificGroupsOrMembers() {
|
|||
const setSpecificGroups = useAccessControlStore(s => s.setSpecificGroups)
|
||||
const setSpecificMembers = useAccessControlStore(s => s.setSpecificMembers)
|
||||
const { t } = useTranslation()
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const hideTip = systemFeatures.webapp_auth.enabled
|
||||
&& (systemFeatures.webapp_auth.allow_sso
|
||||
|| systemFeatures.webapp_auth.allow_email_password_login
|
||||
|| systemFeatures.webapp_auth.allow_email_code_login)
|
||||
|
||||
const { isPending, data } = useAppWhiteListSubjects(appId, Boolean(appId) && currentMenu === AccessMode.SPECIFIC_GROUPS_MEMBERS)
|
||||
useEffect(() => {
|
||||
|
|
@ -37,7 +30,6 @@ export default function SpecificGroupsOrMembers() {
|
|||
<RiLockLine className='h-4 w-4 text-text-primary' />
|
||||
<p className='system-sm-medium text-text-primary'>{t('app.accessControlDialog.accessItems.specific')}</p>
|
||||
</div>
|
||||
{!hideTip && <WebAppSSONotEnabledTip />}
|
||||
</div>
|
||||
}
|
||||
|
||||
|
|
@ -48,10 +40,6 @@ export default function SpecificGroupsOrMembers() {
|
|||
<p className='system-sm-medium text-text-primary'>{t('app.accessControlDialog.accessItems.specific')}</p>
|
||||
</div>
|
||||
<div className='flex items-center gap-x-1'>
|
||||
{!hideTip && <>
|
||||
<WebAppSSONotEnabledTip />
|
||||
<Divider className='ml-2 mr-0 h-[14px]' type="vertical" />
|
||||
</>}
|
||||
<AddMemberOrGroupDialog />
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -9,11 +9,14 @@ import dayjs from 'dayjs'
|
|||
import {
|
||||
RiArrowDownSLine,
|
||||
RiArrowRightSLine,
|
||||
RiBuildingLine,
|
||||
RiGlobalLine,
|
||||
RiLockLine,
|
||||
RiPlanetLine,
|
||||
RiPlayCircleLine,
|
||||
RiPlayList2Line,
|
||||
RiTerminalBoxLine,
|
||||
RiVerifiedBadgeLine,
|
||||
} from '@remixicon/react'
|
||||
import { useKeyPress } from 'ahooks'
|
||||
import { getKeyboardKeyCodeBySystem } from '../../workflow/utils'
|
||||
|
|
@ -275,11 +278,33 @@ const AppPublisher = ({
|
|||
onClick={() => {
|
||||
setShowAppAccessControl(true)
|
||||
}}>
|
||||
<div className='flex grow items-center gap-x-1.5 pr-1'>
|
||||
<RiLockLine className='h-4 w-4 shrink-0 text-text-secondary' />
|
||||
{appDetail?.access_mode === AccessMode.ORGANIZATION && <p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.organization')}</p>}
|
||||
{appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && <p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.specific')}</p>}
|
||||
{appDetail?.access_mode === AccessMode.PUBLIC && <p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.anyone')}</p>}
|
||||
<div className='flex grow items-center gap-x-1.5 overflow-hidden pr-1'>
|
||||
{appDetail?.access_mode === AccessMode.ORGANIZATION
|
||||
&& <>
|
||||
<RiBuildingLine className='h-4 w-4 shrink-0 text-text-secondary' />
|
||||
<p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.organization')}</p>
|
||||
</>
|
||||
}
|
||||
{appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS
|
||||
&& <>
|
||||
<RiLockLine className='h-4 w-4 shrink-0 text-text-secondary' />
|
||||
<div className='grow truncate'>
|
||||
<span className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.specific')}</span>
|
||||
</div>
|
||||
</>
|
||||
}
|
||||
{appDetail?.access_mode === AccessMode.PUBLIC
|
||||
&& <>
|
||||
<RiGlobalLine className='h-4 w-4 shrink-0 text-text-secondary' />
|
||||
<p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.anyone')}</p>
|
||||
</>
|
||||
}
|
||||
{appDetail?.access_mode === AccessMode.EXTERNAL_MEMBERS
|
||||
&& <>
|
||||
<RiVerifiedBadgeLine className='h-4 w-4 shrink-0 text-text-secondary' />
|
||||
<p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.external')}</p>
|
||||
</>
|
||||
}
|
||||
</div>
|
||||
{!isAppAccessSet && <p className='system-xs-regular shrink-0 text-text-tertiary'>{t('app.publishApp.notSet')}</p>}
|
||||
<div className='flex h-4 w-4 shrink-0 items-center justify-center'>
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
'use client'
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useCallback, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useContext, useContextSelector } from 'use-context-selector'
|
||||
import { RiArrowRightLine, RiArrowRightSLine, RiCommandLine, RiCornerDownLeftLine, RiExchange2Fill } from '@remixicon/react'
|
||||
import Link from 'next/link'
|
||||
|
|
@ -19,7 +19,6 @@ import AppsContext, { useAppContext } from '@/context/app-context'
|
|||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import type { AppMode } from '@/types/app'
|
||||
import { AppModes } from '@/types/app'
|
||||
import { createApp } from '@/service/apps'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Textarea from '@/app/components/base/textarea'
|
||||
|
|
@ -56,14 +55,6 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps)
|
|||
|
||||
const isCreatingRef = useRef(false)
|
||||
|
||||
const searchParams = useSearchParams()
|
||||
|
||||
useEffect(() => {
|
||||
const category = searchParams.get('category')
|
||||
if (category && AppModes.includes(category as AppMode))
|
||||
setAppMode(category as AppMode)
|
||||
}, [searchParams])
|
||||
|
||||
const onCreate = useCallback(async () => {
|
||||
if (!appMode) {
|
||||
notify({ type: 'error', message: t('app.newApp.appTypeRequired') })
|
||||
|
|
@ -128,7 +119,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps)
|
|||
onClick={() => {
|
||||
setAppMode('workflow')
|
||||
}} />
|
||||
<AppTypeCard
|
||||
<AppTypeCard
|
||||
active={appMode === 'advanced-chat'}
|
||||
title={t('app.types.advanced')}
|
||||
description={t('app.newApp.advancedShortDescription')}
|
||||
|
|
|
|||
|
|
@ -5,10 +5,13 @@ import { useTranslation } from 'react-i18next'
|
|||
import {
|
||||
RiArrowRightSLine,
|
||||
RiBookOpenLine,
|
||||
RiBuildingLine,
|
||||
RiEqualizer2Line,
|
||||
RiExternalLinkLine,
|
||||
RiGlobalLine,
|
||||
RiLockLine,
|
||||
RiPaintBrushLine,
|
||||
RiVerifiedBadgeLine,
|
||||
RiWindowLine,
|
||||
} from '@remixicon/react'
|
||||
import SettingsModal from './settings'
|
||||
|
|
@ -248,11 +251,30 @@ function AppCard({
|
|||
<div className='flex h-9 w-full cursor-pointer items-center gap-x-0.5 rounded-lg bg-components-input-bg-normal py-1 pl-2.5 pr-2'
|
||||
onClick={handleClickAccessControl}>
|
||||
<div className='flex grow items-center gap-x-1.5 pr-1'>
|
||||
<RiLockLine className='h-4 w-4 shrink-0 text-text-secondary' />
|
||||
{appDetail?.access_mode === AccessMode.ORGANIZATION && <p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.organization')}</p>}
|
||||
{appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && <p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.specific')}</p>}
|
||||
{appDetail?.access_mode === AccessMode.PUBLIC && <p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.anyone')}</p>}
|
||||
</div>
|
||||
{appDetail?.access_mode === AccessMode.ORGANIZATION
|
||||
&& <>
|
||||
<RiBuildingLine className='h-4 w-4 shrink-0 text-text-secondary' />
|
||||
<p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.organization')}</p>
|
||||
</>
|
||||
}
|
||||
{appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS
|
||||
&& <>
|
||||
<RiLockLine className='h-4 w-4 shrink-0 text-text-secondary' />
|
||||
<p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.specific')}</p>
|
||||
</>
|
||||
}
|
||||
{appDetail?.access_mode === AccessMode.PUBLIC
|
||||
&& <>
|
||||
<RiGlobalLine className='h-4 w-4 shrink-0 text-text-secondary' />
|
||||
<p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.anyone')}</p>
|
||||
</>
|
||||
}
|
||||
{appDetail?.access_mode === AccessMode.EXTERNAL_MEMBERS
|
||||
&& <>
|
||||
<RiVerifiedBadgeLine className='h-4 w-4 shrink-0 text-text-secondary' />
|
||||
<p className='system-sm-medium text-text-secondary'>{t('app.accessControlDialog.accessItems.external')}</p>
|
||||
</>
|
||||
}</div>
|
||||
{!isAppAccessSet && <p className='system-xs-regular shrink-0 text-text-tertiary'>{t('app.publishApp.notSet')}</p>}
|
||||
<div className='flex h-4 w-4 shrink-0 items-center justify-center'>
|
||||
<RiArrowRightSLine className='h-4 w-4 text-text-quaternary' />
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ export type AppSelectorProps = {
|
|||
onChange: (value: AppSelectorProps['value']) => void
|
||||
}
|
||||
|
||||
const allTypes: AppMode[] = ['chat', 'agent-chat', 'completion', 'advanced-chat', 'workflow']
|
||||
const allTypes: AppMode[] = ['workflow', 'advanced-chat', 'chat', 'agent-chat', 'completion']
|
||||
|
||||
const AppTypeSelector = ({ value, onChange }: AppSelectorProps) => {
|
||||
const [open, setOpen] = useState(false)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
'use client'
|
||||
import classNames from '@/utils/classnames'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
|
@ -7,18 +8,20 @@ type IAppUnavailableProps = {
|
|||
code?: number | string
|
||||
isUnknownReason?: boolean
|
||||
unknownReason?: string
|
||||
className?: string
|
||||
}
|
||||
|
||||
const AppUnavailable: FC<IAppUnavailableProps> = ({
|
||||
code = 404,
|
||||
isUnknownReason,
|
||||
unknownReason,
|
||||
className,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className='flex h-screen w-screen items-center justify-center'>
|
||||
<h1 className='mr-5 h-[50px] pr-5 text-[24px] font-medium leading-[50px]'
|
||||
<div className={classNames('flex h-screen w-screen items-center justify-center', className)}>
|
||||
<h1 className='mr-5 h-[50px] shrink-0 pr-5 text-[24px] font-medium leading-[50px]'
|
||||
style={{
|
||||
borderRight: '1px solid rgba(0,0,0,.3)',
|
||||
}}>{code}</h1>
|
||||
|
|
|
|||
|
|
@ -16,14 +16,12 @@ import type {
|
|||
ConversationItem,
|
||||
} from '@/models/share'
|
||||
import { noop } from 'lodash-es'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
|
||||
export type ChatWithHistoryContextValue = {
|
||||
appInfoError?: any
|
||||
appInfoLoading?: boolean
|
||||
appMeta?: AppMeta
|
||||
appData?: AppData
|
||||
accessMode?: AccessMode
|
||||
userCanAccess?: boolean
|
||||
appParams?: ChatConfig
|
||||
appChatListDataLoading?: boolean
|
||||
|
|
@ -64,7 +62,6 @@ export type ChatWithHistoryContextValue = {
|
|||
}
|
||||
|
||||
export const ChatWithHistoryContext = createContext<ChatWithHistoryContextValue>({
|
||||
accessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS,
|
||||
userCanAccess: false,
|
||||
currentConversationId: '',
|
||||
appPrevChatTree: [],
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import type {
|
|||
Feedback,
|
||||
} from '../types'
|
||||
import { CONVERSATION_ID_INFO } from '../constants'
|
||||
import { buildChatItemTree, getProcessedSystemVariablesFromUrlParams } from '../utils'
|
||||
import { buildChatItemTree, getProcessedSystemVariablesFromUrlParams, getRawInputsFromUrlParams } from '../utils'
|
||||
import { addFileInfos, sortAgentSorts } from '../../../tools/utils'
|
||||
import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils'
|
||||
import {
|
||||
|
|
@ -43,9 +43,8 @@ import { useAppFavicon } from '@/hooks/use-app-favicon'
|
|||
import { InputVarType } from '@/app/components/workflow/types'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { noop } from 'lodash-es'
|
||||
import { useGetAppAccessMode, useGetUserCanAccessApp } from '@/service/access-control'
|
||||
import { useGetUserCanAccessApp } from '@/service/access-control'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
|
||||
function getFormattedChatList(messages: any[]) {
|
||||
const newChatList: ChatItem[] = []
|
||||
|
|
@ -77,11 +76,6 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
|
|||
const isInstalledApp = useMemo(() => !!installedAppInfo, [installedAppInfo])
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const { data: appInfo, isLoading: appInfoLoading, error: appInfoError } = useSWR(installedAppInfo ? null : 'appInfo', fetchAppInfo)
|
||||
const { isPending: isGettingAccessMode, data: appAccessMode } = useGetAppAccessMode({
|
||||
appId: installedAppInfo?.app.id || appInfo?.app_id,
|
||||
isInstalledApp,
|
||||
enabled: systemFeatures.webapp_auth.enabled,
|
||||
})
|
||||
const { isPending: isCheckingPermission, data: userCanAccessResult } = useGetUserCanAccessApp({
|
||||
appId: installedAppInfo?.app.id || appInfo?.app_id,
|
||||
isInstalledApp,
|
||||
|
|
@ -195,6 +189,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
|
|||
const { t } = useTranslation()
|
||||
const newConversationInputsRef = useRef<Record<string, any>>({})
|
||||
const [newConversationInputs, setNewConversationInputs] = useState<Record<string, any>>({})
|
||||
const [initInputs, setInitInputs] = useState<Record<string, any>>({})
|
||||
const handleNewConversationInputsChange = useCallback((newInputs: Record<string, any>) => {
|
||||
newConversationInputsRef.current = newInputs
|
||||
setNewConversationInputs(newInputs)
|
||||
|
|
@ -202,20 +197,29 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
|
|||
const inputsForms = useMemo(() => {
|
||||
return (appParams?.user_input_form || []).filter((item: any) => !item.external_data_tool).map((item: any) => {
|
||||
if (item.paragraph) {
|
||||
let value = initInputs[item.paragraph.variable]
|
||||
if (value && item.paragraph.max_length && value.length > item.paragraph.max_length)
|
||||
value = value.slice(0, item.paragraph.max_length)
|
||||
|
||||
return {
|
||||
...item.paragraph,
|
||||
default: value || item.default,
|
||||
type: 'paragraph',
|
||||
}
|
||||
}
|
||||
if (item.number) {
|
||||
const convertedNumber = Number(initInputs[item.number.variable]) ?? undefined
|
||||
return {
|
||||
...item.number,
|
||||
default: convertedNumber || item.default,
|
||||
type: 'number',
|
||||
}
|
||||
}
|
||||
if (item.select) {
|
||||
const isInputInOptions = item.select.options.includes(initInputs[item.select.variable])
|
||||
return {
|
||||
...item.select,
|
||||
default: (isInputInOptions ? initInputs[item.select.variable] : undefined) || item.default,
|
||||
type: 'select',
|
||||
}
|
||||
}
|
||||
|
|
@ -234,17 +238,30 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
|
|||
}
|
||||
}
|
||||
|
||||
let value = initInputs[item['text-input'].variable]
|
||||
if (value && item['text-input'].max_length && value.length > item['text-input'].max_length)
|
||||
value = value.slice(0, item['text-input'].max_length)
|
||||
|
||||
return {
|
||||
...item['text-input'],
|
||||
default: value || item.default,
|
||||
type: 'text-input',
|
||||
}
|
||||
})
|
||||
}, [appParams])
|
||||
}, [initInputs, appParams])
|
||||
|
||||
const allInputsHidden = useMemo(() => {
|
||||
return inputsForms.length > 0 && inputsForms.every(item => item.hide === true)
|
||||
}, [inputsForms])
|
||||
|
||||
useEffect(() => {
|
||||
// init inputs from url params
|
||||
(async () => {
|
||||
const inputs = await getRawInputsFromUrlParams()
|
||||
setInitInputs(inputs)
|
||||
})()
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
const conversationInputs: Record<string, any> = {}
|
||||
|
||||
|
|
@ -362,11 +379,11 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
|
|||
if (conversationId)
|
||||
setClearChatList(false)
|
||||
}, [handleConversationIdInfoChange, setClearChatList])
|
||||
const handleNewConversation = useCallback(() => {
|
||||
const handleNewConversation = useCallback(async () => {
|
||||
currentChatInstanceRef.current.handleStop()
|
||||
setShowNewConversationItemInList(true)
|
||||
handleChangeConversation('')
|
||||
handleNewConversationInputsChange({})
|
||||
handleNewConversationInputsChange(await getRawInputsFromUrlParams())
|
||||
setClearChatList(true)
|
||||
}, [handleChangeConversation, setShowNewConversationItemInList, handleNewConversationInputsChange, setClearChatList])
|
||||
const handleUpdateConversationList = useCallback(() => {
|
||||
|
|
@ -469,8 +486,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
|
|||
|
||||
return {
|
||||
appInfoError,
|
||||
appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && (isGettingAccessMode || isCheckingPermission)),
|
||||
accessMode: systemFeatures.webapp_auth.enabled ? appAccessMode?.accessMode : AccessMode.PUBLIC,
|
||||
appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && isCheckingPermission),
|
||||
userCanAccess: systemFeatures.webapp_auth.enabled ? userCanAccessResult?.result : true,
|
||||
isInstalledApp,
|
||||
appId,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useState,
|
||||
} from 'react'
|
||||
|
|
@ -17,10 +19,12 @@ import ChatWrapper from './chat-wrapper'
|
|||
import type { InstalledApp } from '@/models/explore'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import { checkOrSetAccessToken } from '@/app/components/share/utils'
|
||||
import { checkOrSetAccessToken, removeAccessToken } from '@/app/components/share/utils'
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import cn from '@/utils/classnames'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
|
||||
|
||||
type ChatWithHistoryProps = {
|
||||
className?: string
|
||||
|
|
@ -38,6 +42,7 @@ const ChatWithHistory: FC<ChatWithHistoryProps> = ({
|
|||
isMobile,
|
||||
themeBuilder,
|
||||
sidebarCollapseState,
|
||||
isInstalledApp,
|
||||
} = useChatWithHistoryContext()
|
||||
const isSidebarCollapsed = sidebarCollapseState
|
||||
const customConfig = appData?.custom_config
|
||||
|
|
@ -51,13 +56,34 @@ const ChatWithHistory: FC<ChatWithHistoryProps> = ({
|
|||
|
||||
useDocumentTitle(site?.title || 'Chat')
|
||||
|
||||
const { t } = useTranslation()
|
||||
const searchParams = useSearchParams()
|
||||
const router = useRouter()
|
||||
const pathname = usePathname()
|
||||
const getSigninUrl = useCallback(() => {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
params.delete('message')
|
||||
params.set('redirect_url', pathname)
|
||||
return `/webapp-signin?${params.toString()}`
|
||||
}, [searchParams, pathname])
|
||||
|
||||
const backToHome = useCallback(() => {
|
||||
removeAccessToken()
|
||||
const url = getSigninUrl()
|
||||
router.replace(url)
|
||||
}, [getSigninUrl, router])
|
||||
|
||||
if (appInfoLoading) {
|
||||
return (
|
||||
<Loading type='app' />
|
||||
)
|
||||
}
|
||||
if (!userCanAccess)
|
||||
return <AppUnavailable code={403} unknownReason='no permission.' />
|
||||
if (!userCanAccess) {
|
||||
return <div className='flex h-full flex-col items-center justify-center gap-y-2'>
|
||||
<AppUnavailable className='h-auto w-auto' code={403} unknownReason='no permission.' />
|
||||
{!isInstalledApp && <span className='system-sm-regular cursor-pointer text-text-tertiary' onClick={backToHome}>{t('common.userProfile.logout')}</span>}
|
||||
</div>
|
||||
}
|
||||
|
||||
if (appInfoError) {
|
||||
return (
|
||||
|
|
@ -124,7 +150,6 @@ const ChatWithHistoryWrap: FC<ChatWithHistoryWrapProps> = ({
|
|||
const {
|
||||
appInfoError,
|
||||
appInfoLoading,
|
||||
accessMode,
|
||||
userCanAccess,
|
||||
appData,
|
||||
appParams,
|
||||
|
|
@ -169,7 +194,6 @@ const ChatWithHistoryWrap: FC<ChatWithHistoryWrapProps> = ({
|
|||
appInfoError,
|
||||
appInfoLoading,
|
||||
appData,
|
||||
accessMode,
|
||||
userCanAccess,
|
||||
appParams,
|
||||
appMeta,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ import RenameModal from '@/app/components/base/chat/chat-with-history/sidebar/re
|
|||
import DifyLogo from '@/app/components/base/logo/dify-logo'
|
||||
import type { ConversationItem } from '@/models/share'
|
||||
import cn from '@/utils/classnames'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
||||
type Props = {
|
||||
|
|
@ -30,7 +29,6 @@ const Sidebar = ({ isPanel }: Props) => {
|
|||
const { t } = useTranslation()
|
||||
const {
|
||||
isInstalledApp,
|
||||
accessMode,
|
||||
appData,
|
||||
handleNewConversation,
|
||||
pinnedConversationList,
|
||||
|
|
@ -140,7 +138,7 @@ const Sidebar = ({ isPanel }: Props) => {
|
|||
)}
|
||||
</div>
|
||||
<div className='flex shrink-0 items-center justify-between p-3'>
|
||||
<MenuDropdown hideLogout={isInstalledApp || accessMode === AccessMode.PUBLIC} placement='top-start' data={appData?.site} />
|
||||
<MenuDropdown hideLogout={isInstalledApp} placement='top-start' data={appData?.site} />
|
||||
{/* powered by */}
|
||||
<div className='shrink-0'>
|
||||
{!appData?.custom_config?.remove_webapp_brand && (
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ export const markdownContentSVG = `
|
|||
<svg width="400" height="600" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="100%" height="100%" fill="#F0F8FF"/>
|
||||
|
||||
<text x="50%" y="60" font-family="楷体" font-size="32" fill="#4682B4" text-anchor="middle">创意Logo设计</text>
|
||||
<text x="50%" y="60" font-family="楷体" font-size="32" fill="#4682B4" text-anchor="middle">创意 Logo 设计</text>
|
||||
|
||||
<line x1="50" y1="80" x2="350" y2="80" stroke="#B0C4DE" stroke-width="2"/>
|
||||
|
||||
|
|
|
|||
|
|
@ -366,7 +366,7 @@ export const useChat = (
|
|||
if (!newResponseItem)
|
||||
return
|
||||
|
||||
const isUseAgentThought = newResponseItem.agent_thoughts?.length > 0
|
||||
const isUseAgentThought = newResponseItem.agent_thoughts?.length > 0 && newResponseItem.agent_thoughts[newResponseItem.agent_thoughts?.length - 1].thought === newResponseItem.answer
|
||||
updateChatTreeNode(responseItem.id, {
|
||||
content: isUseAgentThought ? '' : newResponseItem.answer,
|
||||
log: [
|
||||
|
|
|
|||
|
|
@ -303,7 +303,7 @@ const Chat: FC<ChatProps> = ({
|
|||
{
|
||||
!noChatInput && (
|
||||
<ChatInputArea
|
||||
botName={appData?.site.title || ''}
|
||||
botName={appData?.site.title || 'Bot'}
|
||||
disabled={inputDisabled}
|
||||
showFeatureBar={showFeatureBar}
|
||||
showFileUpload={showFileUpload}
|
||||
|
|
|
|||
|
|
@ -15,10 +15,8 @@ import type {
|
|||
ConversationItem,
|
||||
} from '@/models/share'
|
||||
import { noop } from 'lodash-es'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
|
||||
export type EmbeddedChatbotContextValue = {
|
||||
accessMode?: AccessMode
|
||||
userCanAccess?: boolean
|
||||
appInfoError?: any
|
||||
appInfoLoading?: boolean
|
||||
|
|
@ -58,7 +56,6 @@ export type EmbeddedChatbotContextValue = {
|
|||
|
||||
export const EmbeddedChatbotContext = createContext<EmbeddedChatbotContextValue>({
|
||||
userCanAccess: false,
|
||||
accessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS,
|
||||
currentConversationId: '',
|
||||
appPrevChatList: [],
|
||||
pinnedConversationList: [],
|
||||
|
|
|
|||
|
|
@ -36,9 +36,8 @@ import { InputVarType } from '@/app/components/workflow/types'
|
|||
import { TransferMethod } from '@/types/app'
|
||||
import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils'
|
||||
import { noop } from 'lodash-es'
|
||||
import { useGetAppAccessMode, useGetUserCanAccessApp } from '@/service/access-control'
|
||||
import { useGetUserCanAccessApp } from '@/service/access-control'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
|
||||
function getFormattedChatList(messages: any[]) {
|
||||
const newChatList: ChatItem[] = []
|
||||
|
|
@ -70,11 +69,6 @@ export const useEmbeddedChatbot = () => {
|
|||
const isInstalledApp = false
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const { data: appInfo, isLoading: appInfoLoading, error: appInfoError } = useSWR('appInfo', fetchAppInfo)
|
||||
const { isPending: isGettingAccessMode, data: appAccessMode } = useGetAppAccessMode({
|
||||
appId: appInfo?.app_id,
|
||||
isInstalledApp,
|
||||
enabled: systemFeatures.webapp_auth.enabled,
|
||||
})
|
||||
const { isPending: isCheckingPermission, data: userCanAccessResult } = useGetUserCanAccessApp({
|
||||
appId: appInfo?.app_id,
|
||||
isInstalledApp,
|
||||
|
|
@ -385,8 +379,7 @@ export const useEmbeddedChatbot = () => {
|
|||
|
||||
return {
|
||||
appInfoError,
|
||||
appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && (isGettingAccessMode || isCheckingPermission)),
|
||||
accessMode: systemFeatures.webapp_auth.enabled ? appAccessMode?.accessMode : AccessMode.PUBLIC,
|
||||
appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && isCheckingPermission),
|
||||
userCanAccess: systemFeatures.webapp_auth.enabled ? userCanAccessResult?.result : true,
|
||||
isInstalledApp,
|
||||
allowResetChat,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
'use client'
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useState,
|
||||
} from 'react'
|
||||
|
|
@ -12,7 +14,7 @@ import { useEmbeddedChatbot } from './hooks'
|
|||
import { isDify } from './utils'
|
||||
import { useThemeContext } from './theme/theme-context'
|
||||
import { CssTransform } from './theme/utils'
|
||||
import { checkOrSetAccessToken } from '@/app/components/share/utils'
|
||||
import { checkOrSetAccessToken, removeAccessToken } from '@/app/components/share/utils'
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
|
|
@ -23,6 +25,7 @@ import DifyLogo from '@/app/components/base/logo/dify-logo'
|
|||
import cn from '@/utils/classnames'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
|
||||
|
||||
const Chatbot = () => {
|
||||
const {
|
||||
|
|
@ -36,6 +39,7 @@ const Chatbot = () => {
|
|||
chatShouldReloadKey,
|
||||
handleNewConversation,
|
||||
themeBuilder,
|
||||
isInstalledApp,
|
||||
} = useEmbeddedChatbotContext()
|
||||
const { t } = useTranslation()
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
|
|
@ -51,6 +55,22 @@ const Chatbot = () => {
|
|||
|
||||
useDocumentTitle(site?.title || 'Chat')
|
||||
|
||||
const searchParams = useSearchParams()
|
||||
const router = useRouter()
|
||||
const pathname = usePathname()
|
||||
const getSigninUrl = useCallback(() => {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
params.delete('message')
|
||||
params.set('redirect_url', pathname)
|
||||
return `/webapp-signin?${params.toString()}`
|
||||
}, [searchParams, pathname])
|
||||
|
||||
const backToHome = useCallback(() => {
|
||||
removeAccessToken()
|
||||
const url = getSigninUrl()
|
||||
router.replace(url)
|
||||
}, [getSigninUrl, router])
|
||||
|
||||
if (appInfoLoading) {
|
||||
return (
|
||||
<>
|
||||
|
|
@ -66,8 +86,12 @@ const Chatbot = () => {
|
|||
)
|
||||
}
|
||||
|
||||
if (!userCanAccess)
|
||||
return <AppUnavailable code={403} unknownReason='no permission.' />
|
||||
if (!userCanAccess) {
|
||||
return <div className='flex h-full flex-col items-center justify-center gap-y-2'>
|
||||
<AppUnavailable className='h-auto w-auto' code={403} unknownReason='no permission.' />
|
||||
{!isInstalledApp && <span className='system-sm-regular cursor-pointer text-text-tertiary' onClick={backToHome}>{t('common.userProfile.logout')}</span>}
|
||||
</div>
|
||||
}
|
||||
|
||||
if (appInfoError) {
|
||||
return (
|
||||
|
|
@ -141,7 +165,6 @@ const EmbeddedChatbotWrapper = () => {
|
|||
appInfoError,
|
||||
appInfoLoading,
|
||||
appData,
|
||||
accessMode,
|
||||
userCanAccess,
|
||||
appParams,
|
||||
appMeta,
|
||||
|
|
@ -176,7 +199,6 @@ const EmbeddedChatbotWrapper = () => {
|
|||
|
||||
return <EmbeddedChatbotContext.Provider value={{
|
||||
userCanAccess,
|
||||
accessMode,
|
||||
appInfoError,
|
||||
appInfoLoading,
|
||||
appData,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,17 @@ async function decodeBase64AndDecompress(base64String: string) {
|
|||
}
|
||||
}
|
||||
|
||||
async function getRawInputsFromUrlParams(): Promise<Record<string, any>> {
|
||||
const urlParams = new URLSearchParams(window.location.search)
|
||||
const inputs: Record<string, any> = {}
|
||||
const entriesArray = Array.from(urlParams.entries())
|
||||
entriesArray.forEach(([key, value]) => {
|
||||
if (!key.startsWith('sys.'))
|
||||
inputs[key] = decodeURIComponent(value)
|
||||
})
|
||||
return inputs
|
||||
}
|
||||
|
||||
async function getProcessedInputsFromUrlParams(): Promise<Record<string, any>> {
|
||||
const urlParams = new URLSearchParams(window.location.search)
|
||||
const inputs: Record<string, any> = {}
|
||||
|
|
@ -184,6 +195,7 @@ function getThreadMessages(tree: ChatItemInTree[], targetMessageId?: string): Ch
|
|||
}
|
||||
|
||||
export {
|
||||
getRawInputsFromUrlParams,
|
||||
getProcessedInputsFromUrlParams,
|
||||
getProcessedSystemVariablesFromUrlParams,
|
||||
isValidGeneratedAnswer,
|
||||
|
|
|
|||
|
|
@ -231,7 +231,7 @@ export const useFile = (fileConfig: FileUpload) => {
|
|||
url: res.url,
|
||||
}
|
||||
if (!isAllowedFileExtension(res.name, res.mime_type, fileConfig.allowed_file_types || [], fileConfig.allowed_file_extensions || [])) {
|
||||
notify({ type: 'error', message: t('common.fileUploader.fileExtensionNotSupport') })
|
||||
notify({ type: 'error', message: `${t('common.fileUploader.fileExtensionNotSupport')} ${file.type}` })
|
||||
handleRemoveFile(uploadingFile.id)
|
||||
}
|
||||
if (!checkSizeLimit(newFile.supportFileType, newFile.size))
|
||||
|
|
@ -257,7 +257,7 @@ export const useFile = (fileConfig: FileUpload) => {
|
|||
|
||||
const handleLocalFileUpload = useCallback((file: File) => {
|
||||
if (!isAllowedFileExtension(file.name, file.type, fileConfig.allowed_file_types || [], fileConfig.allowed_file_extensions || [])) {
|
||||
notify({ type: 'error', message: t('common.fileUploader.fileExtensionNotSupport') })
|
||||
notify({ type: 'error', message: `${t('common.fileUploader.fileExtensionNotSupport')} ${file.type}` })
|
||||
return
|
||||
}
|
||||
const allowedFileTypes = fileConfig.allowed_file_types
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import { FILE_EXTS } from '../prompt-editor/constants'
|
|||
jest.mock('mime', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
getExtension: jest.fn(),
|
||||
getAllExtensions: jest.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
|
|
@ -58,12 +58,27 @@ describe('file-uploader utils', () => {
|
|||
|
||||
describe('getFileExtension', () => {
|
||||
it('should get extension from mimetype', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('pdf')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['pdf']))
|
||||
expect(getFileExtension('file', 'application/pdf')).toBe('pdf')
|
||||
})
|
||||
|
||||
it('should get extension from mimetype and file name 1', () => {
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['pdf']))
|
||||
expect(getFileExtension('file.pdf', 'application/pdf')).toBe('pdf')
|
||||
})
|
||||
|
||||
it('should get extension from mimetype with multiple ext candidates with filename hint', () => {
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['der', 'crt', 'pem']))
|
||||
expect(getFileExtension('file.pem', 'application/x-x509-ca-cert')).toBe('pem')
|
||||
})
|
||||
|
||||
it('should get extension from mimetype with multiple ext candidates without filename hint', () => {
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['der', 'crt', 'pem']))
|
||||
expect(getFileExtension('file', 'application/x-x509-ca-cert')).toBe('der')
|
||||
})
|
||||
|
||||
it('should get extension from filename if mimetype fails', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue(null)
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(null)
|
||||
expect(getFileExtension('file.txt', '')).toBe('txt')
|
||||
expect(getFileExtension('file.txt.docx', '')).toBe('docx')
|
||||
expect(getFileExtension('file', '')).toBe('')
|
||||
|
|
@ -76,157 +91,157 @@ describe('file-uploader utils', () => {
|
|||
|
||||
describe('getFileAppearanceType', () => {
|
||||
it('should identify gif files', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('gif')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['gif']))
|
||||
expect(getFileAppearanceType('image.gif', 'image/gif'))
|
||||
.toBe(FileAppearanceTypeEnum.gif)
|
||||
})
|
||||
|
||||
it('should identify image files', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('jpg')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['jpg']))
|
||||
expect(getFileAppearanceType('image.jpg', 'image/jpeg'))
|
||||
.toBe(FileAppearanceTypeEnum.image)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('jpeg')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['jpeg']))
|
||||
expect(getFileAppearanceType('image.jpeg', 'image/jpeg'))
|
||||
.toBe(FileAppearanceTypeEnum.image)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('png')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['png']))
|
||||
expect(getFileAppearanceType('image.png', 'image/png'))
|
||||
.toBe(FileAppearanceTypeEnum.image)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('webp')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['webp']))
|
||||
expect(getFileAppearanceType('image.webp', 'image/webp'))
|
||||
.toBe(FileAppearanceTypeEnum.image)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('svg')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['svg']))
|
||||
expect(getFileAppearanceType('image.svg', 'image/svgxml'))
|
||||
.toBe(FileAppearanceTypeEnum.image)
|
||||
})
|
||||
|
||||
it('should identify video files', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('mp4')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['mp4']))
|
||||
expect(getFileAppearanceType('video.mp4', 'video/mp4'))
|
||||
.toBe(FileAppearanceTypeEnum.video)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('mov')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['mov']))
|
||||
expect(getFileAppearanceType('video.mov', 'video/quicktime'))
|
||||
.toBe(FileAppearanceTypeEnum.video)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('mpeg')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['mpeg']))
|
||||
expect(getFileAppearanceType('video.mpeg', 'video/mpeg'))
|
||||
.toBe(FileAppearanceTypeEnum.video)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('webm')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['webm']))
|
||||
expect(getFileAppearanceType('video.web', 'video/webm'))
|
||||
.toBe(FileAppearanceTypeEnum.video)
|
||||
})
|
||||
|
||||
it('should identify audio files', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('mp3')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['mp3']))
|
||||
expect(getFileAppearanceType('audio.mp3', 'audio/mpeg'))
|
||||
.toBe(FileAppearanceTypeEnum.audio)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('m4a')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['m4a']))
|
||||
expect(getFileAppearanceType('audio.m4a', 'audio/mp4'))
|
||||
.toBe(FileAppearanceTypeEnum.audio)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('wav')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['wav']))
|
||||
expect(getFileAppearanceType('audio.wav', 'audio/vnd.wav'))
|
||||
.toBe(FileAppearanceTypeEnum.audio)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('amr')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['amr']))
|
||||
expect(getFileAppearanceType('audio.amr', 'audio/AMR'))
|
||||
.toBe(FileAppearanceTypeEnum.audio)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('mpga')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['mpga']))
|
||||
expect(getFileAppearanceType('audio.mpga', 'audio/mpeg'))
|
||||
.toBe(FileAppearanceTypeEnum.audio)
|
||||
})
|
||||
|
||||
it('should identify code files', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('html')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['html']))
|
||||
expect(getFileAppearanceType('index.html', 'text/html'))
|
||||
.toBe(FileAppearanceTypeEnum.code)
|
||||
})
|
||||
|
||||
it('should identify PDF files', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('pdf')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['pdf']))
|
||||
expect(getFileAppearanceType('doc.pdf', 'application/pdf'))
|
||||
.toBe(FileAppearanceTypeEnum.pdf)
|
||||
})
|
||||
|
||||
it('should identify markdown files', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('md')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['md']))
|
||||
expect(getFileAppearanceType('file.md', 'text/markdown'))
|
||||
.toBe(FileAppearanceTypeEnum.markdown)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('markdown')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['markdown']))
|
||||
expect(getFileAppearanceType('file.markdown', 'text/markdown'))
|
||||
.toBe(FileAppearanceTypeEnum.markdown)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('mdx')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['mdx']))
|
||||
expect(getFileAppearanceType('file.mdx', 'text/mdx'))
|
||||
.toBe(FileAppearanceTypeEnum.markdown)
|
||||
})
|
||||
|
||||
it('should identify excel files', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('xlsx')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['xlsx']))
|
||||
expect(getFileAppearanceType('doc.xlsx', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'))
|
||||
.toBe(FileAppearanceTypeEnum.excel)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('xls')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['xls']))
|
||||
expect(getFileAppearanceType('doc.xls', 'application/vnd.ms-excel'))
|
||||
.toBe(FileAppearanceTypeEnum.excel)
|
||||
})
|
||||
|
||||
it('should identify word files', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('doc')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['doc']))
|
||||
expect(getFileAppearanceType('doc.doc', 'application/msword'))
|
||||
.toBe(FileAppearanceTypeEnum.word)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('docx')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['docx']))
|
||||
expect(getFileAppearanceType('doc.docx', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'))
|
||||
.toBe(FileAppearanceTypeEnum.word)
|
||||
})
|
||||
|
||||
it('should identify word files', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('ppt')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['ppt']))
|
||||
expect(getFileAppearanceType('doc.ppt', 'application/vnd.ms-powerpoint'))
|
||||
.toBe(FileAppearanceTypeEnum.ppt)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('pptx')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['pptx']))
|
||||
expect(getFileAppearanceType('doc.pptx', 'application/vnd.openxmlformats-officedocument.presentationml.presentation'))
|
||||
.toBe(FileAppearanceTypeEnum.ppt)
|
||||
})
|
||||
|
||||
it('should identify document files', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('txt')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['txt']))
|
||||
expect(getFileAppearanceType('file.txt', 'text/plain'))
|
||||
.toBe(FileAppearanceTypeEnum.document)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('csv')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['csv']))
|
||||
expect(getFileAppearanceType('file.csv', 'text/csv'))
|
||||
.toBe(FileAppearanceTypeEnum.document)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('msg')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['msg']))
|
||||
expect(getFileAppearanceType('file.msg', 'application/vnd.ms-outlook'))
|
||||
.toBe(FileAppearanceTypeEnum.document)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('eml')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['eml']))
|
||||
expect(getFileAppearanceType('file.eml', 'message/rfc822'))
|
||||
.toBe(FileAppearanceTypeEnum.document)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('xml')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['xml']))
|
||||
expect(getFileAppearanceType('file.xml', 'application/rssxml'))
|
||||
.toBe(FileAppearanceTypeEnum.document)
|
||||
|
||||
jest.mocked(mime.getExtension).mockReturnValue('epub')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['epub']))
|
||||
expect(getFileAppearanceType('file.epub', 'application/epubzip'))
|
||||
.toBe(FileAppearanceTypeEnum.document)
|
||||
})
|
||||
|
||||
it('should handle null mime extension', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue(null)
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(null)
|
||||
expect(getFileAppearanceType('file.txt', 'text/plain'))
|
||||
.toBe(FileAppearanceTypeEnum.document)
|
||||
})
|
||||
|
|
@ -360,7 +375,7 @@ describe('file-uploader utils', () => {
|
|||
|
||||
describe('isAllowedFileExtension', () => {
|
||||
it('should validate allowed file extensions', () => {
|
||||
jest.mocked(mime.getExtension).mockReturnValue('pdf')
|
||||
jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['pdf']))
|
||||
expect(isAllowedFileExtension(
|
||||
'test.pdf',
|
||||
'application/pdf',
|
||||
|
|
|
|||
|
|
@ -42,19 +42,38 @@ export const fileUpload: FileUpload = ({
|
|||
})
|
||||
}
|
||||
|
||||
const additionalExtensionMap = new Map<string, string[]>([
|
||||
['text/x-markdown', ['md']],
|
||||
])
|
||||
|
||||
export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => {
|
||||
let extension = ''
|
||||
if (fileMimetype)
|
||||
extension = mime.getExtension(fileMimetype) || ''
|
||||
let extensions = new Set<string>()
|
||||
if (fileMimetype) {
|
||||
const extensionsFromMimeType = mime.getAllExtensions(fileMimetype) || new Set<string>()
|
||||
const additionalExtensions = additionalExtensionMap.get(fileMimetype) || []
|
||||
extensions = new Set<string>([
|
||||
...extensionsFromMimeType,
|
||||
...additionalExtensions,
|
||||
])
|
||||
}
|
||||
|
||||
if (fileName && !extension) {
|
||||
let extensionInFileName = ''
|
||||
if (fileName) {
|
||||
const fileNamePair = fileName.split('.')
|
||||
const fileNamePairLength = fileNamePair.length
|
||||
|
||||
if (fileNamePairLength > 1)
|
||||
extension = fileNamePair[fileNamePairLength - 1]
|
||||
if (fileNamePairLength > 1) {
|
||||
extensionInFileName = fileNamePair[fileNamePairLength - 1].toLowerCase()
|
||||
if (extensions.has(extensionInFileName))
|
||||
extension = extensionInFileName
|
||||
}
|
||||
}
|
||||
if (!extension) {
|
||||
if (extensions.size > 0)
|
||||
extension = extensions.values().next().value.toLowerCase()
|
||||
else
|
||||
extension = ''
|
||||
extension = extensionInFileName
|
||||
}
|
||||
|
||||
if (isRemote)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import { useChatContext } from '@/app/components/base/chat/chat/context'
|
||||
import Button from '@/app/components/base/button'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
import { isValidUrl } from './utils'
|
||||
const MarkdownButton = ({ node }: any) => {
|
||||
const { onSend } = useChatContext()
|
||||
const variant = node.properties.dataVariant
|
||||
|
|
@ -9,25 +9,17 @@ const MarkdownButton = ({ node }: any) => {
|
|||
const link = node.properties.dataLink
|
||||
const size = node.properties.dataSize
|
||||
|
||||
function is_valid_url(url: string): boolean {
|
||||
try {
|
||||
const parsed_url = new URL(url)
|
||||
return ['http:', 'https:'].includes(parsed_url.protocol)
|
||||
}
|
||||
catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return <Button
|
||||
variant={variant}
|
||||
size={size}
|
||||
className={cn('!h-auto min-h-8 select-none whitespace-normal !px-3')}
|
||||
onClick={() => {
|
||||
if (is_valid_url(link)) {
|
||||
if (isValidUrl(link)) {
|
||||
window.open(link, '_blank')
|
||||
return
|
||||
}
|
||||
if(!message)
|
||||
return
|
||||
onSend?.(message)
|
||||
}}
|
||||
>
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
*/
|
||||
import React from 'react'
|
||||
import { useChatContext } from '@/app/components/base/chat/chat/context'
|
||||
import { isValidUrl } from './utils'
|
||||
|
||||
const Link = ({ node, children, ...props }: any) => {
|
||||
const { onSend } = useChatContext()
|
||||
|
|
@ -14,7 +15,11 @@ const Link = ({ node, children, ...props }: any) => {
|
|||
return <abbr className="cursor-pointer underline !decoration-primary-700 decoration-dashed" onClick={() => onSend?.(hidden_text)} title={node.children[0]?.value || ''}>{node.children[0]?.value || ''}</abbr>
|
||||
}
|
||||
else {
|
||||
return <a {...props} target="_blank" className="cursor-pointer underline !decoration-primary-700 decoration-dashed">{children || 'Download'}</a>
|
||||
const href = props.href || node.properties?.href
|
||||
if(!isValidUrl(href))
|
||||
return <span>{children}</span>
|
||||
|
||||
return <a href={href} target="_blank" className="cursor-pointer underline !decoration-primary-700 decoration-dashed">{children || 'Download'}</a>
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
export const isValidUrl = (url: string): boolean => {
|
||||
return ['http:', 'https:', '//', 'mailto:'].some(prefix => url.startsWith(prefix))
|
||||
}
|
||||
|
|
@ -7,7 +7,7 @@ import RemarkGfm from 'remark-gfm'
|
|||
import RehypeRaw from 'rehype-raw'
|
||||
import { flow } from 'lodash-es'
|
||||
import cn from '@/utils/classnames'
|
||||
import { preprocessLaTeX, preprocessThinkTag } from './markdown-utils'
|
||||
import { customUrlTransform, preprocessLaTeX, preprocessThinkTag } from './markdown-utils'
|
||||
import {
|
||||
AudioBlock,
|
||||
CodeBlock,
|
||||
|
|
@ -65,6 +65,7 @@ export function Markdown(props: { content: string; className?: string; customDis
|
|||
}
|
||||
},
|
||||
]}
|
||||
urlTransform={customUrlTransform}
|
||||
disallowedElements={['iframe', 'head', 'html', 'meta', 'link', 'style', 'body', ...(props.customDisallowedElements || [])]}
|
||||
components={{
|
||||
code: CodeBlock,
|
||||
|
|
|
|||
|
|
@ -36,3 +36,52 @@ export const preprocessThinkTag = (content: string) => {
|
|||
(str: string) => str.replace(/(<\/details>)(?![^\S\r\n]*[\r\n])(?![^\S\r\n]*$)/g, '$1\n'),
|
||||
])(content)
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforms a URI for use in react-markdown, ensuring security and compatibility.
|
||||
* This function is designed to work with react-markdown v9+ which has stricter
|
||||
* default URL handling.
|
||||
*
|
||||
* Behavior:
|
||||
* 1. Always allows the custom 'abbr:' protocol.
|
||||
* 2. Always allows page-local fragments (e.g., "#some-id").
|
||||
* 3. Always allows protocol-relative URLs (e.g., "//example.com/path").
|
||||
* 4. Always allows purely relative paths (e.g., "path/to/file", "/abs/path").
|
||||
* 5. Allows absolute URLs if their scheme is in a permitted list (case-insensitive):
|
||||
* 'http:', 'https:', 'mailto:', 'xmpp:', 'irc:', 'ircs:'.
|
||||
* 6. Intelligently distinguishes colons used for schemes from colons within
|
||||
* paths, query parameters, or fragments of relative-like URLs.
|
||||
* 7. Returns the original URI if allowed, otherwise returns `undefined` to
|
||||
* signal that the URI should be removed/disallowed by react-markdown.
|
||||
*/
|
||||
export const customUrlTransform = (uri: string): string | undefined => {
|
||||
const PERMITTED_SCHEME_REGEX = /^(https?|ircs?|mailto|xmpp|abbr):$/i
|
||||
|
||||
if (uri.startsWith('#'))
|
||||
return uri
|
||||
|
||||
if (uri.startsWith('//'))
|
||||
return uri
|
||||
|
||||
const colonIndex = uri.indexOf(':')
|
||||
|
||||
if (colonIndex === -1)
|
||||
return uri
|
||||
|
||||
const slashIndex = uri.indexOf('/')
|
||||
const questionMarkIndex = uri.indexOf('?')
|
||||
const hashIndex = uri.indexOf('#')
|
||||
|
||||
if (
|
||||
(slashIndex !== -1 && colonIndex > slashIndex)
|
||||
|| (questionMarkIndex !== -1 && colonIndex > questionMarkIndex)
|
||||
|| (hashIndex !== -1 && colonIndex > hashIndex)
|
||||
)
|
||||
return uri
|
||||
|
||||
const scheme = uri.substring(0, colonIndex + 1).toLowerCase()
|
||||
if (PERMITTED_SCHEME_REGEX.test(scheme))
|
||||
return uri
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
|
|
|||
|
|
@ -487,15 +487,15 @@ const Flowchart = React.forwardRef((props: {
|
|||
'bg-white': currentTheme === Theme.light,
|
||||
'bg-slate-900': currentTheme === Theme.dark,
|
||||
}),
|
||||
mermaidDiv: cn('mermaid cursor-pointer h-auto w-full relative', {
|
||||
mermaidDiv: cn('mermaid relative h-auto w-full cursor-pointer', {
|
||||
'bg-white': currentTheme === Theme.light,
|
||||
'bg-slate-900': currentTheme === Theme.dark,
|
||||
}),
|
||||
errorMessage: cn('py-4 px-[26px]', {
|
||||
errorMessage: cn('px-[26px] py-4', {
|
||||
'text-red-500': currentTheme === Theme.light,
|
||||
'text-red-400': currentTheme === Theme.dark,
|
||||
}),
|
||||
errorIcon: cn('w-6 h-6', {
|
||||
errorIcon: cn('h-6 w-6', {
|
||||
'text-red-500': currentTheme === Theme.light,
|
||||
'text-red-400': currentTheme === Theme.dark,
|
||||
}),
|
||||
|
|
@ -503,7 +503,7 @@ const Flowchart = React.forwardRef((props: {
|
|||
'text-gray-700': currentTheme === Theme.light,
|
||||
'text-gray-300': currentTheme === Theme.dark,
|
||||
}),
|
||||
themeToggle: cn('flex items-center justify-center w-10 h-10 rounded-full transition-all duration-300 shadow-md backdrop-blur-sm', {
|
||||
themeToggle: cn('flex h-10 w-10 items-center justify-center rounded-full shadow-md backdrop-blur-sm transition-all duration-300', {
|
||||
'bg-white/80 hover:bg-white hover:shadow-lg text-gray-700 border border-gray-200': currentTheme === Theme.light,
|
||||
'bg-slate-800/80 hover:bg-slate-700 hover:shadow-lg text-yellow-300 border border-slate-600': currentTheme === Theme.dark,
|
||||
}),
|
||||
|
|
@ -512,7 +512,7 @@ const Flowchart = React.forwardRef((props: {
|
|||
// Style classes for look options
|
||||
const getLookButtonClass = (lookType: 'classic' | 'handDrawn') => {
|
||||
return cn(
|
||||
'flex items-center justify-center mb-4 w-[calc((100%-8px)/2)] h-8 rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg cursor-pointer system-sm-medium text-text-secondary',
|
||||
'system-sm-medium mb-4 flex h-8 w-[calc((100%-8px)/2)] cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary',
|
||||
look === lookType && 'border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary',
|
||||
currentTheme === Theme.dark && 'border-slate-600 bg-slate-800 text-slate-300',
|
||||
look === lookType && currentTheme === Theme.dark && 'border-blue-500 bg-slate-700 text-white',
|
||||
|
|
@ -523,7 +523,7 @@ const Flowchart = React.forwardRef((props: {
|
|||
<div ref={ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}>
|
||||
<div className={themeClasses.segmented}>
|
||||
<div className="msh-segmented-group">
|
||||
<label className="msh-segmented-item flex items-center space-x-1 m-2 w-[200px]">
|
||||
<label className="msh-segmented-item m-2 flex w-[200px] items-center space-x-1">
|
||||
<div
|
||||
key='classic'
|
||||
className={getLookButtonClass('classic')}
|
||||
|
|
@ -545,7 +545,7 @@ const Flowchart = React.forwardRef((props: {
|
|||
<div ref={containerRef} style={{ position: 'absolute', visibility: 'hidden', height: 0, overflow: 'hidden' }} />
|
||||
|
||||
{isLoading && !svgCode && (
|
||||
<div className='py-4 px-[26px]'>
|
||||
<div className='px-[26px] py-4'>
|
||||
<LoadingAnim type='text'/>
|
||||
{!isCodeComplete && (
|
||||
<div className="mt-2 text-sm text-gray-500">
|
||||
|
|
@ -557,7 +557,7 @@ const Flowchart = React.forwardRef((props: {
|
|||
|
||||
{svgCode && (
|
||||
<div className={themeClasses.mermaidDiv} style={{ objectFit: 'cover' }} onClick={() => setImagePreviewUrl(svgCode)}>
|
||||
<div className="absolute left-2 bottom-2 z-[100]">
|
||||
<div className="absolute bottom-2 left-2 z-[100]">
|
||||
<button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue