Merge branch 'main' into feat/grouping-branching

This commit is contained in:
zhsama 2026-01-02 01:34:02 +08:00
commit bd338a9043
278 changed files with 3019 additions and 1015 deletions

View File

@ -28,17 +28,14 @@ import userEvent from '@testing-library/user-event'
// i18n (automatically mocked)
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
// No explicit mock needed - it returns translation keys as-is
// The global mock provides: useTranslation, Trans, useMixedTranslation, useGetLanguage
// No explicit mock needed for most tests
//
// Override only if custom translations are required:
// vi.mock('react-i18next', () => ({
// useTranslation: () => ({
// t: (key: string) => {
// const customTranslations: Record<string, string> = {
// 'my.custom.key': 'Custom Translation',
// }
// return customTranslations[key] || key
// },
// }),
// import { createReactI18nextMock } from '@/test/i18n-mock'
// vi.mock('react-i18next', () => createReactI18nextMock({
// 'my.custom.key': 'Custom Translation',
// 'button.save': 'Save',
// }))
// Router (if component uses useRouter, usePathname, useSearchParams)

View File

@ -52,23 +52,29 @@ Modules are not mocked automatically. Use `vi.mock` in test files, or add global
### 1. i18n (Auto-loaded via Global Mock)
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
**No explicit mock needed** for most tests - it returns translation keys as-is.
For tests requiring custom translations, override the mock:
The global mock provides:
- `useTranslation` - returns translation keys with namespace prefix
- `Trans` component - renders i18nKey and components
- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`)
- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'`
**Default behavior**: Most tests should use the global mock (no local override needed).
**For custom translations**: Use the helper function from `@/test/i18n-mock`:
```typescript
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => {
const translations: Record<string, string> = {
'my.custom.key': 'Custom translation',
}
return translations[key] || key
},
}),
import { createReactI18nextMock } from '@/test/i18n-mock'
vi.mock('react-i18next', () => createReactI18nextMock({
'my.custom.key': 'Custom translation',
'button.save': 'Save',
}))
```
**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this.
### 2. Next.js Router
```typescript

View File

@ -110,6 +110,16 @@ jobs:
working-directory: ./web
run: pnpm run type-check:tsgo
- name: Web dead code check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run knip
- name: Web build check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run build
superlinter:
name: SuperLinter
runs-on: ubuntu-latest

View File

@ -101,6 +101,15 @@ S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
# Workflow run and Conversation archive storage (S3-compatible)
ARCHIVE_STORAGE_ENABLED=false
ARCHIVE_STORAGE_ENDPOINT=
ARCHIVE_STORAGE_ARCHIVE_BUCKET=
ARCHIVE_STORAGE_EXPORT_BUCKET=
ARCHIVE_STORAGE_ACCESS_KEY=
ARCHIVE_STORAGE_SECRET_KEY=
ARCHIVE_STORAGE_REGION=auto
# Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key

View File

@ -1,4 +1,8 @@
exclude = ["migrations/*"]
exclude = [
"migrations/*",
".git",
".git/**",
]
line-length = 120
[format]

View File

@ -1,9 +1,11 @@
from configs.extra.archive_config import ArchiveStorageConfig
from configs.extra.notion_config import NotionConfig
from configs.extra.sentry_config import SentryConfig
class ExtraServiceConfig(
# place the configs in alphabet order
ArchiveStorageConfig,
NotionConfig,
SentryConfig,
):

View File

@ -0,0 +1,43 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class ArchiveStorageConfig(BaseSettings):
"""
Configuration settings for workflow run logs archiving storage.
"""
ARCHIVE_STORAGE_ENABLED: bool = Field(
description="Enable workflow run logs archiving to S3-compatible storage",
default=False,
)
ARCHIVE_STORAGE_ENDPOINT: str | None = Field(
description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')",
default=None,
)
ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field(
description="Name of the bucket to store archived workflow logs",
default=None,
)
ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field(
description="Name of the bucket to store exported workflow runs",
default=None,
)
ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field(
description="Access key ID for authenticating with storage",
default=None,
)
ARCHIVE_STORAGE_SECRET_KEY: str | None = Field(
description="Secret access key for authenticating with storage",
default=None,
)
ARCHIVE_STORAGE_REGION: str = Field(
description="Region for storage (use 'auto' if the provider supports it)",
default="auto",
)

View File

@ -1,62 +1,59 @@
from flask_restx import Api, Namespace, fields
from __future__ import annotations
from libs.helper import AppIconUrlField
from typing import Any, TypeAlias
parameters__system_parameters = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
"workflow_file_upload_limit": fields.Integer,
}
from pydantic import BaseModel, ConfigDict, computed_field
from core.file import helpers as file_helpers
from models.model import IconType
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
JSONObject: TypeAlias = dict[str, Any]
def build_system_parameters_model(api_or_ns: Api | Namespace):
"""Build the system parameters model for the API or Namespace."""
return api_or_ns.model("SystemParameters", parameters__system_parameters)
class SystemParameters(BaseModel):
image_file_size_limit: int
video_file_size_limit: int
audio_file_size_limit: int
file_size_limit: int
workflow_file_upload_limit: int
parameters_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"suggested_questions_after_answer": fields.Raw,
"speech_to_text": fields.Raw,
"text_to_speech": fields.Raw,
"retriever_resource": fields.Raw,
"annotation_reply": fields.Raw,
"more_like_this": fields.Raw,
"user_input_form": fields.Raw,
"sensitive_word_avoidance": fields.Raw,
"file_upload": fields.Raw,
"system_parameters": fields.Nested(parameters__system_parameters),
}
class Parameters(BaseModel):
opening_statement: str | None = None
suggested_questions: list[str]
suggested_questions_after_answer: JSONObject
speech_to_text: JSONObject
text_to_speech: JSONObject
retriever_resource: JSONObject
annotation_reply: JSONObject
more_like_this: JSONObject
user_input_form: list[JSONObject]
sensitive_word_avoidance: JSONObject
file_upload: JSONObject
system_parameters: SystemParameters
def build_parameters_model(api_or_ns: Api | Namespace):
"""Build the parameters model for the API or Namespace."""
copied_fields = parameters_fields.copy()
copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
return api_or_ns.model("Parameters", copied_fields)
class Site(BaseModel):
model_config = ConfigDict(from_attributes=True)
title: str
chat_color_theme: str | None = None
chat_color_theme_inverted: bool
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
description: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
default_language: str
show_workflow_steps: bool
use_icon_as_answer_icon: bool
site_fields = {
"title": fields.String,
"chat_color_theme": fields.String,
"chat_color_theme_inverted": fields.Boolean,
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"description": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"default_language": fields.String,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
}
def build_site_model(api_or_ns: Api | Namespace):
"""Build the site model for the API or Namespace."""
return api_or_ns.model("Site", site_fields)
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
if self.icon and self.icon_type == IconType.IMAGE:
return file_helpers.get_signed_file_url(self.icon)
return None

View File

@ -1,3 +1,4 @@
import re
import uuid
from typing import Literal
@ -73,6 +74,48 @@ class AppListQuery(BaseModel):
raise ValueError("Invalid UUID format in tag_ids.") from exc
# XSS prevention: patterns that could lead to XSS attacks
# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
_XSS_PATTERNS = [
r"<script[^>]*>.*?</script>", # Script tags
r"<iframe\b[^>]*?(?:/>|>.*?</iframe>)", # Iframe tags (including self-closing)
r"javascript:", # JavaScript protocol
r"<svg[^>]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
r"<object\b[^>]*(?:\s*/>|>.*?</object\s*>)", # Object tags (opening tag)
r"<embed[^>]*>", # Embed tags (self-closing)
r"<link[^>]*>", # Link tags with javascript
]
def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
"""
Validate that a string value doesn't contain potential XSS payloads.
Args:
value: The string value to validate
field_name: Name of the field for error messages
Returns:
The original value if safe
Raises:
ValueError: If the value contains XSS patterns
"""
if value is None:
return None
value_lower = value.lower()
for pattern in _XSS_PATTERNS:
if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
raise ValueError(
f"{field_name} contains invalid characters or patterns. "
"HTML tags, JavaScript, and other potentially dangerous content are not allowed."
)
return value
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
@ -81,6 +124,11 @@ class CreateAppPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
@ -91,6 +139,11 @@ class UpdateAppPayload(BaseModel):
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
@ -99,6 +152,11 @@ class CopyAppPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")

View File

@ -124,7 +124,7 @@ class OAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
account = _generate_account(provider, user_info)
account, oauth_new_user = _generate_account(provider, user_info)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
@ -159,7 +159,10 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request),
)
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
base_url = dify_config.CONSOLE_WEB_URL
query_char = "&" if "?" in base_url else "?"
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
response = redirect(target_url)
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
@ -177,9 +180,10 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
return account
def _generate_account(provider: str, user_info: OAuthUserInfo):
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
# Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info)
oauth_new_user = False
if account:
tenants = TenantService.get_join_tenants(account)
@ -193,6 +197,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
tenant_was_created.send(new_tenant)
if not account:
oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register:
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
raise AccountRegisterError(
@ -220,4 +225,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
# Link account
AccountService.link_account_integrate(provider, user_info.id, account)
return account
return account, oauth_new_user

View File

@ -3,10 +3,12 @@ import uuid
from flask import request
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy import String, cast, func, or_, select
from sqlalchemy.dialects.postgresql import JSONB
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
@ -143,7 +145,29 @@ class DatasetDocumentSegmentListApi(Resource):
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
# Search in both content and keywords fields
# Use database-specific methods for JSON array search
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
keywords_condition = func.array_to_string(
func.array(
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
.correlate(DocumentSegment)
.scalar_subquery()
),
",",
).ilike(f"%{keyword}%")
else:
# MySQL: Cast JSON to string for pattern matching
# MySQL stores Chinese text directly in JSON without Unicode escaping
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
query = query.where(
or_(
DocumentSegment.content.ilike(f"%{keyword}%"),
keywords_condition,
)
)
if args.enabled.lower() != "all":
if args.enabled.lower() == "true":

View File

@ -1,5 +1,3 @@
from flask_restx import marshal_with
from controllers.common import fields
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
@ -13,7 +11,6 @@ from services.app_service import AppService
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""
@marshal_with(fields.parameters_fields)
def get(self, installed_app: InstalledApp):
"""Retrieve app parameters."""
app_model = installed_app.app
@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource):
user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

View File

@ -1,7 +1,7 @@
from typing import Literal
from flask import request
from flask_restx import Api, Namespace, Resource, fields
from flask_restx import Namespace, Resource, fields
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
@ -92,7 +92,7 @@ annotation_list_fields = {
}
def build_annotation_list_model(api_or_ns: Api | Namespace):
def build_annotation_list_model(api_or_ns: Namespace):
"""Build the annotation list model for the API or Namespace."""
copied_annotation_list_fields = annotation_list_fields.copy()
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))

View File

@ -1,6 +1,6 @@
from flask_restx import Resource
from controllers.common.fields import build_parameters_model
from controllers.common.fields import Parameters
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
@ -23,7 +23,6 @@ class AppParameterApi(Resource):
}
)
@validate_app_token
@service_api_ns.marshal_with(build_parameters_model(service_api_ns))
def get(self, app_model: App):
"""Retrieve app parameters.
@ -45,7 +44,8 @@ class AppParameterApi(Resource):
user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return Parameters.model_validate(parameters).model_dump(mode="json")
@service_api_ns.route("/meta")

View File

@ -1,7 +1,7 @@
from flask_restx import Resource
from werkzeug.exceptions import Forbidden
from controllers.common.fields import build_site_model
from controllers.common.fields import Site as SiteResponse
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
@ -23,7 +23,6 @@ class AppSiteApi(Resource):
}
)
@validate_app_token
@service_api_ns.marshal_with(build_site_model(service_api_ns))
def get(self, app_model: App):
"""Retrieve app site info.
@ -38,4 +37,4 @@ class AppSiteApi(Resource):
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return site
return SiteResponse.model_validate(site).model_dump(mode="json")

View File

@ -3,7 +3,7 @@ from typing import Any, Literal
from dateutil.parser import isoparse
from flask import request
from flask_restx import Api, Namespace, Resource, fields
from flask_restx import Namespace, Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
@ -78,7 +78,7 @@ workflow_run_fields = {
}
def build_workflow_run_model(api_or_ns: Api | Namespace):
def build_workflow_run_model(api_or_ns: Namespace):
"""Build the workflow run model for the API or Namespace."""
return api_or_ns.model("WorkflowRun", workflow_run_fields)

View File

@ -1,7 +1,7 @@
import logging
from flask import request
from flask_restx import Resource, marshal_with
from flask_restx import Resource
from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import Unauthorized
@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(fields.parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource):
user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
@web_ns.route("/meta")

View File

@ -22,6 +22,7 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
@ -165,6 +166,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
# Check if max iteration is reached and model still wants to call tools
if iteration_step == max_iteration_steps and scratchpad.action:
if scratchpad.action.action_name.lower() != "final answer":
raise AgentMaxIterationError(app_config.agent.max_iteration)
# get llm usage
if "usage" in usage_dict:
if usage_dict["usage"] is not None:

View File

@ -25,6 +25,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
@ -222,6 +223,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
final_answer += response + "\n"
# Check if max iteration is reached and model still wants to call tools
if iteration_step == max_iteration_steps and tool_calls:
raise AgentMaxIterationError(app_config.agent.max_iteration)
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:

View File

@ -27,26 +27,44 @@ class CleanProcessor:
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
text = re.sub(pattern, "", text)
# Remove URL but keep Markdown image URLs
# First, temporarily replace Markdown image URLs with a placeholder
markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
placeholders: list[str] = []
# Remove URL but keep Markdown image URLs and link URLs
# Replace the ENTIRE markdown link/image with a single placeholder to protect
# the link text (which might also be a URL) from being removed
markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)"
markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)"
placeholders: list[tuple[str, str, str]] = [] # (type, text, url)
def replace_with_placeholder(match, placeholders=placeholders):
def replace_markdown_with_placeholder(match, placeholders=placeholders):
link_type = "link"
link_text = match.group(1)
url = match.group(2)
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
placeholders.append((link_type, link_text, url))
return placeholder
def replace_image_with_placeholder(match, placeholders=placeholders):
link_type = "image"
url = match.group(1)
placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
placeholders.append(url)
return f"![image]({placeholder})"
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
placeholders.append((link_type, "image", url))
return placeholder
text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
# Protect markdown links first
text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text)
# Then protect markdown images
text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text)
# Now remove all remaining URLs
url_pattern = r"https?://[^\s)]+"
url_pattern = r"https?://\S+"
text = re.sub(url_pattern, "", text)
# Finally, restore the Markdown image URLs
for i, url in enumerate(placeholders):
text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
# Restore the Markdown links and images
for i, (link_type, text_or_alt, url) in enumerate(placeholders):
placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__"
if link_type == "link":
text = text.replace(placeholder, f"[{text_or_alt}]({url})")
else: # image
text = text.replace(placeholder, f"![{text_or_alt}]({url})")
return text
def filter_string(self, text):

View File

@ -112,7 +112,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
@ -148,7 +148,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:

View File

@ -1,25 +1,57 @@
"""Abstract interface for document loader implementations."""
import contextlib
import io
import logging
import uuid
from collections.abc import Iterator
import pypdfium2
import pypdfium2.raw as pdfium_c
from configs import dify_config
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfExtractor(BaseExtractor):
"""Load pdf files.
"""
PdfExtractor is used to extract text and images from PDF files.
Args:
file_path: Path to the file to load.
file_path: Path to the PDF file.
tenant_id: Workspace ID.
user_id: ID of the user performing the extraction.
file_cache_key: Optional cache key for the extracted text.
"""
def __init__(self, file_path: str, file_cache_key: str | None = None):
"""Initialize with file path."""
# Magic bytes for image format detection: (magic_bytes, extension, mime_type)
IMAGE_FORMATS = [
(b"\xff\xd8\xff", "jpg", "image/jpeg"),
(b"\x89PNG\r\n\x1a\n", "png", "image/png"),
(b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"),
(b"GIF8", "gif", "image/gif"),
(b"BM", "bmp", "image/bmp"),
(b"II*\x00", "tiff", "image/tiff"),
(b"MM\x00*", "tiff", "image/tiff"),
(b"II+\x00", "tiff", "image/tiff"),
(b"MM\x00+", "tiff", "image/tiff"),
]
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
"""Initialize PdfExtractor."""
self._file_path = file_path
self._tenant_id = tenant_id
self._user_id = user_id
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor):
def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
import pypdfium2 # type: ignore
with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor):
text_page = page.get_textpage()
content = text_page.get_text_range()
text_page.close()
image_content = self._extract_images(page)
if image_content:
content += "\n" + image_content
page.close()
metadata = {"source": blob.source, "page": page_number}
yield Document(page_content=content, metadata=metadata)
finally:
pdf_reader.close()
def _extract_images(self, page) -> str:
"""
Extract images from a PDF page, save them to storage and database,
and return markdown image links.
Args:
page: pypdfium2 page object.
Returns:
Markdown string containing links to the extracted images.
"""
image_content = []
upload_files = []
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
try:
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
for obj in image_objects:
try:
# Extract image bytes
img_byte_arr = io.BytesIO()
# Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly
# Fallback to png for other formats
obj.extract(img_byte_arr, fb_format="png")
img_bytes = img_byte_arr.getvalue()
if not img_bytes:
continue
header = img_bytes[: self.MAX_MAGIC_LEN]
image_ext = None
mime_type = None
for magic, ext, mime in self.IMAGE_FORMATS:
if header.startswith(magic):
image_ext = ext
mime_type = mime
break
if not image_ext or not mime_type:
continue
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext
storage.save(file_key, img_bytes)
# save file to db
upload_file = UploadFile(
tenant_id=self._tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=len(img_bytes),
extension=image_ext,
mime_type=mime_type,
created_by=self._user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self._user_id,
used_at=naive_utc_now(),
)
upload_files.append(upload_file)
image_content.append(f"![image]({base_url}/files/{upload_file.id}/file-preview)")
except Exception as e:
logger.warning("Failed to extract image from PDF: %s", e)
continue
except Exception as e:
logger.warning("Failed to get objects from PDF page: %s", e)
if upload_files:
db.session.add_all(upload_files)
db.session.commit()
return "\n".join(image_content)

View File

@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
) -> tuple[list[ApiToolBundle], str]:
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
auto parse to tool bundle

View File

@ -4,6 +4,7 @@ import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Preserves markdown links like [text](url) at the start.
Args:
text (str): The input text to process.
@ -11,6 +12,11 @@ def remove_leading_symbols(text: str) -> str:
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Check if text starts with a markdown link - preserve it
markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)"
if re.match(markdown_link_pattern, text):
return text
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'

View File

@ -60,6 +60,7 @@ class SkipPropagator:
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
self._state_manager.start_execution(downstream_node_id)
return
# All edges are skipped, propagate skip to this node

View File

@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)
class AgentMaxIterationError(AgentNodeError):
"""Exception raised when the agent exceeds the maximum iteration limit."""
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration
super().__init__(
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
f"The agent was unable to complete the task within the allowed number of iterations."
)

View File

@ -12,9 +12,8 @@ from dify_app import DifyApp
def _get_celery_ssl_options() -> dict[str, Any] | None:
"""Get SSL configuration for Celery broker/backend connections."""
# Use REDIS_USE_SSL for consistency with the main Redis client
# Only apply SSL if we're using Redis as broker/backend
if not dify_config.REDIS_USE_SSL:
if not dify_config.BROKER_USE_SSL:
return None
# Check if Celery is actually using Redis

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from libs.helper import TimestampField
@ -12,7 +12,7 @@ annotation_fields = {
}
def build_annotation_model(api_or_ns: Api | Namespace):
def build_annotation_model(api_or_ns: Namespace):
"""Build the annotation model for the API or Namespace."""
return api_or_ns.model("Annotation", annotation_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
@ -46,7 +46,7 @@ message_file_fields = {
}
def build_message_file_model(api_or_ns: Api | Namespace):
def build_message_file_model(api_or_ns: Namespace):
"""Build the message file fields for the API or Namespace."""
return api_or_ns.model("MessageFile", message_file_fields)
@ -217,7 +217,7 @@ conversation_infinite_scroll_pagination_fields = {
}
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
simple_conversation_model = build_simple_conversation_model(api_or_ns)
@ -226,11 +226,11 @@ def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespa
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
def build_conversation_delete_model(api_or_ns: Api | Namespace):
def build_conversation_delete_model(api_or_ns: Namespace):
"""Build the conversation delete model for the API or Namespace."""
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
def build_simple_conversation_model(api_or_ns: Api | Namespace):
def build_simple_conversation_model(api_or_ns: Namespace):
"""Build the simple conversation model for the API or Namespace."""
return api_or_ns.model("SimpleConversation", simple_conversation_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from libs.helper import TimestampField
@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = {
}
def build_conversation_variable_model(api_or_ns: Api | Namespace):
def build_conversation_variable_model(api_or_ns: Namespace):
"""Build the conversation variable model for the API or Namespace."""
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
# Build the nested variable model first
conversation_variable_model = build_conversation_variable_model(api_or_ns)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
simple_end_user_fields = {
"id": fields.String,
@ -8,5 +8,5 @@ simple_end_user_fields = {
}
def build_simple_end_user_model(api_or_ns: Api | Namespace):
def build_simple_end_user_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from libs.helper import TimestampField
@ -14,7 +14,7 @@ upload_config_fields = {
}
def build_upload_config_model(api_or_ns: Api | Namespace):
def build_upload_config_model(api_or_ns: Namespace):
"""Build the upload config model for the API or Namespace.
Args:
@ -39,7 +39,7 @@ file_fields = {
}
def build_file_model(api_or_ns: Api | Namespace):
def build_file_model(api_or_ns: Namespace):
"""Build the file model for the API or Namespace.
Args:
@ -57,7 +57,7 @@ remote_file_info_fields = {
}
def build_remote_file_info_model(api_or_ns: Api | Namespace):
def build_remote_file_info_model(api_or_ns: Namespace):
"""Build the remote file info model for the API or Namespace.
Args:
@ -81,7 +81,7 @@ file_fields_with_signed_url = {
}
def build_file_with_signed_url_model(api_or_ns: Api | Namespace):
def build_file_with_signed_url_model(api_or_ns: Namespace):
"""Build the file with signed URL model for the API or Namespace.
Args:

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from libs.helper import AvatarUrlField, TimestampField
@ -9,7 +9,7 @@ simple_account_fields = {
}
def build_simple_account_model(api_or_ns: Api | Namespace):
def build_simple_account_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleAccount", simple_account_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField
@ -10,7 +10,7 @@ feedback_fields = {
}
def build_feedback_model(api_or_ns: Api | Namespace):
def build_feedback_model(api_or_ns: Namespace):
"""Build the feedback model for the API or Namespace."""
return api_or_ns.model("Feedback", feedback_fields)
@ -30,7 +30,7 @@ agent_thought_fields = {
}
def build_agent_thought_model(api_or_ns: Api | Namespace):
def build_agent_thought_model(api_or_ns: Namespace):
"""Build the agent thought model for the API or Namespace."""
return api_or_ns.model("AgentThought", agent_thought_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
dataset_tag_fields = {
"id": fields.String,
@ -8,5 +8,5 @@ dataset_tag_fields = {
}
def build_dataset_tag_fields(api_or_ns: Api | Namespace):
def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
from fields.member_fields import build_simple_account_model, simple_account_fields
@ -17,7 +17,7 @@ workflow_app_log_partial_fields = {
}
def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
"""Build the workflow app log partial model for the API or Namespace."""
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
simple_account_model = build_simple_account_model(api_or_ns)
@ -43,7 +43,7 @@ workflow_app_log_pagination_fields = {
}
def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
"""Build the workflow app log pagination model for the API or Namespace."""
# Build the nested partial model first
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
@ -19,7 +19,7 @@ workflow_run_for_log_fields = {
}
def build_workflow_run_for_log_model(api_or_ns: Api | Namespace):
def build_workflow_run_for_log_model(api_or_ns: Namespace):
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)

347
api/libs/archive_storage.py Normal file
View File

@ -0,0 +1,347 @@
"""
Archive Storage Client for S3-compatible storage.
This module provides a dedicated storage client for archiving or exporting logs
to S3-compatible object storage.
"""
import base64
import datetime
import gzip
import hashlib
import logging
from collections.abc import Generator
from typing import Any, cast
import boto3
import orjson
from botocore.client import Config
from botocore.exceptions import ClientError
from configs import dify_config
logger = logging.getLogger(__name__)
class ArchiveStorageError(Exception):
"""Base exception for archive storage operations."""
pass
class ArchiveStorageNotConfiguredError(ArchiveStorageError):
"""Raised when archive storage is not properly configured."""
pass
class ArchiveStorage:
"""
S3-compatible storage client for archiving or exporting.
This client provides methods for storing and retrieving archived data in JSONL+gzip format.
"""
def __init__(self, bucket: str):
if not dify_config.ARCHIVE_STORAGE_ENABLED:
raise ArchiveStorageNotConfiguredError("Archive storage is not enabled")
if not bucket:
raise ArchiveStorageNotConfiguredError("Archive storage bucket is not configured")
if not all(
[
dify_config.ARCHIVE_STORAGE_ENDPOINT,
bucket,
dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
dify_config.ARCHIVE_STORAGE_SECRET_KEY,
]
):
raise ArchiveStorageNotConfiguredError(
"Archive storage configuration is incomplete. "
"Required: ARCHIVE_STORAGE_ENDPOINT, ARCHIVE_STORAGE_ACCESS_KEY, "
"ARCHIVE_STORAGE_SECRET_KEY, and a bucket name"
)
self.bucket = bucket
self.client = boto3.client(
"s3",
endpoint_url=dify_config.ARCHIVE_STORAGE_ENDPOINT,
aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY,
region_name=dify_config.ARCHIVE_STORAGE_REGION,
config=Config(s3={"addressing_style": "path"}),
)
# Verify bucket accessibility
try:
self.client.head_bucket(Bucket=self.bucket)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == "404":
raise ArchiveStorageNotConfiguredError(f"Archive bucket '{self.bucket}' does not exist")
elif error_code == "403":
raise ArchiveStorageNotConfiguredError(f"Access denied to archive bucket '{self.bucket}'")
else:
raise ArchiveStorageError(f"Failed to access archive bucket: {e}")
def put_object(self, key: str, data: bytes) -> str:
"""
Upload an object to the archive storage.
Args:
key: Object key (path) within the bucket
data: Binary data to upload
Returns:
MD5 checksum of the uploaded data
Raises:
ArchiveStorageError: If upload fails
"""
checksum = hashlib.md5(data).hexdigest()
try:
self.client.put_object(
Bucket=self.bucket,
Key=key,
Body=data,
ContentMD5=self._content_md5(data),
)
logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum)
return checksum
except ClientError as e:
raise ArchiveStorageError(f"Failed to upload object '{key}': {e}")
def get_object(self, key: str) -> bytes:
"""
Download an object from the archive storage.
Args:
key: Object key (path) within the bucket
Returns:
Binary data of the object
Raises:
ArchiveStorageError: If download fails
FileNotFoundError: If object does not exist
"""
try:
response = self.client.get_object(Bucket=self.bucket, Key=key)
return response["Body"].read()
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == "NoSuchKey":
raise FileNotFoundError(f"Archive object not found: {key}")
raise ArchiveStorageError(f"Failed to download object '{key}': {e}")
def get_object_stream(self, key: str) -> Generator[bytes, None, None]:
"""
Stream an object from the archive storage.
Args:
key: Object key (path) within the bucket
Yields:
Chunks of binary data
Raises:
ArchiveStorageError: If download fails
FileNotFoundError: If object does not exist
"""
try:
response = self.client.get_object(Bucket=self.bucket, Key=key)
yield from response["Body"].iter_chunks()
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == "NoSuchKey":
raise FileNotFoundError(f"Archive object not found: {key}")
raise ArchiveStorageError(f"Failed to stream object '{key}': {e}")
def object_exists(self, key: str) -> bool:
"""
Check if an object exists in the archive storage.
Args:
key: Object key (path) within the bucket
Returns:
True if object exists, False otherwise
"""
try:
self.client.head_object(Bucket=self.bucket, Key=key)
return True
except ClientError:
return False
def delete_object(self, key: str) -> None:
"""
Delete an object from the archive storage.
Args:
key: Object key (path) within the bucket
Raises:
ArchiveStorageError: If deletion fails
"""
try:
self.client.delete_object(Bucket=self.bucket, Key=key)
logger.debug("Deleted object: %s", key)
except ClientError as e:
raise ArchiveStorageError(f"Failed to delete object '{key}': {e}")
def generate_presigned_url(self, key: str, expires_in: int = 3600) -> str:
"""
Generate a pre-signed URL for downloading an object.
Args:
key: Object key (path) within the bucket
expires_in: URL validity duration in seconds (default: 1 hour)
Returns:
Pre-signed URL string.
Raises:
ArchiveStorageError: If generation fails
"""
try:
return self.client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": self.bucket, "Key": key},
ExpiresIn=expires_in,
)
except ClientError as e:
raise ArchiveStorageError(f"Failed to generate pre-signed URL for '{key}': {e}")
def list_objects(self, prefix: str) -> list[str]:
"""
List objects under a given prefix.
Args:
prefix: Object key prefix to filter by
Returns:
List of object keys matching the prefix
"""
keys = []
paginator = self.client.get_paginator("list_objects_v2")
try:
for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
for obj in page.get("Contents", []):
keys.append(obj["Key"])
except ClientError as e:
raise ArchiveStorageError(f"Failed to list objects with prefix '{prefix}': {e}")
return keys
@staticmethod
def _content_md5(data: bytes) -> str:
"""Calculate base64-encoded MD5 for Content-MD5 header."""
return base64.b64encode(hashlib.md5(data).digest()).decode()
@staticmethod
def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes:
"""
Serialize records to gzipped JSONL format.
Args:
records: List of dictionaries to serialize
Returns:
Gzipped JSONL bytes
"""
lines = []
for record in records:
# Convert datetime objects to ISO format strings
serialized = ArchiveStorage._serialize_record(record)
lines.append(orjson.dumps(serialized))
jsonl_content = b"\n".join(lines)
if jsonl_content:
jsonl_content += b"\n"
return gzip.compress(jsonl_content)
@staticmethod
def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]:
"""
Deserialize gzipped JSONL data to records.
Args:
data: Gzipped JSONL bytes
Returns:
List of dictionaries
"""
jsonl_content = gzip.decompress(data)
records = []
for line in jsonl_content.splitlines():
if line:
records.append(orjson.loads(line))
return records
@staticmethod
def _serialize_record(record: dict[str, Any]) -> dict[str, Any]:
"""Serialize a single record, converting special types."""
def _serialize(item: Any) -> Any:
if isinstance(item, datetime.datetime):
return item.isoformat()
if isinstance(item, dict):
return {key: _serialize(value) for key, value in item.items()}
if isinstance(item, list):
return [_serialize(value) for value in item]
return item
return cast(dict[str, Any], _serialize(record))
@staticmethod
def compute_checksum(data: bytes) -> str:
"""Compute MD5 checksum of data."""
return hashlib.md5(data).hexdigest()
# Singleton instance (lazy initialization)
_archive_storage: ArchiveStorage | None = None
_export_storage: ArchiveStorage | None = None
def get_archive_storage() -> ArchiveStorage:
"""
Get the archive storage singleton instance.
Returns:
ArchiveStorage instance
Raises:
ArchiveStorageNotConfiguredError: If archive storage is not configured
"""
global _archive_storage
if _archive_storage is None:
archive_bucket = dify_config.ARCHIVE_STORAGE_ARCHIVE_BUCKET
if not archive_bucket:
raise ArchiveStorageNotConfiguredError(
"Archive storage bucket is not configured. Required: ARCHIVE_STORAGE_ARCHIVE_BUCKET"
)
_archive_storage = ArchiveStorage(bucket=archive_bucket)
return _archive_storage
def get_export_storage() -> ArchiveStorage:
"""
Get the export storage singleton instance.
Returns:
ArchiveStorage instance
"""
global _export_storage
if _export_storage is None:
export_bucket = dify_config.ARCHIVE_STORAGE_EXPORT_BUCKET
if not export_bucket:
raise ArchiveStorageNotConfiguredError(
"Archive export bucket is not configured. Required: ARCHIVE_STORAGE_EXPORT_BUCKET"
)
_export_storage = ArchiveStorage(bucket=export_bucket)
return _export_storage

View File

@ -8,7 +8,7 @@ from uuid import uuid4
import sqlalchemy as sa
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column
from sqlalchemy.orm import Mapped, Session, mapped_column, validates
from typing_extensions import deprecated
from .base import TypeBase
@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase):
role: TenantAccountRole | None = field(default=None, init=False)
_current_tenant: "Tenant | None" = field(default=None, init=False)
@validates("status")
def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
if isinstance(value, AccountStatus):
return value.value
return value
@property
def is_password_set(self):
return self.password is not None

View File

@ -16,6 +16,11 @@ celery_redis = Redis(
port=redis_config.get("port") or 6379,
password=redis_config.get("password") or None,
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
ssl=bool(dify_config.BROKER_USE_SSL),
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None,
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
)
logger = logging.getLogger(__name__)

View File

@ -85,7 +85,9 @@ class ApiToolManageService:
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
def convert_schema_to_tool_bundles(
schema: str, extra_info: dict | None = None
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
convert schema to tool bundles
@ -103,7 +105,7 @@ class ApiToolManageService:
provider_name: str,
icon: dict,
credentials: dict,
schema_type: str,
schema_type: ApiProviderSchemaType,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
@ -112,9 +114,6 @@ class ApiToolManageService:
"""
create api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema}")
provider_name = provider_name.strip()
# check if the provider exists
@ -241,18 +240,15 @@ class ApiToolManageService:
original_provider: str,
icon: dict,
credentials: dict,
schema_type: str,
_schema_type: ApiProviderSchemaType,
schema: str,
privacy_policy: str,
privacy_policy: str | None,
custom_disclaimer: str,
labels: list[str],
):
"""
update api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema}")
provider_name = provider_name.strip()
# check if the provider exists
@ -277,7 +273,7 @@ class ApiToolManageService:
provider.icon = json.dumps(icon)
provider.schema = schema
provider.description = extra_info.get("description", "")
provider.schema_type_str = ApiProviderSchemaType.OPENAPI
provider.schema_type_str = schema_type
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
provider.custom_disclaimer = custom_disclaimer
@ -356,7 +352,7 @@ class ApiToolManageService:
tool_name: str,
credentials: dict,
parameters: dict,
schema_type: str,
schema_type: ApiProviderSchemaType,
schema: str,
):
"""

View File

@ -0,0 +1,69 @@
import builtins
from types import SimpleNamespace
from unittest.mock import patch
from flask.views import MethodView as FlaskMethodView
_NEEDS_METHOD_VIEW_CLEANUP = False
if not hasattr(builtins, "MethodView"):
builtins.MethodView = FlaskMethodView
_NEEDS_METHOD_VIEW_CLEANUP = True
from controllers.common.fields import Parameters, Site
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import IconType
def test_parameters_model_round_trip():
parameters = get_parameters_from_feature_dict(features_dict={}, user_input_form=[])
model = Parameters.model_validate(parameters)
assert model.model_dump(mode="json") == parameters
def test_site_icon_url_uses_signed_url_for_image_icon():
site = SimpleNamespace(
title="Example",
chat_color_theme=None,
chat_color_theme_inverted=False,
icon_type=IconType.IMAGE,
icon="file-id",
icon_background=None,
description=None,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
default_language="en-US",
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
with patch("controllers.common.fields.file_helpers.get_signed_file_url", return_value="signed") as mock_helper:
model = Site.model_validate(site)
assert model.icon_url == "signed"
mock_helper.assert_called_once_with("file-id")
def test_site_icon_url_is_none_for_non_image_icon():
site = SimpleNamespace(
title="Example",
chat_color_theme=None,
chat_color_theme_inverted=False,
icon_type=IconType.EMOJI,
icon="file-id",
icon_background=None,
description=None,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
default_language="en-US",
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
with patch("controllers.common.fields.file_helpers.get_signed_file_url") as mock_helper:
model = Site.model_validate(site)
assert model.icon_url is None
mock_helper.assert_not_called()

View File

@ -0,0 +1,254 @@
"""
Unit tests for XSS prevention in App payloads.
This test module validates that HTML tags, JavaScript, and other potentially
dangerous content are rejected in App names and descriptions.
"""
import pytest
from controllers.console.app.app import CopyAppPayload, CreateAppPayload, UpdateAppPayload
class TestXSSPreventionUnit:
"""Unit tests for XSS prevention in App payloads."""
def test_create_app_valid_names(self):
"""Test CreateAppPayload with valid app names."""
# Normal app names should be valid
valid_names = [
"My App",
"Test App 123",
"App with - dash",
"App with _ underscore",
"App with + plus",
"App with () parentheses",
"App with [] brackets",
"App with {} braces",
"App with ! exclamation",
"App with @ at",
"App with # hash",
"App with $ dollar",
"App with % percent",
"App with ^ caret",
"App with & ampersand",
"App with * asterisk",
"Unicode: 测试应用",
"Emoji: 🤖",
"Mixed: Test 测试 123",
]
for name in valid_names:
payload = CreateAppPayload(
name=name,
mode="chat",
)
assert payload.name == name
def test_create_app_xss_script_tags(self):
"""Test CreateAppPayload rejects script tags."""
xss_payloads = [
"<script>alert(document.cookie)</script>",
"<Script>alert(1)</Script>",
"<SCRIPT>alert('XSS')</SCRIPT>",
"<script>alert(String.fromCharCode(88,83,83))</script>",
"<script src='evil.js'></script>",
"<script>document.location='http://evil.com'</script>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_iframe_tags(self):
"""Test CreateAppPayload rejects iframe tags."""
xss_payloads = [
"<iframe src='evil.com'></iframe>",
"<Iframe srcdoc='<script>alert(1)</script>'></iframe>",
"<IFRAME src='javascript:alert(1)'></iframe>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_javascript_protocol(self):
"""Test CreateAppPayload rejects javascript: protocol."""
xss_payloads = [
"javascript:alert(1)",
"JAVASCRIPT:alert(1)",
"JavaScript:alert(document.cookie)",
"javascript:void(0)",
"javascript://comment%0Aalert(1)",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_svg_onload(self):
"""Test CreateAppPayload rejects SVG with onload."""
xss_payloads = [
"<svg onload=alert(1)>",
"<SVG ONLOAD=alert(1)>",
"<svg/x/onload=alert(1)>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_event_handlers(self):
"""Test CreateAppPayload rejects HTML event handlers."""
xss_payloads = [
"<div onclick=alert(1)>",
"<img onerror=alert(1)>",
"<body onload=alert(1)>",
"<input onfocus=alert(1)>",
"<a onmouseover=alert(1)>",
"<DIV ONCLICK=alert(1)>",
"<img src=x onerror=alert(1)>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_object_embed(self):
"""Test CreateAppPayload rejects object and embed tags."""
xss_payloads = [
"<object data='evil.swf'></object>",
"<embed src='evil.swf'>",
"<OBJECT data='javascript:alert(1)'></OBJECT>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_link_javascript(self):
"""Test CreateAppPayload rejects link tags with javascript."""
xss_payloads = [
"<link href='javascript:alert(1)'>",
"<LINK HREF='javascript:alert(1)'>",
]
for name in xss_payloads:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_xss_in_description(self):
"""Test CreateAppPayload rejects XSS in description."""
xss_descriptions = [
"<script>alert(1)</script>",
"javascript:alert(1)",
"<img onerror=alert(1)>",
]
for description in xss_descriptions:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(
name="Valid Name",
mode="chat",
description=description,
)
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_create_app_valid_descriptions(self):
"""Test CreateAppPayload with valid descriptions."""
valid_descriptions = [
"A simple description",
"Description with < and > symbols",
"Description with & ampersand",
"Description with 'quotes' and \"double quotes\"",
"Description with / slashes",
"Description with \\ backslashes",
"Description with ; semicolons",
"Unicode: 这是一个描述",
"Emoji: 🎉🚀",
]
for description in valid_descriptions:
payload = CreateAppPayload(
name="Valid App Name",
mode="chat",
description=description,
)
assert payload.description == description
def test_create_app_none_description(self):
"""Test CreateAppPayload with None description."""
payload = CreateAppPayload(
name="Valid App Name",
mode="chat",
description=None,
)
assert payload.description is None
def test_update_app_xss_prevention(self):
"""Test UpdateAppPayload also prevents XSS."""
xss_names = [
"<script>alert(1)</script>",
"javascript:alert(1)",
"<img onerror=alert(1)>",
]
for name in xss_names:
with pytest.raises(ValueError) as exc_info:
UpdateAppPayload(name=name)
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_update_app_valid_names(self):
"""Test UpdateAppPayload with valid names."""
payload = UpdateAppPayload(name="Valid Updated Name")
assert payload.name == "Valid Updated Name"
def test_copy_app_xss_prevention(self):
"""Test CopyAppPayload also prevents XSS."""
xss_names = [
"<script>alert(1)</script>",
"javascript:alert(1)",
"<img onerror=alert(1)>",
]
for name in xss_names:
with pytest.raises(ValueError) as exc_info:
CopyAppPayload(name=name)
assert "invalid characters or patterns" in str(exc_info.value).lower()
def test_copy_app_valid_names(self):
"""Test CopyAppPayload with valid names."""
payload = CopyAppPayload(name="Valid Copy Name")
assert payload.name == "Valid Copy Name"
def test_copy_app_none_name(self):
"""Test CopyAppPayload with None name (should be allowed)."""
payload = CopyAppPayload(name=None)
assert payload.name is None
def test_edge_case_angle_brackets_content(self):
"""Test that angle brackets with actual content are rejected."""
# Angle brackets without valid HTML-like patterns should be checked
# The regex pattern <.*?on\w+\s*= should catch event handlers
# But let's verify other patterns too
# Valid: angle brackets used as symbols (not matched by our patterns)
# Our patterns specifically look for dangerous constructs
# Invalid: actual HTML tags with event handlers
invalid_names = [
"<div onclick=xss>",
"<img src=x onerror=alert(1)>",
]
for name in invalid_names:
with pytest.raises(ValueError) as exc_info:
CreateAppPayload(name=name, mode="chat")
assert "invalid characters or patterns" in str(exc_info.value).lower()

View File

@ -171,7 +171,7 @@ class TestOAuthCallback:
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
mock_generate_account.return_value = oauth_setup["account"]
mock_generate_account.return_value = (oauth_setup["account"], True)
mock_account_service.login.return_value = oauth_setup["token_pair"]
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
@ -179,7 +179,7 @@ class TestOAuthCallback:
oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
mock_redirect.assert_called_once_with("http://localhost:3000")
mock_redirect.assert_called_once_with("http://localhost:3000?oauth_new_user=true")
@pytest.mark.parametrize(
("exception", "expected_error"),
@ -223,7 +223,7 @@ class TestOAuthCallback:
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
(
AccountStatus.CLOSED.value,
"http://localhost:3000",
"http://localhost:3000?oauth_new_user=false",
),
],
)
@ -260,7 +260,7 @@ class TestOAuthCallback:
account = MagicMock()
account.status = account_status
account.id = "123"
mock_generate_account.return_value = account
mock_generate_account.return_value = (account, False)
# Mock login for CLOSED status
mock_token_pair = MagicMock()
@ -296,7 +296,7 @@ class TestOAuthCallback:
mock_account = MagicMock()
mock_account.status = AccountStatus.PENDING
mock_generate_account.return_value = mock_account
mock_generate_account.return_value = (mock_account, False)
mock_token_pair = MagicMock()
mock_token_pair.access_token = "jwt_access_token"
@ -360,7 +360,7 @@ class TestOAuthCallback:
closed_account.status = AccountStatus.CLOSED
closed_account.id = "123"
closed_account.name = "Closed Account"
mock_generate_account.return_value = closed_account
mock_generate_account.return_value = (closed_account, False)
# Mock successful login (current behavior)
mock_token_pair = MagicMock()
@ -374,7 +374,7 @@ class TestOAuthCallback:
resource.get("github")
# Verify current behavior: login succeeds (this is NOT ideal)
mock_redirect.assert_called_once_with("http://localhost:3000")
mock_redirect.assert_called_once_with("http://localhost:3000?oauth_new_user=false")
mock_account_service.login.assert_called_once()
# Document expected behavior in comments:
@ -458,8 +458,9 @@ class TestAccountGeneration:
with pytest.raises(AccountRegisterError):
_generate_account("github", user_info)
else:
result = _generate_account("github", user_info)
result, oauth_new_user = _generate_account("github", user_info)
assert result == mock_account
assert oauth_new_user == should_create
if should_create:
mock_register_service.register.assert_called_once_with(
@ -490,9 +491,10 @@ class TestAccountGeneration:
mock_tenant_service.create_tenant.return_value = mock_new_tenant
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
result = _generate_account("github", user_info)
result, oauth_new_user = _generate_account("github", user_info)
assert result == mock_account
assert oauth_new_user is False
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
mock_tenant_service.create_tenant_member.assert_called_once_with(
mock_new_tenant, mock_account, role="owner"

View File

@ -0,0 +1,213 @@
from core.rag.cleaner.clean_processor import CleanProcessor
class TestCleanProcessor:
"""Test cases for CleanProcessor.clean method."""
def test_clean_default_removal_of_invalid_symbols(self):
"""Test default cleaning removes invalid symbols."""
# Test <| replacement
assert CleanProcessor.clean("text<|with<|invalid", None) == "text<with<invalid"
# Test |> replacement
assert CleanProcessor.clean("text|>with|>invalid", None) == "text>with>invalid"
# Test removal of control characters
text_with_control = "normal\x00text\x1fwith\x07control\x7fchars"
expected = "normaltextwithcontrolchars"
assert CleanProcessor.clean(text_with_control, None) == expected
# Test U+FFFE removal
text_with_ufffe = "normal\ufffepadding"
expected = "normalpadding"
assert CleanProcessor.clean(text_with_ufffe, None) == expected
def test_clean_with_none_process_rule(self):
"""Test cleaning with None process_rule - only default cleaning applied."""
text = "Hello<|World\x00"
expected = "Hello<World"
assert CleanProcessor.clean(text, None) == expected
def test_clean_with_empty_process_rule(self):
"""Test cleaning with empty process_rule dict - only default cleaning applied."""
text = "Hello<|World\x00"
expected = "Hello<World"
assert CleanProcessor.clean(text, {}) == expected
def test_clean_with_empty_rules(self):
"""Test cleaning with empty rules - only default cleaning applied."""
text = "Hello<|World\x00"
expected = "Hello<World"
assert CleanProcessor.clean(text, {"rules": {}}) == expected
def test_clean_remove_extra_spaces_enabled(self):
"""Test remove_extra_spaces rule when enabled."""
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}]}}
# Test multiple newlines reduced to two
text = "Line1\n\n\n\n\nLine2"
expected = "Line1\n\nLine2"
assert CleanProcessor.clean(text, process_rule) == expected
# Test various whitespace characters reduced to single space
text = "word1\u2000\u2001\t\t \u3000word2"
expected = "word1 word2"
assert CleanProcessor.clean(text, process_rule) == expected
# Test combination of newlines and spaces
text = "Line1\n\n\n\n \t Line2"
expected = "Line1\n\n Line2"
assert CleanProcessor.clean(text, process_rule) == expected
def test_clean_remove_extra_spaces_disabled(self):
"""Test remove_extra_spaces rule when disabled."""
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": False}]}}
text = "Line1\n\n\n\n\nLine2 with spaces"
# Should only apply default cleaning (no invalid symbols here)
assert CleanProcessor.clean(text, process_rule) == text
def test_clean_remove_urls_emails_enabled(self):
"""Test remove_urls_emails rule when enabled."""
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
# Test email removal
text = "Contact us at test@example.com for more info"
expected = "Contact us at for more info"
assert CleanProcessor.clean(text, process_rule) == expected
# Test URL removal
text = "Visit https://example.com or http://test.org"
expected = "Visit or "
assert CleanProcessor.clean(text, process_rule) == expected
# Test both email and URL
text = "Email me@test.com and visit https://site.com"
expected = "Email and visit "
assert CleanProcessor.clean(text, process_rule) == expected
def test_clean_preserve_markdown_links_and_images(self):
"""Test that markdown links and images are preserved when removing URLs."""
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
# Test markdown link preservation
text = "Check [Google](https://google.com) for info"
expected = "Check [Google](https://google.com) for info"
assert CleanProcessor.clean(text, process_rule) == expected
# Test markdown image preservation
text = "Image: ![alt](https://example.com/image.png)"
expected = "Image: ![alt](https://example.com/image.png)"
assert CleanProcessor.clean(text, process_rule) == expected
# Test both link and image preservation
text = "[Link](https://link.com) and ![Image](https://image.com/img.jpg)"
expected = "[Link](https://link.com) and ![Image](https://image.com/img.jpg)"
assert CleanProcessor.clean(text, process_rule) == expected
# Test that non-markdown URLs are still removed
text = "Check [Link](https://keep.com) but remove https://remove.com"
expected = "Check [Link](https://keep.com) but remove "
assert CleanProcessor.clean(text, process_rule) == expected
# Test email removal alongside markdown preservation
text = "Email: test@test.com, link: [Click](https://site.com)"
expected = "Email: , link: [Click](https://site.com)"
assert CleanProcessor.clean(text, process_rule) == expected
def test_clean_remove_urls_emails_disabled(self):
"""Test remove_urls_emails rule when disabled."""
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": False}]}}
text = "Email test@example.com visit https://example.com"
# Should only apply default cleaning
assert CleanProcessor.clean(text, process_rule) == text
def test_clean_both_rules_enabled(self):
"""Test both pre-processing rules enabled together."""
process_rule = {
"rules": {
"pre_processing_rules": [
{"id": "remove_extra_spaces", "enabled": True},
{"id": "remove_urls_emails", "enabled": True},
]
}
}
text = "Hello\n\n\n\n World test@example.com \n\n\nhttps://example.com"
expected = "Hello\n\n World \n\n"
assert CleanProcessor.clean(text, process_rule) == expected
def test_clean_with_markdown_link_and_extra_spaces(self):
"""Test markdown link preservation with extra spaces removal."""
process_rule = {
"rules": {
"pre_processing_rules": [
{"id": "remove_extra_spaces", "enabled": True},
{"id": "remove_urls_emails", "enabled": True},
]
}
}
text = "[Link](https://example.com)\n\n\n\n Text https://remove.com"
expected = "[Link](https://example.com)\n\n Text "
assert CleanProcessor.clean(text, process_rule) == expected
def test_clean_unknown_rule_id_ignored(self):
"""Test that unknown rule IDs are silently ignored."""
process_rule = {"rules": {"pre_processing_rules": [{"id": "unknown_rule", "enabled": True}]}}
text = "Hello<|World\x00"
expected = "Hello<World"
# Only default cleaning should be applied
assert CleanProcessor.clean(text, process_rule) == expected
def test_clean_empty_text(self):
"""Test cleaning empty text."""
assert CleanProcessor.clean("", None) == ""
assert CleanProcessor.clean("", {}) == ""
assert CleanProcessor.clean("", {"rules": {}}) == ""
def test_clean_text_with_only_invalid_symbols(self):
"""Test text containing only invalid symbols."""
text = "<|<|\x00\x01\x02\ufffe|>|>"
# <| becomes <, |> becomes >, control chars and U+FFFE are removed
assert CleanProcessor.clean(text, None) == "<<>>"
def test_clean_multiple_markdown_links_preserved(self):
"""Test multiple markdown links are all preserved."""
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
text = "[One](https://one.com) [Two](http://two.org) [Three](https://three.net)"
expected = "[One](https://one.com) [Two](http://two.org) [Three](https://three.net)"
assert CleanProcessor.clean(text, process_rule) == expected
def test_clean_markdown_link_text_as_url(self):
"""Test markdown link where the link text itself is a URL."""
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
# Link text that looks like URL should be preserved
text = "[https://text-url.com](https://actual-url.com)"
expected = "[https://text-url.com](https://actual-url.com)"
assert CleanProcessor.clean(text, process_rule) == expected
# Text URL without markdown should be removed
text = "https://text-url.com https://actual-url.com"
expected = " "
assert CleanProcessor.clean(text, process_rule) == expected
def test_clean_complex_markdown_link_content(self):
"""Test markdown links with complex content - known limitation with brackets in link text."""
process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}}
# Note: The regex pattern [^\]]* cannot handle ] within link text
# This is a known limitation - the pattern stops at the first ]
text = "[Text with [brackets] and (parens)](https://example.com)"
# Actual behavior: only matches up to first ], URL gets removed
expected = "[Text with [brackets] and (parens)]("
assert CleanProcessor.clean(text, process_rule) == expected
# Test that properly formatted markdown links work
text = "[Text with (parens) and symbols](https://example.com)"
expected = "[Text with (parens) and symbols](https://example.com)"
assert CleanProcessor.clean(text, process_rule) == expected

View File

@ -0,0 +1,186 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
import core.rag.extractor.pdf_extractor as pe
@pytest.fixture
def mock_dependencies(monkeypatch):
# Mock storage
saves = []
def save(key, data):
saves.append((key, data))
monkeypatch.setattr(pe, "storage", SimpleNamespace(save=save))
# Mock db
class DummySession:
def __init__(self):
self.added = []
self.committed = False
def add(self, obj):
self.added.append(obj)
def add_all(self, objs):
self.added.extend(objs)
def commit(self):
self.committed = True
db_stub = SimpleNamespace(session=DummySession())
monkeypatch.setattr(pe, "db", db_stub)
# Mock UploadFile
class FakeUploadFile:
DEFAULT_ID = "test_file_id"
def __init__(self, **kwargs):
# Assign id from DEFAULT_ID, allow override via kwargs if needed
self.id = self.DEFAULT_ID
for k, v in kwargs.items():
setattr(self, k, v)
monkeypatch.setattr(pe, "UploadFile", FakeUploadFile)
# Mock config
monkeypatch.setattr(pe.dify_config, "FILES_URL", "http://files.local")
monkeypatch.setattr(pe.dify_config, "INTERNAL_FILES_URL", None)
monkeypatch.setattr(pe.dify_config, "STORAGE_TYPE", "local")
return SimpleNamespace(saves=saves, db=db_stub, UploadFile=FakeUploadFile)
@pytest.mark.parametrize(
("image_bytes", "expected_mime", "expected_ext", "file_id"),
[
(b"\xff\xd8\xff some jpeg", "image/jpeg", "jpg", "test_file_id_jpeg"),
(b"\x89PNG\r\n\x1a\n some png", "image/png", "png", "test_file_id_png"),
],
)
def test_extract_images_formats(mock_dependencies, monkeypatch, image_bytes, expected_mime, expected_ext, file_id):
saves = mock_dependencies.saves
db_stub = mock_dependencies.db
# Customize FakeUploadFile id for this test case.
# Using monkeypatch ensures the class attribute is reset between parameter sets.
monkeypatch.setattr(mock_dependencies.UploadFile, "DEFAULT_ID", file_id)
# Mock page and image objects
mock_page = MagicMock()
mock_image_obj = MagicMock()
def mock_extract(buf, fb_format=None):
buf.write(image_bytes)
mock_image_obj.extract.side_effect = mock_extract
mock_page.get_objects.return_value = [mock_image_obj]
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
# We need to handle the import inside _extract_images
with patch("pypdfium2.raw") as mock_raw:
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
result = extractor._extract_images(mock_page)
assert f"![image](http://files.local/files/{file_id}/file-preview)" in result
assert len(saves) == 1
assert saves[0][1] == image_bytes
assert len(db_stub.session.added) == 1
assert db_stub.session.added[0].tenant_id == "t1"
assert db_stub.session.added[0].size == len(image_bytes)
assert db_stub.session.added[0].mime_type == expected_mime
assert db_stub.session.added[0].extension == expected_ext
assert db_stub.session.committed is True
@pytest.mark.parametrize(
("get_objects_side_effect", "get_objects_return_value"),
[
(None, []), # Empty list
(None, None), # None returned
(Exception("Failed to get objects"), None), # Exception raised
],
)
def test_extract_images_get_objects_scenarios(mock_dependencies, get_objects_side_effect, get_objects_return_value):
mock_page = MagicMock()
if get_objects_side_effect:
mock_page.get_objects.side_effect = get_objects_side_effect
else:
mock_page.get_objects.return_value = get_objects_return_value
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
with patch("pypdfium2.raw") as mock_raw:
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
result = extractor._extract_images(mock_page)
assert result == ""
def test_extract_calls_extract_images(mock_dependencies, monkeypatch):
# Mock pypdfium2
mock_pdf_doc = MagicMock()
mock_page = MagicMock()
mock_pdf_doc.__iter__.return_value = [mock_page]
# Mock text extraction
mock_text_page = MagicMock()
mock_text_page.get_text_range.return_value = "Page text content"
mock_page.get_textpage.return_value = mock_text_page
with patch("pypdfium2.PdfDocument", return_value=mock_pdf_doc):
# Mock Blob
mock_blob = MagicMock()
mock_blob.source = "test.pdf"
with patch("core.rag.extractor.pdf_extractor.Blob.from_path", return_value=mock_blob):
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
# Mock _extract_images to return a known string
monkeypatch.setattr(extractor, "_extract_images", lambda p: "![image](img_url)")
documents = list(extractor.extract())
assert len(documents) == 1
assert "Page text content" in documents[0].page_content
assert "![image](img_url)" in documents[0].page_content
assert documents[0].metadata["page"] == 0
def test_extract_images_failures(mock_dependencies):
saves = mock_dependencies.saves
db_stub = mock_dependencies.db
# Mock page and image objects
mock_page = MagicMock()
mock_image_obj_fail = MagicMock()
mock_image_obj_ok = MagicMock()
# First image raises exception
mock_image_obj_fail.extract.side_effect = Exception("Extraction failure")
# Second image is OK (JPEG)
jpeg_bytes = b"\xff\xd8\xff some image data"
def mock_extract(buf, fb_format=None):
buf.write(jpeg_bytes)
mock_image_obj_ok.extract.side_effect = mock_extract
mock_page.get_objects.return_value = [mock_image_obj_fail, mock_image_obj_ok]
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
with patch("pypdfium2.raw") as mock_raw:
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
result = extractor._extract_images(mock_page)
# Should have one success
assert "![image](http://files.local/files/test_file_id/file-preview)" in result
assert len(saves) == 1
assert saves[0][1] == jpeg_bytes
assert db_stub.session.committed is True

View File

@ -0,0 +1 @@
"""Tests for graph traversal components."""

View File

@ -0,0 +1,307 @@
"""Unit tests for skip propagator."""
from unittest.mock import MagicMock, create_autospec
from core.workflow.graph import Edge, Graph
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
from core.workflow.graph_engine.graph_traversal.skip_propagator import SkipPropagator
class TestSkipPropagator:
"""Test suite for SkipPropagator."""
def test_propagate_skip_from_edge_with_unknown_edges_stops_processing(self) -> None:
"""When there are unknown incoming edges, propagation should stop."""
# Arrange
mock_graph = create_autospec(Graph)
mock_state_manager = create_autospec(GraphStateManager)
# Create a mock edge
mock_edge = MagicMock(spec=Edge)
mock_edge.id = "edge_1"
mock_edge.head = "node_2"
# Setup graph edges dict
mock_graph.edges = {"edge_1": mock_edge}
# Setup incoming edges
incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge)]
mock_graph.get_incoming_edges.return_value = incoming_edges
# Setup state manager to return has_unknown=True
mock_state_manager.analyze_edge_states.return_value = {
"has_unknown": True,
"has_taken": False,
"all_skipped": False,
}
propagator = SkipPropagator(mock_graph, mock_state_manager)
# Act
propagator.propagate_skip_from_edge("edge_1")
# Assert
mock_graph.get_incoming_edges.assert_called_once_with("node_2")
mock_state_manager.analyze_edge_states.assert_called_once_with(incoming_edges)
# Should not call any other state manager methods
mock_state_manager.enqueue_node.assert_not_called()
mock_state_manager.start_execution.assert_not_called()
mock_state_manager.mark_node_skipped.assert_not_called()
def test_propagate_skip_from_edge_with_taken_edge_enqueues_node(self) -> None:
"""When there is at least one taken edge, node should be enqueued."""
# Arrange
mock_graph = create_autospec(Graph)
mock_state_manager = create_autospec(GraphStateManager)
# Create a mock edge
mock_edge = MagicMock(spec=Edge)
mock_edge.id = "edge_1"
mock_edge.head = "node_2"
mock_graph.edges = {"edge_1": mock_edge}
incoming_edges = [MagicMock(spec=Edge)]
mock_graph.get_incoming_edges.return_value = incoming_edges
# Setup state manager to return has_taken=True
mock_state_manager.analyze_edge_states.return_value = {
"has_unknown": False,
"has_taken": True,
"all_skipped": False,
}
propagator = SkipPropagator(mock_graph, mock_state_manager)
# Act
propagator.propagate_skip_from_edge("edge_1")
# Assert
mock_state_manager.enqueue_node.assert_called_once_with("node_2")
mock_state_manager.start_execution.assert_called_once_with("node_2")
mock_state_manager.mark_node_skipped.assert_not_called()
def test_propagate_skip_from_edge_with_all_skipped_propagates_to_node(self) -> None:
"""When all incoming edges are skipped, should propagate skip to node."""
# Arrange
mock_graph = create_autospec(Graph)
mock_state_manager = create_autospec(GraphStateManager)
# Create a mock edge
mock_edge = MagicMock(spec=Edge)
mock_edge.id = "edge_1"
mock_edge.head = "node_2"
mock_graph.edges = {"edge_1": mock_edge}
incoming_edges = [MagicMock(spec=Edge)]
mock_graph.get_incoming_edges.return_value = incoming_edges
# Setup state manager to return all_skipped=True
mock_state_manager.analyze_edge_states.return_value = {
"has_unknown": False,
"has_taken": False,
"all_skipped": True,
}
propagator = SkipPropagator(mock_graph, mock_state_manager)
# Act
propagator.propagate_skip_from_edge("edge_1")
# Assert
mock_state_manager.mark_node_skipped.assert_called_once_with("node_2")
mock_state_manager.enqueue_node.assert_not_called()
mock_state_manager.start_execution.assert_not_called()
def test_propagate_skip_to_node_marks_node_and_outgoing_edges_skipped(self) -> None:
"""_propagate_skip_to_node should mark node and all outgoing edges as skipped."""
# Arrange
mock_graph = create_autospec(Graph)
mock_state_manager = create_autospec(GraphStateManager)
# Create outgoing edges
edge1 = MagicMock(spec=Edge)
edge1.id = "edge_2"
edge1.head = "node_downstream_1" # Set head for propagate_skip_from_edge
edge2 = MagicMock(spec=Edge)
edge2.id = "edge_3"
edge2.head = "node_downstream_2"
# Setup graph edges dict for propagate_skip_from_edge
mock_graph.edges = {"edge_2": edge1, "edge_3": edge2}
mock_graph.get_outgoing_edges.return_value = [edge1, edge2]
# Setup get_incoming_edges to return empty list to stop recursion
mock_graph.get_incoming_edges.return_value = []
propagator = SkipPropagator(mock_graph, mock_state_manager)
# Use mock to call private method
# Act
propagator._propagate_skip_to_node("node_1")
# Assert
mock_state_manager.mark_node_skipped.assert_called_once_with("node_1")
mock_state_manager.mark_edge_skipped.assert_any_call("edge_2")
mock_state_manager.mark_edge_skipped.assert_any_call("edge_3")
assert mock_state_manager.mark_edge_skipped.call_count == 2
# Should recursively propagate from each edge
# Since propagate_skip_from_edge is called, we need to verify it was called
# But we can't directly verify due to recursion. We'll trust the logic.
def test_skip_branch_paths_marks_unselected_edges_and_propagates(self) -> None:
"""skip_branch_paths should mark all unselected edges as skipped and propagate."""
# Arrange
mock_graph = create_autospec(Graph)
mock_state_manager = create_autospec(GraphStateManager)
# Create unselected edges
edge1 = MagicMock(spec=Edge)
edge1.id = "edge_1"
edge1.head = "node_downstream_1"
edge2 = MagicMock(spec=Edge)
edge2.id = "edge_2"
edge2.head = "node_downstream_2"
unselected_edges = [edge1, edge2]
# Setup graph edges dict
mock_graph.edges = {"edge_1": edge1, "edge_2": edge2}
# Setup get_incoming_edges to return empty list to stop recursion
mock_graph.get_incoming_edges.return_value = []
propagator = SkipPropagator(mock_graph, mock_state_manager)
# Act
propagator.skip_branch_paths(unselected_edges)
# Assert
mock_state_manager.mark_edge_skipped.assert_any_call("edge_1")
mock_state_manager.mark_edge_skipped.assert_any_call("edge_2")
assert mock_state_manager.mark_edge_skipped.call_count == 2
# propagate_skip_from_edge should be called for each edge
# We can't directly verify due to the mock, but the logic is covered
def test_propagate_skip_from_edge_recursively_propagates_through_graph(self) -> None:
"""Skip propagation should recursively propagate through the graph."""
# Arrange
mock_graph = create_autospec(Graph)
mock_state_manager = create_autospec(GraphStateManager)
# Create edge chain: edge_1 -> node_2 -> edge_3 -> node_4
edge1 = MagicMock(spec=Edge)
edge1.id = "edge_1"
edge1.head = "node_2"
edge3 = MagicMock(spec=Edge)
edge3.id = "edge_3"
edge3.head = "node_4"
mock_graph.edges = {"edge_1": edge1, "edge_3": edge3}
# Setup get_incoming_edges to return different values based on node
def get_incoming_edges_side_effect(node_id):
if node_id == "node_2":
return [edge1]
elif node_id == "node_4":
return [edge3]
return []
mock_graph.get_incoming_edges.side_effect = get_incoming_edges_side_effect
# Setup get_outgoing_edges to return different values based on node
def get_outgoing_edges_side_effect(node_id):
if node_id == "node_2":
return [edge3]
elif node_id == "node_4":
return [] # No outgoing edges, stops recursion
return []
mock_graph.get_outgoing_edges.side_effect = get_outgoing_edges_side_effect
# Setup state manager to return all_skipped for both nodes
mock_state_manager.analyze_edge_states.return_value = {
"has_unknown": False,
"has_taken": False,
"all_skipped": True,
}
propagator = SkipPropagator(mock_graph, mock_state_manager)
# Act
propagator.propagate_skip_from_edge("edge_1")
# Assert
# Should mark node_2 as skipped
mock_state_manager.mark_node_skipped.assert_any_call("node_2")
# Should mark edge_3 as skipped
mock_state_manager.mark_edge_skipped.assert_any_call("edge_3")
# Should propagate to node_4
mock_state_manager.mark_node_skipped.assert_any_call("node_4")
assert mock_state_manager.mark_node_skipped.call_count == 2
def test_propagate_skip_from_edge_with_mixed_edge_states_handles_correctly(self) -> None:
"""Test with mixed edge states (some unknown, some taken, some skipped)."""
# Arrange
mock_graph = create_autospec(Graph)
mock_state_manager = create_autospec(GraphStateManager)
mock_edge = MagicMock(spec=Edge)
mock_edge.id = "edge_1"
mock_edge.head = "node_2"
mock_graph.edges = {"edge_1": mock_edge}
incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge), MagicMock(spec=Edge)]
mock_graph.get_incoming_edges.return_value = incoming_edges
# Test 1: has_unknown=True, has_taken=False, all_skipped=False
mock_state_manager.analyze_edge_states.return_value = {
"has_unknown": True,
"has_taken": False,
"all_skipped": False,
}
propagator = SkipPropagator(mock_graph, mock_state_manager)
# Act
propagator.propagate_skip_from_edge("edge_1")
# Assert - should stop processing
mock_state_manager.enqueue_node.assert_not_called()
mock_state_manager.mark_node_skipped.assert_not_called()
# Reset mocks for next test
mock_state_manager.reset_mock()
mock_graph.reset_mock()
# Test 2: has_unknown=False, has_taken=True, all_skipped=False
mock_state_manager.analyze_edge_states.return_value = {
"has_unknown": False,
"has_taken": True,
"all_skipped": False,
}
# Act
propagator.propagate_skip_from_edge("edge_1")
# Assert - should enqueue node
mock_state_manager.enqueue_node.assert_called_once_with("node_2")
mock_state_manager.start_execution.assert_called_once_with("node_2")
# Reset mocks for next test
mock_state_manager.reset_mock()
mock_graph.reset_mock()
# Test 3: has_unknown=False, has_taken=False, all_skipped=True
mock_state_manager.analyze_edge_states.return_value = {
"has_unknown": False,
"has_taken": False,
"all_skipped": True,
}
# Act
propagator.propagate_skip_from_edge("edge_1")
# Assert - should propagate skip
mock_state_manager.mark_node_skipped.assert_called_once_with("node_2")

View File

@ -8,11 +8,12 @@ class TestCelerySSLConfiguration:
"""Test suite for Celery SSL configuration."""
def test_get_celery_ssl_options_when_ssl_disabled(self):
"""Test SSL options when REDIS_USE_SSL is False."""
mock_config = MagicMock()
mock_config.REDIS_USE_SSL = False
"""Test SSL options when BROKER_USE_SSL is False."""
from configs import DifyConfig
with patch("extensions.ext_celery.dify_config", mock_config):
dify_config = DifyConfig(CELERY_BROKER_URL="redis://localhost:6379/0")
with patch("extensions.ext_celery.dify_config", dify_config):
from extensions.ext_celery import _get_celery_ssl_options
result = _get_celery_ssl_options()
@ -21,7 +22,6 @@ class TestCelerySSLConfiguration:
def test_get_celery_ssl_options_when_broker_not_redis(self):
"""Test SSL options when broker is not Redis."""
mock_config = MagicMock()
mock_config.REDIS_USE_SSL = True
mock_config.CELERY_BROKER_URL = "amqp://localhost:5672"
with patch("extensions.ext_celery.dify_config", mock_config):
@ -33,7 +33,6 @@ class TestCelerySSLConfiguration:
def test_get_celery_ssl_options_with_cert_none(self):
"""Test SSL options with CERT_NONE requirement."""
mock_config = MagicMock()
mock_config.REDIS_USE_SSL = True
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE"
mock_config.REDIS_SSL_CA_CERTS = None
@ -53,7 +52,6 @@ class TestCelerySSLConfiguration:
def test_get_celery_ssl_options_with_cert_required(self):
"""Test SSL options with CERT_REQUIRED and certificates."""
mock_config = MagicMock()
mock_config.REDIS_USE_SSL = True
mock_config.CELERY_BROKER_URL = "rediss://localhost:6380/0"
mock_config.REDIS_SSL_CERT_REQS = "CERT_REQUIRED"
mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt"
@ -73,7 +71,6 @@ class TestCelerySSLConfiguration:
def test_get_celery_ssl_options_with_cert_optional(self):
"""Test SSL options with CERT_OPTIONAL requirement."""
mock_config = MagicMock()
mock_config.REDIS_USE_SSL = True
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
mock_config.REDIS_SSL_CERT_REQS = "CERT_OPTIONAL"
mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt"
@ -91,7 +88,6 @@ class TestCelerySSLConfiguration:
def test_get_celery_ssl_options_with_invalid_cert_reqs(self):
"""Test SSL options with invalid cert requirement defaults to CERT_NONE."""
mock_config = MagicMock()
mock_config.REDIS_USE_SSL = True
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
mock_config.REDIS_SSL_CERT_REQS = "INVALID_VALUE"
mock_config.REDIS_SSL_CA_CERTS = None
@ -108,7 +104,6 @@ class TestCelerySSLConfiguration:
def test_celery_init_applies_ssl_to_broker_and_backend(self):
"""Test that SSL options are applied to both broker and backend when using Redis."""
mock_config = MagicMock()
mock_config.REDIS_USE_SSL = True
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
mock_config.CELERY_BACKEND = "redis"
mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0"

View File

@ -0,0 +1,272 @@
import base64
import hashlib
from datetime import datetime
from unittest.mock import ANY, MagicMock
import pytest
from botocore.exceptions import ClientError
from libs import archive_storage as storage_module
from libs.archive_storage import (
ArchiveStorage,
ArchiveStorageError,
ArchiveStorageNotConfiguredError,
)
BUCKET_NAME = "archive-bucket"
def _configure_storage(monkeypatch, **overrides):
defaults = {
"ARCHIVE_STORAGE_ENABLED": True,
"ARCHIVE_STORAGE_ENDPOINT": "https://storage.example.com",
"ARCHIVE_STORAGE_ARCHIVE_BUCKET": BUCKET_NAME,
"ARCHIVE_STORAGE_ACCESS_KEY": "access",
"ARCHIVE_STORAGE_SECRET_KEY": "secret",
"ARCHIVE_STORAGE_REGION": "auto",
}
defaults.update(overrides)
for key, value in defaults.items():
monkeypatch.setattr(storage_module.dify_config, key, value, raising=False)
def _client_error(code: str) -> ClientError:
return ClientError({"Error": {"Code": code}}, "Operation")
def _mock_client(monkeypatch):
client = MagicMock()
client.head_bucket.return_value = None
boto_client = MagicMock(return_value=client)
monkeypatch.setattr(storage_module.boto3, "client", boto_client)
return client, boto_client
def test_init_disabled(monkeypatch):
_configure_storage(monkeypatch, ARCHIVE_STORAGE_ENABLED=False)
with pytest.raises(ArchiveStorageNotConfiguredError, match="not enabled"):
ArchiveStorage(bucket=BUCKET_NAME)
def test_init_missing_config(monkeypatch):
_configure_storage(monkeypatch, ARCHIVE_STORAGE_ENDPOINT=None)
with pytest.raises(ArchiveStorageNotConfiguredError, match="incomplete"):
ArchiveStorage(bucket=BUCKET_NAME)
def test_init_bucket_not_found(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
client.head_bucket.side_effect = _client_error("404")
with pytest.raises(ArchiveStorageNotConfiguredError, match="does not exist"):
ArchiveStorage(bucket=BUCKET_NAME)
def test_init_bucket_access_denied(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
client.head_bucket.side_effect = _client_error("403")
with pytest.raises(ArchiveStorageNotConfiguredError, match="Access denied"):
ArchiveStorage(bucket=BUCKET_NAME)
def test_init_bucket_other_error(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
client.head_bucket.side_effect = _client_error("500")
with pytest.raises(ArchiveStorageError, match="Failed to access archive bucket"):
ArchiveStorage(bucket=BUCKET_NAME)
def test_init_sets_client(monkeypatch):
_configure_storage(monkeypatch)
client, boto_client = _mock_client(monkeypatch)
storage = ArchiveStorage(bucket=BUCKET_NAME)
boto_client.assert_called_once_with(
"s3",
endpoint_url="https://storage.example.com",
aws_access_key_id="access",
aws_secret_access_key="secret",
region_name="auto",
config=ANY,
)
assert storage.client is client
assert storage.bucket == BUCKET_NAME
def test_put_object_returns_checksum(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
storage = ArchiveStorage(bucket=BUCKET_NAME)
data = b"hello"
checksum = storage.put_object("key", data)
expected_md5 = hashlib.md5(data).hexdigest()
expected_content_md5 = base64.b64encode(hashlib.md5(data).digest()).decode()
client.put_object.assert_called_once_with(
Bucket="archive-bucket",
Key="key",
Body=data,
ContentMD5=expected_content_md5,
)
assert checksum == expected_md5
def test_put_object_raises_on_error(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
storage = ArchiveStorage(bucket=BUCKET_NAME)
client.put_object.side_effect = _client_error("500")
with pytest.raises(ArchiveStorageError, match="Failed to upload object"):
storage.put_object("key", b"data")
def test_get_object_returns_bytes(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
body = MagicMock()
body.read.return_value = b"payload"
client.get_object.return_value = {"Body": body}
storage = ArchiveStorage(bucket=BUCKET_NAME)
assert storage.get_object("key") == b"payload"
def test_get_object_missing(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
client.get_object.side_effect = _client_error("NoSuchKey")
storage = ArchiveStorage(bucket=BUCKET_NAME)
with pytest.raises(FileNotFoundError, match="Archive object not found"):
storage.get_object("missing")
def test_get_object_stream(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
body = MagicMock()
body.iter_chunks.return_value = [b"a", b"b"]
client.get_object.return_value = {"Body": body}
storage = ArchiveStorage(bucket=BUCKET_NAME)
assert list(storage.get_object_stream("key")) == [b"a", b"b"]
def test_get_object_stream_missing(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
client.get_object.side_effect = _client_error("NoSuchKey")
storage = ArchiveStorage(bucket=BUCKET_NAME)
with pytest.raises(FileNotFoundError, match="Archive object not found"):
list(storage.get_object_stream("missing"))
def test_object_exists(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
storage = ArchiveStorage(bucket=BUCKET_NAME)
assert storage.object_exists("key") is True
client.head_object.side_effect = _client_error("404")
assert storage.object_exists("missing") is False
def test_delete_object_error(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
client.delete_object.side_effect = _client_error("500")
storage = ArchiveStorage(bucket=BUCKET_NAME)
with pytest.raises(ArchiveStorageError, match="Failed to delete object"):
storage.delete_object("key")
def test_list_objects(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
paginator = MagicMock()
paginator.paginate.return_value = [
{"Contents": [{"Key": "a"}, {"Key": "b"}]},
{"Contents": [{"Key": "c"}]},
]
client.get_paginator.return_value = paginator
storage = ArchiveStorage(bucket=BUCKET_NAME)
assert storage.list_objects("prefix") == ["a", "b", "c"]
paginator.paginate.assert_called_once_with(Bucket="archive-bucket", Prefix="prefix")
def test_list_objects_error(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
paginator = MagicMock()
paginator.paginate.side_effect = _client_error("500")
client.get_paginator.return_value = paginator
storage = ArchiveStorage(bucket=BUCKET_NAME)
with pytest.raises(ArchiveStorageError, match="Failed to list objects"):
storage.list_objects("prefix")
def test_generate_presigned_url(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
client.generate_presigned_url.return_value = "http://signed-url"
storage = ArchiveStorage(bucket=BUCKET_NAME)
url = storage.generate_presigned_url("key", expires_in=123)
client.generate_presigned_url.assert_called_once_with(
ClientMethod="get_object",
Params={"Bucket": "archive-bucket", "Key": "key"},
ExpiresIn=123,
)
assert url == "http://signed-url"
def test_generate_presigned_url_error(monkeypatch):
_configure_storage(monkeypatch)
client, _ = _mock_client(monkeypatch)
client.generate_presigned_url.side_effect = _client_error("500")
storage = ArchiveStorage(bucket=BUCKET_NAME)
with pytest.raises(ArchiveStorageError, match="Failed to generate pre-signed URL"):
storage.generate_presigned_url("key")
def test_serialization_roundtrip():
records = [
{
"id": "1",
"created_at": datetime(2024, 1, 1, 12, 0, 0),
"payload": {"nested": "value"},
"items": [{"name": "a"}],
},
{"id": "2", "value": 123},
]
data = ArchiveStorage.serialize_to_jsonl_gz(records)
decoded = ArchiveStorage.deserialize_from_jsonl_gz(data)
assert decoded[0]["id"] == "1"
assert decoded[0]["payload"]["nested"] == "value"
assert decoded[0]["items"][0]["name"] == "a"
assert "2024-01-01T12:00:00" in decoded[0]["created_at"]
assert decoded[1]["value"] == 123
def test_content_md5_matches_checksum():
data = b"checksum"
expected = base64.b64encode(hashlib.md5(data).digest()).decode()
assert ArchiveStorage._content_md5(data) == expected
assert ArchiveStorage.compute_checksum(data) == hashlib.md5(data).hexdigest()

View File

@ -15,6 +15,11 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols
("", ""),
(" ", " "),
("【测试】", "【测试】"),
# Markdown link preservation - should be preserved if text starts with a markdown link
("[Google](https://google.com) is a search engine", "[Google](https://google.com) is a search engine"),
("[Example](http://example.com) some text", "[Example](http://example.com) some text"),
# Leading symbols before markdown link are removed, including the opening bracket [
("@[Test](https://example.com)", "Test](https://example.com)"),
],
)
def test_remove_leading_symbols(input_text, expected_output):

View File

@ -447,6 +447,15 @@ S3_SECRET_KEY=
# If set to false, the access key and secret key must be provided.
S3_USE_AWS_MANAGED_IAM=false
# Workflow run and Conversation archive storage (S3-compatible)
ARCHIVE_STORAGE_ENABLED=false
ARCHIVE_STORAGE_ENDPOINT=
ARCHIVE_STORAGE_ARCHIVE_BUCKET=
ARCHIVE_STORAGE_EXPORT_BUCKET=
ARCHIVE_STORAGE_ACCESS_KEY=
ARCHIVE_STORAGE_SECRET_KEY=
ARCHIVE_STORAGE_REGION=auto
# Azure Blob Configuration
#
AZURE_BLOB_ACCOUNT_NAME=difyai

View File

@ -122,6 +122,13 @@ x-shared-env: &shared-api-worker-env
S3_ACCESS_KEY: ${S3_ACCESS_KEY:-}
S3_SECRET_KEY: ${S3_SECRET_KEY:-}
S3_USE_AWS_MANAGED_IAM: ${S3_USE_AWS_MANAGED_IAM:-false}
ARCHIVE_STORAGE_ENABLED: ${ARCHIVE_STORAGE_ENABLED:-false}
ARCHIVE_STORAGE_ENDPOINT: ${ARCHIVE_STORAGE_ENDPOINT:-}
ARCHIVE_STORAGE_ARCHIVE_BUCKET: ${ARCHIVE_STORAGE_ARCHIVE_BUCKET:-}
ARCHIVE_STORAGE_EXPORT_BUCKET: ${ARCHIVE_STORAGE_EXPORT_BUCKET:-}
ARCHIVE_STORAGE_ACCESS_KEY: ${ARCHIVE_STORAGE_ACCESS_KEY:-}
ARCHIVE_STORAGE_SECRET_KEY: ${ARCHIVE_STORAGE_SECRET_KEY:-}
ARCHIVE_STORAGE_REGION: ${ARCHIVE_STORAGE_REGION:-auto}
AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-difyai}
AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-difyai}
AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-difyai-container}

View File

@ -1,6 +1,7 @@
import type { Plan, UsagePlanInfo } from '@/app/components/billing/type'
import type { ProviderContextState } from '@/context/provider-context'
import { merge, noop } from 'es-toolkit/compat'
import { merge } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { defaultPlan } from '@/app/components/billing/config'
// Avoid being mocked in tests

View File

@ -4,7 +4,7 @@ import type { FC } from 'react'
import type { TriggerProps } from '@/app/components/base/date-and-time-picker/types'
import { RiCalendarLine } from '@remixicon/react'
import dayjs from 'dayjs'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import * as React from 'react'
import { useCallback } from 'react'
import Picker from '@/app/components/base/date-and-time-picker/date-picker'

View File

@ -1,6 +1,6 @@
'use client'
import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import Link from 'next/link'
import { useRouter, useSearchParams } from 'next/navigation'
import { useState } from 'react'

View File

@ -1,4 +1,4 @@
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { useRouter, useSearchParams } from 'next/navigation'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -1,5 +1,5 @@
'use client'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import Link from 'next/link'
import { useRouter, useSearchParams } from 'next/navigation'
import { useCallback, useState } from 'react'

View File

@ -1,6 +1,6 @@
import type { ResponseError } from '@/service/fetch'
import { RiCloseLine } from '@remixicon/react'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { useRouter } from 'next/navigation'
import * as React from 'react'
import { useState } from 'react'

View File

@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
import { RiCloseLine } from '@remixicon/react'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import * as React from 'react'
import { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -4,7 +4,7 @@ import {
RiAddLine,
RiEditLine,
} from '@remixicon/react'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { cn } from '@/utils/classnames'

View File

@ -4,7 +4,7 @@ import type { ExternalDataTool } from '@/models/common'
import type { PromptVariable } from '@/models/debug'
import type { GenRes } from '@/service/debug'
import { useBoolean } from 'ahooks'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { produce } from 'immer'
import * as React from 'react'
import { useState } from 'react'

View File

@ -7,19 +7,24 @@ export const jsonObjectWrap = {
export const jsonConfigPlaceHolder = JSON.stringify(
{
foo: {
type: 'string',
},
bar: {
type: 'object',
properties: {
sub: {
type: 'number',
},
type: 'object',
properties: {
foo: {
type: 'string',
},
bar: {
type: 'object',
properties: {
sub: {
type: 'number',
},
},
required: [],
additionalProperties: true,
},
required: [],
additionalProperties: true,
},
required: [],
additionalProperties: true,
},
null,
2,

View File

@ -28,7 +28,7 @@ import { checkKeys, getNewVarInWorkflow, replaceSpaceWithUnderscoreInVarNameInpu
import ConfigSelect from '../config-select'
import ConfigString from '../config-string'
import ModalFoot from '../modal-foot'
import { jsonConfigPlaceHolder, jsonObjectWrap } from './config'
import { jsonConfigPlaceHolder } from './config'
import Field from './field'
import TypeSelector from './type-select'
@ -78,13 +78,12 @@ const ConfigModal: FC<IConfigModalProps> = ({
const modalRef = useRef<HTMLDivElement>(null)
const appDetail = useAppStore(state => state.appDetail)
const isBasicApp = appDetail?.mode !== AppModeEnum.ADVANCED_CHAT && appDetail?.mode !== AppModeEnum.WORKFLOW
const isSupportJSON = false
const jsonSchemaStr = useMemo(() => {
const isJsonObject = type === InputVarType.jsonObject
if (!isJsonObject || !tempPayload.json_schema)
return ''
try {
return JSON.stringify(JSON.parse(tempPayload.json_schema).properties, null, 2)
return JSON.stringify(JSON.parse(tempPayload.json_schema), null, 2)
}
catch {
return ''
@ -129,13 +128,14 @@ const ConfigModal: FC<IConfigModalProps> = ({
}, [])
const handleJSONSchemaChange = useCallback((value: string) => {
const isEmpty = value == null || value.trim() === ''
if (isEmpty) {
handlePayloadChange('json_schema')(undefined)
return null
}
try {
const v = JSON.parse(value)
const res = {
...jsonObjectWrap,
properties: v,
}
handlePayloadChange('json_schema')(JSON.stringify(res, null, 2))
handlePayloadChange('json_schema')(JSON.stringify(v, null, 2))
}
catch {
return null
@ -175,7 +175,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
},
]
: []),
...((!isBasicApp && isSupportJSON)
...((!isBasicApp)
? [{
name: t('variableConfig.json', { ns: 'appDebug' }),
value: InputVarType.jsonObject,
@ -233,7 +233,28 @@ const ConfigModal: FC<IConfigModalProps> = ({
const checkboxDefaultSelectValue = useMemo(() => getCheckboxDefaultSelectValue(tempPayload.default), [tempPayload.default])
const isJsonSchemaEmpty = (value: InputVar['json_schema']) => {
if (value === null || value === undefined) {
return true
}
if (typeof value !== 'string') {
return false
}
const trimmed = value.trim()
return trimmed === ''
}
const handleConfirm = () => {
const jsonSchemaValue = tempPayload.json_schema
const isSchemaEmpty = isJsonSchemaEmpty(jsonSchemaValue)
const normalizedJsonSchema = isSchemaEmpty ? undefined : jsonSchemaValue
// if the input type is jsonObject and the schema is empty as determined by `isJsonSchemaEmpty`,
// remove the `json_schema` field from the payload by setting its value to `undefined`.
const payloadToSave = tempPayload.type === InputVarType.jsonObject && isSchemaEmpty
? { ...tempPayload, json_schema: undefined }
: tempPayload
const moreInfo = tempPayload.variable === payload?.variable
? undefined
: {
@ -250,7 +271,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
return
}
if (isStringInput || type === InputVarType.number) {
onConfirm(tempPayload, moreInfo)
onConfirm(payloadToSave, moreInfo)
}
else if (type === InputVarType.select) {
if (options?.length === 0) {
@ -270,7 +291,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
Toast.notify({ type: 'error', message: t('variableConfig.errorMsg.optionRepeat', { ns: 'appDebug' }) })
return
}
onConfirm(tempPayload, moreInfo)
onConfirm(payloadToSave, moreInfo)
}
else if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) {
if (tempPayload.allowed_file_types?.length === 0) {
@ -283,10 +304,26 @@ const ConfigModal: FC<IConfigModalProps> = ({
Toast.notify({ type: 'error', message: errorMessages })
return
}
onConfirm(tempPayload, moreInfo)
onConfirm(payloadToSave, moreInfo)
}
else if (type === InputVarType.jsonObject) {
if (!isSchemaEmpty && typeof normalizedJsonSchema === 'string') {
try {
const schema = JSON.parse(normalizedJsonSchema)
if (schema?.type !== 'object') {
Toast.notify({ type: 'error', message: t('variableConfig.errorMsg.jsonSchemaMustBeObject', { ns: 'appDebug' }) })
return
}
}
catch {
Toast.notify({ type: 'error', message: t('variableConfig.errorMsg.jsonSchemaInvalid', { ns: 'appDebug' }) })
return
}
}
onConfirm(payloadToSave, moreInfo)
}
else {
onConfirm(tempPayload, moreInfo)
onConfirm(payloadToSave, moreInfo)
}
}

View File

@ -5,15 +5,6 @@ import * as React from 'react'
import { AgentStrategy } from '@/types/app'
import AgentSettingButton from './agent-setting-button'
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string, options?: { ns?: string }) => {
const prefix = options?.ns ? `${options.ns}.` : ''
return `${prefix}${key}`
},
}),
}))
let latestAgentSettingProps: any
vi.mock('./agent/agent-setting', () => ({
default: (props: any) => {

View File

@ -2,7 +2,7 @@
import type { FC } from 'react'
import type { ExternalDataTool } from '@/models/common'
import copy from 'copy-to-clipboard'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'

View File

@ -15,15 +15,6 @@ vi.mock('use-context-selector', async (importOriginal) => {
}
})
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string, options?: { ns?: string }) => {
const prefix = options?.ns ? `${options.ns}.` : ''
return `${prefix}${key}`
},
}),
}))
const mockUseFeatures = vi.fn()
const mockUseFeaturesStore = vi.fn()
vi.mock('@/app/components/base/features/hooks', () => ({

View File

@ -1,4 +1,4 @@
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { memo } from 'react'
import { useTranslation } from 'react-i18next'
import Slider from '@/app/components/base/slider'

View File

@ -3,7 +3,7 @@ import type { Member } from '@/models/common'
import type { DataSet } from '@/models/datasets'
import type { RetrievalConfig } from '@/types/app'
import { RiCloseLine } from '@remixicon/react'
import { isEqual } from 'es-toolkit/compat'
import { isEqual } from 'es-toolkit/predicate'
import { useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'

View File

@ -1,7 +1,7 @@
'use client'
import type { ModelAndParameter } from '../types'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { createContext, useContext } from 'use-context-selector'
export type DebugWithMultipleModelContextType = {

View File

@ -4,7 +4,8 @@ import type {
OnSend,
TextGenerationConfig,
} from '@/app/components/base/text-generation/types'
import { cloneDeep, noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { cloneDeep } from 'es-toolkit/object'
import { memo } from 'react'
import TextGeneration from '@/app/components/app/text-generate/item'
import { TransferMethod } from '@/app/components/base/chat/types'

View File

@ -6,7 +6,7 @@ import type {
ChatConfig,
ChatItem,
} from '@/app/components/base/chat/types'
import { cloneDeep } from 'es-toolkit/compat'
import { cloneDeep } from 'es-toolkit/object'
import {
useCallback,
useRef,

View File

@ -11,7 +11,8 @@ import {
RiSparklingFill,
} from '@remixicon/react'
import { useBoolean } from 'ahooks'
import { cloneDeep, noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { cloneDeep } from 'es-toolkit/object'
import { produce, setAutoFreeze } from 'immer'
import * as React from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'

View File

@ -1,6 +1,6 @@
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { ChatPromptConfig, CompletionPromptConfig, ConversationHistoriesRole, PromptItem } from '@/models/debug'
import { clone } from 'es-toolkit/compat'
import { clone } from 'es-toolkit/object'
import { produce } from 'immer'
import { useState } from 'react'
import { checkHasContextBlock, checkHasHistoryBlock, checkHasQueryBlock, PRE_PROMPT_PLACEHOLDER_TEXT } from '@/app/components/base/prompt-editor/constants'

View File

@ -20,7 +20,8 @@ import type {
import type { ModelConfig as BackendModelConfig, UserInputFormItem, VisionSettings } from '@/types/app'
import { CodeBracketIcon } from '@heroicons/react/20/solid'
import { useBoolean, useGetState } from 'ahooks'
import { clone, isEqual } from 'es-toolkit/compat'
import { clone } from 'es-toolkit/object'
import { isEqual } from 'es-toolkit/predicate'
import { produce } from 'immer'
import { usePathname } from 'next/navigation'
import * as React from 'react'

View File

@ -3,7 +3,7 @@ import type {
CodeBasedExtensionItem,
ExternalDataTool,
} from '@/models/common'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import AppIcon from '@/app/components/base/app-icon'

View File

@ -14,13 +14,6 @@ vi.mock('ahooks', () => ({
vi.mock('@/context/app-context', () => ({
useAppContext: () => ({ isCurrentWorkspaceEditor: true }),
}))
vi.mock('use-context-selector', async () => {
const actual = await vi.importActual<typeof import('use-context-selector')>('use-context-selector')
return {
...actual,
useContext: () => ({ hasEditPermission: true }),
}
})
vi.mock('nuqs', () => ({
useQueryState: () => ['Recommended', vi.fn()],
}))
@ -119,6 +112,7 @@ describe('Apps', () => {
fireEvent.click(screen.getAllByTestId('app-card')[0])
expect(screen.getByTestId('create-from-template-modal')).toBeInTheDocument()
})
it('shows no template message when list is empty', () => {
mockUseExploreAppList.mockReturnValueOnce({
data: { allList: [], categories: [] },

View File

@ -8,7 +8,6 @@ import { useRouter } from 'next/navigation'
import * as React from 'react'
import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import AppTypeSelector from '@/app/components/app/type-selector'
import { trackEvent } from '@/app/components/base/amplitude'
import Divider from '@/app/components/base/divider'
@ -19,7 +18,6 @@ import CreateAppModal from '@/app/components/explore/create-app-modal'
import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { useAppContext } from '@/context/app-context'
import ExploreContext from '@/context/explore-context'
import { DSLImportMode } from '@/models/app'
import { importDSL } from '@/service/apps'
import { fetchAppDetail } from '@/service/explore'
@ -47,7 +45,6 @@ const Apps = ({
const { t } = useTranslation()
const { isCurrentWorkspaceEditor } = useAppContext()
const { push } = useRouter()
const { hasEditPermission } = useContext(ExploreContext)
const allCategoriesEn = AppCategories.RECOMMENDED
const [keywords, setKeywords] = useState('')
@ -214,7 +211,7 @@ const Apps = ({
<AppCard
key={app.app_id}
app={app}
canCreate={hasEditPermission}
canCreate={isCurrentWorkspaceEditor}
onCreate={() => {
setCurrApp(app)
setIsShowCreateModal(true)

View File

@ -3,7 +3,7 @@
import type { MouseEventHandler } from 'react'
import { RiCloseLine, RiCommandLine, RiCornerDownLeftLine } from '@remixicon/react'
import { useDebounceFn, useKeyPress } from 'ahooks'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { useRouter } from 'next/navigation'
import { useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -1,7 +1,7 @@
'use client'
import type { AppIconType } from '@/types/app'
import { RiCloseLine } from '@remixicon/react'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import * as React from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -3,7 +3,7 @@ import type { FC } from 'react'
import type { App } from '@/types/app'
import { useDebounce } from 'ahooks'
import dayjs from 'dayjs'
import { omit } from 'es-toolkit/compat'
import { omit } from 'es-toolkit/object'
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
import * as React from 'react'
import { useCallback, useEffect, useState } from 'react'

View File

@ -12,7 +12,8 @@ import { RiCloseLine, RiEditFill } from '@remixicon/react'
import dayjs from 'dayjs'
import timezone from 'dayjs/plugin/timezone'
import utc from 'dayjs/plugin/utc'
import { get, noop } from 'es-toolkit/compat'
import { get } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
import * as React from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'

View File

@ -2,7 +2,7 @@ import type { RenderOptions } from '@testing-library/react'
import type { Mock, MockedFunction } from 'vitest'
import type { ModalContextState } from '@/context/modal-context'
import { fireEvent, render } from '@testing-library/react'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { defaultPlan } from '@/app/components/billing/config'
import { useModalContext as actualUseModalContext } from '@/context/modal-context'

View File

@ -2,7 +2,7 @@
import type { App } from '@/types/app'
import { RiCloseLine } from '@remixicon/react'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { useRouter } from 'next/navigation'
import { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -5,7 +5,7 @@ import { useDebounce } from 'ahooks'
import dayjs from 'dayjs'
import timezone from 'dayjs/plugin/timezone'
import utc from 'dayjs/plugin/utc'
import { omit } from 'es-toolkit/compat'
import { omit } from 'es-toolkit/object'
import * as React from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -2,7 +2,8 @@
import type { FC } from 'react'
import type { IChatItem } from '@/app/components/base/chat/chat/type'
import type { AgentIteration, AgentLogDetailResponse } from '@/models/log'
import { flatten, uniq } from 'es-toolkit/compat'
import { uniq } from 'es-toolkit/array'
import { flatten } from 'es-toolkit/compat'
import * as React from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -3,7 +3,7 @@ import type { Area } from 'react-easy-crop'
import type { OnImageInput } from './ImageInput'
import type { AppIconType, ImageFile } from '@/types/app'
import { RiImageCircleAiLine } from '@remixicon/react'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config'

View File

@ -14,7 +14,7 @@ import type {
AppMeta,
ConversationItem,
} from '@/models/share'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { createContext, useContext } from 'use-context-selector'
export type ChatWithHistoryContextValue = {

View File

@ -10,7 +10,7 @@ import type {
ConversationItem,
} from '@/models/share'
import { useLocalStorageState } from 'ahooks'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { produce } from 'immer'
import {
useCallback,

View File

@ -8,7 +8,8 @@ import type { InputForm } from './type'
import type AudioPlayer from '@/app/components/base/audio-btn/audio'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { Annotation } from '@/models/log'
import { noop, uniqBy } from 'es-toolkit/compat'
import { uniqBy } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { produce, setAutoFreeze } from 'immer'
import { useParams, usePathname } from 'next/navigation'
import {

View File

@ -13,7 +13,7 @@ import type {
AppMeta,
ConversationItem,
} from '@/models/share'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { createContext, useContext } from 'use-context-selector'
export type EmbeddedChatbotContextValue = {

View File

@ -9,7 +9,7 @@ import type {
ConversationItem,
} from '@/models/share'
import { useLocalStorageState } from 'ahooks'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { produce } from 'immer'
import {
useCallback,

View File

@ -3,6 +3,7 @@ import type { Day } from '../types'
import dayjs from 'dayjs'
import timezone from 'dayjs/plugin/timezone'
import utc from 'dayjs/plugin/utc'
import { IS_PROD } from '@/config'
import tz from '@/utils/timezone.json'
dayjs.extend(utc)
@ -131,7 +132,7 @@ export type ToDayjsOptions = {
}
const warnParseFailure = (value: string) => {
if (process.env.NODE_ENV !== 'production')
if (!IS_PROD)
console.warn('[TimePicker] Failed to parse time value', value)
}

View File

@ -1,6 +1,6 @@
'use client'
import type { FC } from 'react'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import * as React from 'react'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -4,6 +4,7 @@ import { RiAlertLine, RiBugLine } from '@remixicon/react'
import * as React from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import Button from '@/app/components/base/button'
import { IS_DEV } from '@/config'
import { cn } from '@/utils/classnames'
type ErrorBoundaryState = {
@ -54,7 +55,7 @@ class ErrorBoundaryInner extends React.Component<
}
componentDidCatch(error: Error, errorInfo: ErrorInfo) {
if (process.env.NODE_ENV === 'development') {
if (IS_DEV) {
console.error('ErrorBoundary caught an error:', error)
console.error('Error Info:', errorInfo)
}
@ -262,13 +263,13 @@ export function withErrorBoundary<P extends object>(
// Simple error fallback component
export const ErrorFallback: React.FC<{
error: Error
resetErrorBoundary: () => void
}> = ({ error, resetErrorBoundary }) => {
resetErrorBoundaryAction: () => void
}> = ({ error, resetErrorBoundaryAction }) => {
return (
<div className="flex min-h-[200px] flex-col items-center justify-center rounded-lg border border-red-200 bg-red-50 p-8">
<h2 className="mb-2 text-lg font-semibold text-red-800">Oops! Something went wrong</h2>
<p className="mb-4 text-center text-red-600">{error.message}</p>
<Button onClick={resetErrorBoundary} size="small">
<Button onClick={resetErrorBoundaryAction} size="small">
Try again
</Button>
</div>

View File

@ -3,7 +3,7 @@ import type { InputVar } from '@/app/components/workflow/types'
import type { PromptVariable } from '@/models/debug'
import { RiAddLine, RiAsterisk, RiCloseLine, RiDeleteBinLine, RiDraggable } from '@remixicon/react'
import { useBoolean } from 'ahooks'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { produce } from 'immer'
import * as React from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'

View File

@ -2,7 +2,7 @@ import type { ChangeEvent, FC } from 'react'
import type { CodeBasedExtensionItem } from '@/models/common'
import type { ModerationConfig, ModerationContentConfig } from '@/models/debug'
import { RiCloseLine } from '@remixicon/react'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'

View File

@ -2,7 +2,7 @@ import type { ClipboardEvent } from 'react'
import type { FileEntity } from './types'
import type { FileUpload } from '@/app/components/base/features/types'
import type { FileUploadConfigResponse } from '@/models/common'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { produce } from 'immer'
import { useParams } from 'next/navigation'
import {

View File

@ -1,6 +1,6 @@
import type { FC } from 'react'
import { RiCloseLine, RiZoomInLine, RiZoomOutLine } from '@remixicon/react'
import { noop } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import { t } from 'i18next'
import * as React from 'react'
import { useState } from 'react'

Some files were not shown because too many files have changed in this diff Show More