Merge branch 'main' into feat/model-total-credits

This commit is contained in:
Stephen Zhou 2026-01-04 10:59:46 +08:00
commit d9a0d6caa8
No known key found for this signature in database
2903 changed files with 163811 additions and 145841 deletions

View File

@ -28,17 +28,14 @@ import userEvent from '@testing-library/user-event'
// i18n (automatically mocked)
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
// No explicit mock needed - it returns translation keys as-is
// The global mock provides: useTranslation, Trans, useMixedTranslation, useGetLanguage
// No explicit mock needed for most tests
//
// Override only if custom translations are required:
// vi.mock('react-i18next', () => ({
// useTranslation: () => ({
// t: (key: string) => {
// const customTranslations: Record<string, string> = {
// 'my.custom.key': 'Custom Translation',
// }
// return customTranslations[key] || key
// },
// }),
// import { createReactI18nextMock } from '@/test/i18n-mock'
// vi.mock('react-i18next', () => createReactI18nextMock({
// 'my.custom.key': 'Custom Translation',
// 'button.save': 'Save',
// }))
// Router (if component uses useRouter, usePathname, useSearchParams)

View File

@ -52,23 +52,29 @@ Modules are not mocked automatically. Use `vi.mock` in test files, or add global
### 1. i18n (Auto-loaded via Global Mock)
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
**No explicit mock needed** for most tests - it returns translation keys as-is.
For tests requiring custom translations, override the mock:
The global mock provides:
- `useTranslation` - returns translation keys with namespace prefix
- `Trans` component - renders i18nKey and components
- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`)
- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'`
**Default behavior**: Most tests should use the global mock (no local override needed).
**For custom translations**: Use the helper function from `@/test/i18n-mock`:
```typescript
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => {
const translations: Record<string, string> = {
'my.custom.key': 'Custom translation',
}
return translations[key] || key
},
}),
import { createReactI18nextMock } from '@/test/i18n-mock'
vi.mock('react-i18next', () => createReactI18nextMock({
'my.custom.key': 'Custom translation',
'button.save': 'Save',
}))
```
**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this.
### 2. Next.js Router
```typescript

View File

@ -110,6 +110,16 @@ jobs:
working-directory: ./web
run: pnpm run type-check:tsgo
- name: Web dead code check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run knip
- name: Web build check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run build
superlinter:
name: SuperLinter
runs-on: ubuntu-latest

View File

@ -4,7 +4,7 @@ on:
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.ts'
- 'web/i18n/en-US/*.json'
permissions:
contents: write
@ -28,13 +28,13 @@ jobs:
run: |
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .ts)
filename=$(basename "$file" .json)
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
@ -65,7 +65,7 @@ jobs:
- name: Generate i18n translations
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
run: pnpm run i18n:gen ${{ env.FILE_ARGS }}
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'

View File

@ -101,6 +101,15 @@ S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
# Workflow run and Conversation archive storage (S3-compatible)
ARCHIVE_STORAGE_ENABLED=false
ARCHIVE_STORAGE_ENDPOINT=
ARCHIVE_STORAGE_ARCHIVE_BUCKET=
ARCHIVE_STORAGE_EXPORT_BUCKET=
ARCHIVE_STORAGE_ACCESS_KEY=
ARCHIVE_STORAGE_SECRET_KEY=
ARCHIVE_STORAGE_REGION=auto
# Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key
@ -128,6 +137,7 @@ TENCENT_COS_SECRET_KEY=your-secret-key
TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme
TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain
# Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name

View File

@ -1,4 +1,8 @@
exclude = ["migrations/*"]
exclude = [
"migrations/*",
".git",
".git/**",
]
line-length = 120
[format]

View File

@ -1,9 +1,11 @@
from configs.extra.archive_config import ArchiveStorageConfig
from configs.extra.notion_config import NotionConfig
from configs.extra.sentry_config import SentryConfig
class ExtraServiceConfig(
# place the configs in alphabet order
ArchiveStorageConfig,
NotionConfig,
SentryConfig,
):

View File

@ -0,0 +1,43 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class ArchiveStorageConfig(BaseSettings):
"""
Configuration settings for workflow run logs archiving storage.
"""
ARCHIVE_STORAGE_ENABLED: bool = Field(
description="Enable workflow run logs archiving to S3-compatible storage",
default=False,
)
ARCHIVE_STORAGE_ENDPOINT: str | None = Field(
description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')",
default=None,
)
ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field(
description="Name of the bucket to store archived workflow logs",
default=None,
)
ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field(
description="Name of the bucket to store exported workflow runs",
default=None,
)
ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field(
description="Access key ID for authenticating with storage",
default=None,
)
ARCHIVE_STORAGE_SECRET_KEY: str | None = Field(
description="Secret access key for authenticating with storage",
default=None,
)
ARCHIVE_STORAGE_REGION: str = Field(
description="Region for storage (use 'auto' if the provider supports it)",
default="auto",
)

View File

@ -31,3 +31,8 @@ class TencentCloudCOSStorageConfig(BaseSettings):
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
default=None,
)
TENCENT_COS_CUSTOM_DOMAIN: str | None = Field(
description="Tencent Cloud COS custom domain setting",
default=None,
)

View File

@ -1,62 +1,59 @@
from flask_restx import Api, Namespace, fields
from __future__ import annotations
from libs.helper import AppIconUrlField
from typing import Any, TypeAlias
parameters__system_parameters = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
"workflow_file_upload_limit": fields.Integer,
}
from pydantic import BaseModel, ConfigDict, computed_field
from core.file import helpers as file_helpers
from models.model import IconType
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
JSONObject: TypeAlias = dict[str, Any]
def build_system_parameters_model(api_or_ns: Api | Namespace):
"""Build the system parameters model for the API or Namespace."""
return api_or_ns.model("SystemParameters", parameters__system_parameters)
class SystemParameters(BaseModel):
image_file_size_limit: int
video_file_size_limit: int
audio_file_size_limit: int
file_size_limit: int
workflow_file_upload_limit: int
parameters_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"suggested_questions_after_answer": fields.Raw,
"speech_to_text": fields.Raw,
"text_to_speech": fields.Raw,
"retriever_resource": fields.Raw,
"annotation_reply": fields.Raw,
"more_like_this": fields.Raw,
"user_input_form": fields.Raw,
"sensitive_word_avoidance": fields.Raw,
"file_upload": fields.Raw,
"system_parameters": fields.Nested(parameters__system_parameters),
}
class Parameters(BaseModel):
opening_statement: str | None = None
suggested_questions: list[str]
suggested_questions_after_answer: JSONObject
speech_to_text: JSONObject
text_to_speech: JSONObject
retriever_resource: JSONObject
annotation_reply: JSONObject
more_like_this: JSONObject
user_input_form: list[JSONObject]
sensitive_word_avoidance: JSONObject
file_upload: JSONObject
system_parameters: SystemParameters
def build_parameters_model(api_or_ns: Api | Namespace):
"""Build the parameters model for the API or Namespace."""
copied_fields = parameters_fields.copy()
copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
return api_or_ns.model("Parameters", copied_fields)
class Site(BaseModel):
model_config = ConfigDict(from_attributes=True)
title: str
chat_color_theme: str | None = None
chat_color_theme_inverted: bool
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
description: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
default_language: str
show_workflow_steps: bool
use_icon_as_answer_icon: bool
site_fields = {
"title": fields.String,
"chat_color_theme": fields.String,
"chat_color_theme_inverted": fields.Boolean,
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"description": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"default_language": fields.String,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
}
def build_site_model(api_or_ns: Api | Namespace):
"""Build the site model for the API or Namespace."""
return api_or_ns.model("Site", site_fields)
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
if self.icon and self.icon_type == IconType.IMAGE:
return file_helpers.get_signed_file_url(self.icon)
return None

View File

@ -1,3 +1,4 @@
import re
import uuid
from typing import Literal
@ -73,6 +74,48 @@ class AppListQuery(BaseModel):
raise ValueError("Invalid UUID format in tag_ids.") from exc
# XSS prevention: patterns that could lead to XSS attacks
# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
_XSS_PATTERNS = [
r"<script[^>]*>.*?</script>", # Script tags
r"<iframe\b[^>]*?(?:/>|>.*?</iframe>)", # Iframe tags (including self-closing)
r"javascript:", # JavaScript protocol
r"<svg[^>]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
r"<object\b[^>]*(?:\s*/>|>.*?</object\s*>)", # Object tags (opening tag)
r"<embed[^>]*>", # Embed tags (self-closing)
r"<link[^>]*>", # Link tags with javascript
]
def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
"""
Validate that a string value doesn't contain potential XSS payloads.
Args:
value: The string value to validate
field_name: Name of the field for error messages
Returns:
The original value if safe
Raises:
ValueError: If the value contains XSS patterns
"""
if value is None:
return None
value_lower = value.lower()
for pattern in _XSS_PATTERNS:
if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
raise ValueError(
f"{field_name} contains invalid characters or patterns. "
"HTML tags, JavaScript, and other potentially dangerous content are not allowed."
)
return value
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
@ -81,6 +124,11 @@ class CreateAppPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
@ -91,6 +139,11 @@ class UpdateAppPayload(BaseModel):
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
@ -99,6 +152,11 @@ class CopyAppPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")

View File

@ -124,7 +124,7 @@ class OAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
account = _generate_account(provider, user_info)
account, oauth_new_user = _generate_account(provider, user_info)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
@ -159,7 +159,10 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request),
)
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
base_url = dify_config.CONSOLE_WEB_URL
query_char = "&" if "?" in base_url else "?"
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
response = redirect(target_url)
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
@ -177,9 +180,10 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
return account
def _generate_account(provider: str, user_info: OAuthUserInfo):
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
# Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info)
oauth_new_user = False
if account:
tenants = TenantService.get_join_tenants(account)
@ -193,6 +197,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
tenant_was_created.send(new_tenant)
if not account:
oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register:
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
raise AccountRegisterError(
@ -220,4 +225,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
# Link account
AccountService.link_account_integrate(provider, user_info.id, account)
return account
return account, oauth_new_user

View File

@ -3,10 +3,12 @@ import uuid
from flask import request
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy import String, cast, func, or_, select
from sqlalchemy.dialects.postgresql import JSONB
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
@ -143,7 +145,29 @@ class DatasetDocumentSegmentListApi(Resource):
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
# Search in both content and keywords fields
# Use database-specific methods for JSON array search
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
keywords_condition = func.array_to_string(
func.array(
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
.correlate(DocumentSegment)
.scalar_subquery()
),
",",
).ilike(f"%{keyword}%")
else:
# MySQL: Cast JSON to string for pattern matching
# MySQL stores Chinese text directly in JSON without Unicode escaping
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
query = query.where(
or_(
DocumentSegment.content.ilike(f"%{keyword}%"),
keywords_condition,
)
)
if args.enabled.lower() != "all":
if args.enabled.lower() == "true":

View File

@ -1,5 +1,3 @@
from flask_restx import marshal_with
from controllers.common import fields
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
@ -13,7 +11,6 @@ from services.app_service import AppService
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""
@marshal_with(fields.parameters_fields)
def get(self, installed_app: InstalledApp):
"""Retrieve app parameters."""
app_model = installed_app.app
@ -37,7 +34,8 @@ class AppParameterApi(InstalledAppResource):
user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

View File

@ -20,7 +20,6 @@ from controllers.console.wraps import (
)
from core.db.session_factory import session_factory
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
from core.mcp.mcp_client import MCPClient
@ -987,9 +986,6 @@ class ToolProviderMCPApi(Resource):
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
# Final cache invalidation to ensure list views are up to date
ToolProviderListCache.invalidate_cache(tenant_id)
return jsonable_encoder(result)
@console_ns.expect(parser_mcp_put)
@ -1036,9 +1032,6 @@ class ToolProviderMCPApi(Resource):
validation_result=validation_result,
)
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(current_tenant_id)
return {"result": "success"}
@console_ns.expect(parser_mcp_delete)
@ -1053,9 +1046,6 @@ class ToolProviderMCPApi(Resource):
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(current_tenant_id)
return {"result": "success"}
@ -1106,8 +1096,6 @@ class ToolMCPAuthApi(Resource):
credentials=provider_entity.credentials,
authed=True,
)
# Invalidate cache after updating credentials
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
except MCPAuthError as e:
try:
@ -1121,22 +1109,16 @@ class ToolMCPAuthApi(Resource):
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
# Invalidate cache after auth actions may have updated provider state
ToolProviderListCache.invalidate_cache(tenant_id)
return response
except MCPRefreshTokenError as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
# Invalidate cache after clearing credentials
ToolProviderListCache.invalidate_cache(tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except (MCPError, ValueError) as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
# Invalidate cache after clearing credentials
ToolProviderListCache.invalidate_cache(tenant_id)
raise ValueError(f"Failed to connect to MCP server: {e}") from e

View File

@ -1,7 +1,7 @@
from typing import Literal
from flask import request
from flask_restx import Api, Namespace, Resource, fields
from flask_restx import Namespace, Resource, fields
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
@ -92,7 +92,7 @@ annotation_list_fields = {
}
def build_annotation_list_model(api_or_ns: Api | Namespace):
def build_annotation_list_model(api_or_ns: Namespace):
"""Build the annotation list model for the API or Namespace."""
copied_annotation_list_fields = annotation_list_fields.copy()
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))

View File

@ -1,6 +1,6 @@
from flask_restx import Resource
from controllers.common.fields import build_parameters_model
from controllers.common.fields import Parameters
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
@ -23,7 +23,6 @@ class AppParameterApi(Resource):
}
)
@validate_app_token
@service_api_ns.marshal_with(build_parameters_model(service_api_ns))
def get(self, app_model: App):
"""Retrieve app parameters.
@ -45,7 +44,8 @@ class AppParameterApi(Resource):
user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return Parameters.model_validate(parameters).model_dump(mode="json")
@service_api_ns.route("/meta")

View File

@ -1,7 +1,7 @@
from flask_restx import Resource
from werkzeug.exceptions import Forbidden
from controllers.common.fields import build_site_model
from controllers.common.fields import Site as SiteResponse
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
@ -23,7 +23,6 @@ class AppSiteApi(Resource):
}
)
@validate_app_token
@service_api_ns.marshal_with(build_site_model(service_api_ns))
def get(self, app_model: App):
"""Retrieve app site info.
@ -38,4 +37,4 @@ class AppSiteApi(Resource):
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return site
return SiteResponse.model_validate(site).model_dump(mode="json")

View File

@ -3,7 +3,7 @@ from typing import Any, Literal
from dateutil.parser import isoparse
from flask import request
from flask_restx import Api, Namespace, Resource, fields
from flask_restx import Namespace, Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
@ -78,7 +78,7 @@ workflow_run_fields = {
}
def build_workflow_run_model(api_or_ns: Api | Namespace):
def build_workflow_run_model(api_or_ns: Namespace):
"""Build the workflow run model for the API or Namespace."""
return api_or_ns.model("WorkflowRun", workflow_run_fields)

View File

@ -1,7 +1,7 @@
import logging
from flask import request
from flask_restx import Resource, marshal_with
from flask_restx import Resource
from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import Unauthorized
@ -50,7 +50,6 @@ class AppParameterApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(fields.parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
@ -69,7 +68,8 @@ class AppParameterApi(WebApiResource):
user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
@web_ns.route("/meta")

View File

@ -22,6 +22,7 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
@ -165,6 +166,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
# Check if max iteration is reached and model still wants to call tools
if iteration_step == max_iteration_steps and scratchpad.action:
if scratchpad.action.action_name.lower() != "final answer":
raise AgentMaxIterationError(app_config.agent.max_iteration)
# get llm usage
if "usage" in usage_dict:
if usage_dict["usage"] is not None:

View File

@ -25,6 +25,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
@ -222,6 +223,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
final_answer += response + "\n"
# Check if max iteration is reached and model still wants to call tools
if iteration_step == max_iteration_steps and tool_calls:
raise AgentMaxIterationError(app_config.agent.max_iteration)
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:

View File

@ -30,7 +30,6 @@ class SimpleModelProviderEntity(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: list[ModelType]
def __init__(self, provider_entity: ProviderEntity):
@ -44,7 +43,6 @@ class SimpleModelProviderEntity(BaseModel):
label=provider_entity.label,
icon_small=provider_entity.icon_small,
icon_small_dark=provider_entity.icon_small_dark,
icon_large=provider_entity.icon_large,
supported_model_types=provider_entity.supported_model_types,
)
@ -94,7 +92,6 @@ class DefaultModelProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType] = []

View File

@ -1,58 +0,0 @@
import json
import logging
from typing import Any, cast
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from extensions.ext_redis import redis_client, redis_fallback
logger = logging.getLogger(__name__)
class ToolProviderListCache:
"""Cache for tool provider lists"""
CACHE_TTL = 300 # 5 minutes
@staticmethod
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
"""Generate cache key for tool providers list"""
type_filter = typ or "all"
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
@staticmethod
@redis_fallback(default_return=None)
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
"""Get cached tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
cached_data = redis_client.get(cache_key)
if cached_data:
try:
return json.loads(cached_data.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
logger.warning("Failed to decode cached tool providers data")
return None
return None
@staticmethod
@redis_fallback()
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
"""Cache tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
@staticmethod
@redis_fallback()
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
"""Invalidate cache for tool providers"""
if typ:
# Invalidate specific type cache
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.delete(cache_key)
else:
# Invalidate all caches for this tenant
keys = ["builtin", "model", "api", "workflow", "mcp"]
pipeline = redis_client.pipeline()
for key in keys:
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key))
pipeline.delete(cache_key)
pipeline.execute()

View File

@ -100,7 +100,6 @@ class SimpleProviderEntity(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType]
models: list[AIModelEntity] = []
@ -123,7 +122,6 @@ class ProviderEntity(BaseModel):
label: I18nObject
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
icon_small_dark: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
@ -157,7 +155,6 @@ class ProviderEntity(BaseModel):
provider=self.provider,
label=self.label,
icon_small=self.icon_small,
icon_large=self.icon_large,
supported_model_types=self.supported_model_types,
models=self.models,
)

View File

@ -285,7 +285,7 @@ class ModelProviderFactory:
"""
Get provider icon
:param provider: provider name
:param icon_type: icon type (icon_small or icon_large)
:param icon_type: icon type (icon_small or icon_small_dark)
:param lang: language (zh_Hans or en_US)
:return: provider icon
"""
@ -309,13 +309,7 @@ class ModelProviderFactory:
else:
file_name = provider_schema.icon_small_dark.en_US
else:
if not provider_schema.icon_large:
raise ValueError(f"Provider {provider} does not have large icon.")
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_large.zh_Hans
else:
file_name = provider_schema.icon_large.en_US
raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")

View File

@ -331,7 +331,6 @@ class ProviderManager:
provider=provider_schema.provider,
label=provider_schema.label,
icon_small=provider_schema.icon_small,
icon_large=provider_schema.icon_large,
supported_model_types=provider_schema.supported_model_types,
),
)

View File

@ -27,26 +27,44 @@ class CleanProcessor:
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
text = re.sub(pattern, "", text)
# Remove URL but keep Markdown image URLs
# First, temporarily replace Markdown image URLs with a placeholder
markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
placeholders: list[str] = []
# Remove URL but keep Markdown image URLs and link URLs
# Replace the ENTIRE markdown link/image with a single placeholder to protect
# the link text (which might also be a URL) from being removed
markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)"
markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)"
placeholders: list[tuple[str, str, str]] = [] # (type, text, url)
def replace_with_placeholder(match, placeholders=placeholders):
def replace_markdown_with_placeholder(match, placeholders=placeholders):
link_type = "link"
link_text = match.group(1)
url = match.group(2)
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
placeholders.append((link_type, link_text, url))
return placeholder
def replace_image_with_placeholder(match, placeholders=placeholders):
link_type = "image"
url = match.group(1)
placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
placeholders.append(url)
return f"![image]({placeholder})"
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
placeholders.append((link_type, "image", url))
return placeholder
text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
# Protect markdown links first
text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text)
# Then protect markdown images
text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text)
# Now remove all remaining URLs
url_pattern = r"https?://[^\s)]+"
url_pattern = r"https?://\S+"
text = re.sub(url_pattern, "", text)
# Finally, restore the Markdown image URLs
for i, url in enumerate(placeholders):
text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
# Restore the Markdown links and images
for i, (link_type, text_or_alt, url) in enumerate(placeholders):
placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__"
if link_type == "link":
text = text.replace(placeholder, f"[{text_or_alt}]({url})")
else: # image
text = text.replace(placeholder, f"![{text_or_alt}]({url})")
return text
def filter_string(self, text):

View File

@ -1,4 +1,5 @@
import concurrent.futures
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
@ -36,6 +37,8 @@ default_retrieval_model = {
"score_threshold_enabled": False,
}
logger = logging.getLogger(__name__)
class RetrievalService:
# Cache precompiled regular expressions to avoid repeated compilation
@ -106,7 +109,12 @@ class RetrievalService:
)
)
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
if futures:
for future in concurrent.futures.as_completed(futures, timeout=3600):
if exceptions:
for f in futures:
f.cancel()
break
if exceptions:
raise ValueError(";\n".join(exceptions))
@ -210,6 +218,7 @@ class RetrievalService:
)
all_documents.extend(documents)
except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e))
@classmethod
@ -303,6 +312,7 @@ class RetrievalService:
else:
all_documents.extend(documents)
except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e))
@classmethod
@ -351,6 +361,7 @@ class RetrievalService:
else:
all_documents.extend(documents)
except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e))
@staticmethod
@ -663,7 +674,14 @@ class RetrievalService:
document_ids_filter=document_ids_filter,
)
)
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
# Use as_completed for early error propagation - cancel remaining futures on first error
if futures:
for future in concurrent.futures.as_completed(futures, timeout=300):
if future.exception():
# Cancel remaining futures to avoid unnecessary waiting
for f in futures:
f.cancel()
break
if exceptions:
raise ValueError(";\n".join(exceptions))

View File

@ -112,7 +112,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
@ -148,7 +148,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:

View File

@ -1,25 +1,57 @@
"""Abstract interface for document loader implementations."""
import contextlib
import io
import logging
import uuid
from collections.abc import Iterator
import pypdfium2
import pypdfium2.raw as pdfium_c
from configs import dify_config
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfExtractor(BaseExtractor):
"""Load pdf files.
"""
PdfExtractor is used to extract text and images from PDF files.
Args:
file_path: Path to the file to load.
file_path: Path to the PDF file.
tenant_id: Workspace ID.
user_id: ID of the user performing the extraction.
file_cache_key: Optional cache key for the extracted text.
"""
def __init__(self, file_path: str, file_cache_key: str | None = None):
"""Initialize with file path."""
# Magic bytes for image format detection: (magic_bytes, extension, mime_type)
IMAGE_FORMATS = [
(b"\xff\xd8\xff", "jpg", "image/jpeg"),
(b"\x89PNG\r\n\x1a\n", "png", "image/png"),
(b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"),
(b"GIF8", "gif", "image/gif"),
(b"BM", "bmp", "image/bmp"),
(b"II*\x00", "tiff", "image/tiff"),
(b"MM\x00*", "tiff", "image/tiff"),
(b"II+\x00", "tiff", "image/tiff"),
(b"MM\x00+", "tiff", "image/tiff"),
]
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
"""Initialize PdfExtractor."""
self._file_path = file_path
self._tenant_id = tenant_id
self._user_id = user_id
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
@ -50,7 +82,6 @@ class PdfExtractor(BaseExtractor):
def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
import pypdfium2 # type: ignore
with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
@ -59,8 +90,87 @@ class PdfExtractor(BaseExtractor):
text_page = page.get_textpage()
content = text_page.get_text_range()
text_page.close()
image_content = self._extract_images(page)
if image_content:
content += "\n" + image_content
page.close()
metadata = {"source": blob.source, "page": page_number}
yield Document(page_content=content, metadata=metadata)
finally:
pdf_reader.close()
def _extract_images(self, page) -> str:
"""
Extract images from a PDF page, save them to storage and database,
and return markdown image links.
Args:
page: pypdfium2 page object.
Returns:
Markdown string containing links to the extracted images.
"""
image_content = []
upload_files = []
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
try:
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
for obj in image_objects:
try:
# Extract image bytes
img_byte_arr = io.BytesIO()
# Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly
# Fallback to png for other formats
obj.extract(img_byte_arr, fb_format="png")
img_bytes = img_byte_arr.getvalue()
if not img_bytes:
continue
header = img_bytes[: self.MAX_MAGIC_LEN]
image_ext = None
mime_type = None
for magic, ext, mime in self.IMAGE_FORMATS:
if header.startswith(magic):
image_ext = ext
mime_type = mime
break
if not image_ext or not mime_type:
continue
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext
storage.save(file_key, img_bytes)
# save file to db
upload_file = UploadFile(
tenant_id=self._tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=len(img_bytes),
extension=image_ext,
mime_type=mime_type,
created_by=self._user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self._user_id,
used_at=naive_utc_now(),
)
upload_files.append(upload_file)
image_content.append(f"![image]({base_url}/files/{upload_file.id}/file-preview)")
except Exception as e:
logger.warning("Failed to extract image from PDF: %s", e)
continue
except Exception as e:
logger.warning("Failed to get objects from PDF page: %s", e)
if upload_files:
db.session.add_all(upload_files)
db.session.commit()
return "\n".join(image_content)

View File

@ -516,6 +516,9 @@ class DatasetRetrieval:
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
with measure_time() as timer:
cancel_event = threading.Event()
thread_exceptions: list[Exception] = []
if query:
query_thread = threading.Thread(
target=self._multiple_retrieve_thread,
@ -534,6 +537,8 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
)
all_threads.append(query_thread)
@ -557,12 +562,25 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
)
all_threads.append(attachment_thread)
attachment_thread.start()
for thread in all_threads:
thread.join()
# Poll threads with short timeout to detect errors quickly (fail-fast)
while any(t.is_alive() for t in all_threads):
for thread in all_threads:
thread.join(timeout=0.1)
if thread_exceptions:
cancel_event.set()
break
if thread_exceptions:
break
if thread_exceptions:
raise thread_exceptions[0]
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
if all_documents:
@ -1404,40 +1422,53 @@ class DatasetRetrieval:
score_threshold: float,
query: str | None,
attachment_id: str | None,
cancel_event: threading.Event | None = None,
thread_exceptions: list[Exception] | None = None,
):
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
try:
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
# Check for cancellation signal
if cancel_event and cancel_event.is_set():
break
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
# Poll threads with short timeout to respond quickly to cancellation
while any(t.is_alive() for t in threads):
for thread in threads:
thread.join(timeout=0.1)
if cancel_event and cancel_event.is_set():
break
if cancel_event and cancel_event.is_set():
break
if reranking_enable:
# do rerank for searched documents
@ -1470,3 +1501,8 @@ class DatasetRetrieval:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
except Exception as e:
if cancel_event:
cancel_event.set()
if thread_exceptions is not None:
thread_exceptions.append(e)

View File

@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
) -> tuple[list[ApiToolBundle], str]:
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
auto parse to tool bundle

View File

@ -4,6 +4,7 @@ import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Preserves markdown links like [text](url) at the start.
Args:
text (str): The input text to process.
@ -11,6 +12,11 @@ def remove_leading_symbols(text: str) -> str:
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Check if text starts with a markdown link - preserve it
markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)"
if re.match(markdown_link_pattern, text):
return text
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'

View File

@ -54,7 +54,6 @@ class WorkflowToolProviderController(ToolProviderController):
raise ValueError("app not found")
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
@ -67,7 +66,7 @@ class WorkflowToolProviderController(ToolProviderController):
credentials_schema=[],
plugin_id=None,
),
provider_id="",
provider_id=db_provider.id,
)
controller.tools = [

View File

@ -60,6 +60,7 @@ class SkipPropagator:
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
self._state_manager.start_execution(downstream_node_id)
return
# All edges are skipped, propagate skip to this node

View File

@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)
class AgentMaxIterationError(AgentNodeError):
"""Exception raised when the agent exceeds the maximum iteration limit."""
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration
super().__init__(
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
f"The agent was unable to complete the task within the allowed number of iterations."
)

View File

@ -12,9 +12,8 @@ from dify_app import DifyApp
def _get_celery_ssl_options() -> dict[str, Any] | None:
"""Get SSL configuration for Celery broker/backend connections."""
# Use REDIS_USE_SSL for consistency with the main Redis client
# Only apply SSL if we're using Redis as broker/backend
if not dify_config.REDIS_USE_SSL:
if not dify_config.BROKER_USE_SSL:
return None
# Check if Celery is actually using Redis

View File

@ -13,12 +13,20 @@ class TencentCosStorage(BaseStorage):
super().__init__()
self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME
config = CosConfig(
Region=dify_config.TENCENT_COS_REGION,
SecretId=dify_config.TENCENT_COS_SECRET_ID,
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
Scheme=dify_config.TENCENT_COS_SCHEME,
)
if dify_config.TENCENT_COS_CUSTOM_DOMAIN:
config = CosConfig(
Domain=dify_config.TENCENT_COS_CUSTOM_DOMAIN,
SecretId=dify_config.TENCENT_COS_SECRET_ID,
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
Scheme=dify_config.TENCENT_COS_SCHEME,
)
else:
config = CosConfig(
Region=dify_config.TENCENT_COS_REGION,
SecretId=dify_config.TENCENT_COS_SECRET_ID,
SecretKey=dify_config.TENCENT_COS_SECRET_KEY,
Scheme=dify_config.TENCENT_COS_SCHEME,
)
self.client = CosS3Client(config)
def save(self, filename, data):

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from libs.helper import TimestampField
@ -12,7 +12,7 @@ annotation_fields = {
}
def build_annotation_model(api_or_ns: Api | Namespace):
def build_annotation_model(api_or_ns: Namespace):
"""Build the annotation model for the API or Namespace."""
return api_or_ns.model("Annotation", annotation_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
@ -46,7 +46,7 @@ message_file_fields = {
}
def build_message_file_model(api_or_ns: Api | Namespace):
def build_message_file_model(api_or_ns: Namespace):
"""Build the message file fields for the API or Namespace."""
return api_or_ns.model("MessageFile", message_file_fields)
@ -217,7 +217,7 @@ conversation_infinite_scroll_pagination_fields = {
}
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
simple_conversation_model = build_simple_conversation_model(api_or_ns)
@ -226,11 +226,11 @@ def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespa
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
def build_conversation_delete_model(api_or_ns: Api | Namespace):
def build_conversation_delete_model(api_or_ns: Namespace):
"""Build the conversation delete model for the API or Namespace."""
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
def build_simple_conversation_model(api_or_ns: Api | Namespace):
def build_simple_conversation_model(api_or_ns: Namespace):
"""Build the simple conversation model for the API or Namespace."""
return api_or_ns.model("SimpleConversation", simple_conversation_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from libs.helper import TimestampField
@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = {
}
def build_conversation_variable_model(api_or_ns: Api | Namespace):
def build_conversation_variable_model(api_or_ns: Namespace):
"""Build the conversation variable model for the API or Namespace."""
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
# Build the nested variable model first
conversation_variable_model = build_conversation_variable_model(api_or_ns)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
simple_end_user_fields = {
"id": fields.String,
@ -8,5 +8,5 @@ simple_end_user_fields = {
}
def build_simple_end_user_model(api_or_ns: Api | Namespace):
def build_simple_end_user_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from libs.helper import TimestampField
@ -14,7 +14,7 @@ upload_config_fields = {
}
def build_upload_config_model(api_or_ns: Api | Namespace):
def build_upload_config_model(api_or_ns: Namespace):
"""Build the upload config model for the API or Namespace.
Args:
@ -39,7 +39,7 @@ file_fields = {
}
def build_file_model(api_or_ns: Api | Namespace):
def build_file_model(api_or_ns: Namespace):
"""Build the file model for the API or Namespace.
Args:
@ -57,7 +57,7 @@ remote_file_info_fields = {
}
def build_remote_file_info_model(api_or_ns: Api | Namespace):
def build_remote_file_info_model(api_or_ns: Namespace):
"""Build the remote file info model for the API or Namespace.
Args:
@ -81,7 +81,7 @@ file_fields_with_signed_url = {
}
def build_file_with_signed_url_model(api_or_ns: Api | Namespace):
def build_file_with_signed_url_model(api_or_ns: Namespace):
"""Build the file with signed URL model for the API or Namespace.
Args:

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from libs.helper import AvatarUrlField, TimestampField
@ -9,7 +9,7 @@ simple_account_fields = {
}
def build_simple_account_model(api_or_ns: Api | Namespace):
def build_simple_account_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleAccount", simple_account_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField
@ -10,7 +10,7 @@ feedback_fields = {
}
def build_feedback_model(api_or_ns: Api | Namespace):
def build_feedback_model(api_or_ns: Namespace):
"""Build the feedback model for the API or Namespace."""
return api_or_ns.model("Feedback", feedback_fields)
@ -30,7 +30,7 @@ agent_thought_fields = {
}
def build_agent_thought_model(api_or_ns: Api | Namespace):
def build_agent_thought_model(api_or_ns: Namespace):
"""Build the agent thought model for the API or Namespace."""
return api_or_ns.model("AgentThought", agent_thought_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
dataset_tag_fields = {
"id": fields.String,
@ -8,5 +8,5 @@ dataset_tag_fields = {
}
def build_dataset_tag_fields(api_or_ns: Api | Namespace):
def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
from fields.member_fields import build_simple_account_model, simple_account_fields
@ -17,7 +17,7 @@ workflow_app_log_partial_fields = {
}
def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
"""Build the workflow app log partial model for the API or Namespace."""
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
simple_account_model = build_simple_account_model(api_or_ns)
@ -43,7 +43,7 @@ workflow_app_log_pagination_fields = {
}
def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
"""Build the workflow app log pagination model for the API or Namespace."""
# Build the nested partial model first
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)

View File

@ -1,4 +1,4 @@
from flask_restx import Api, Namespace, fields
from flask_restx import Namespace, fields
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
@ -19,7 +19,7 @@ workflow_run_for_log_fields = {
}
def build_workflow_run_for_log_model(api_or_ns: Api | Namespace):
def build_workflow_run_for_log_model(api_or_ns: Namespace):
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)

347
api/libs/archive_storage.py Normal file
View File

@ -0,0 +1,347 @@
"""
Archive Storage Client for S3-compatible storage.
This module provides a dedicated storage client for archiving or exporting logs
to S3-compatible object storage.
"""
import base64
import datetime
import gzip
import hashlib
import logging
from collections.abc import Generator
from typing import Any, cast
import boto3
import orjson
from botocore.client import Config
from botocore.exceptions import ClientError
from configs import dify_config
logger = logging.getLogger(__name__)
class ArchiveStorageError(Exception):
"""Base exception for archive storage operations."""
pass
class ArchiveStorageNotConfiguredError(ArchiveStorageError):
"""Raised when archive storage is not properly configured."""
pass
class ArchiveStorage:
"""
S3-compatible storage client for archiving or exporting.
This client provides methods for storing and retrieving archived data in JSONL+gzip format.
"""
def __init__(self, bucket: str):
if not dify_config.ARCHIVE_STORAGE_ENABLED:
raise ArchiveStorageNotConfiguredError("Archive storage is not enabled")
if not bucket:
raise ArchiveStorageNotConfiguredError("Archive storage bucket is not configured")
if not all(
[
dify_config.ARCHIVE_STORAGE_ENDPOINT,
bucket,
dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
dify_config.ARCHIVE_STORAGE_SECRET_KEY,
]
):
raise ArchiveStorageNotConfiguredError(
"Archive storage configuration is incomplete. "
"Required: ARCHIVE_STORAGE_ENDPOINT, ARCHIVE_STORAGE_ACCESS_KEY, "
"ARCHIVE_STORAGE_SECRET_KEY, and a bucket name"
)
self.bucket = bucket
self.client = boto3.client(
"s3",
endpoint_url=dify_config.ARCHIVE_STORAGE_ENDPOINT,
aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY,
region_name=dify_config.ARCHIVE_STORAGE_REGION,
config=Config(s3={"addressing_style": "path"}),
)
# Verify bucket accessibility
try:
self.client.head_bucket(Bucket=self.bucket)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == "404":
raise ArchiveStorageNotConfiguredError(f"Archive bucket '{self.bucket}' does not exist")
elif error_code == "403":
raise ArchiveStorageNotConfiguredError(f"Access denied to archive bucket '{self.bucket}'")
else:
raise ArchiveStorageError(f"Failed to access archive bucket: {e}")
def put_object(self, key: str, data: bytes) -> str:
"""
Upload an object to the archive storage.
Args:
key: Object key (path) within the bucket
data: Binary data to upload
Returns:
MD5 checksum of the uploaded data
Raises:
ArchiveStorageError: If upload fails
"""
checksum = hashlib.md5(data).hexdigest()
try:
self.client.put_object(
Bucket=self.bucket,
Key=key,
Body=data,
ContentMD5=self._content_md5(data),
)
logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum)
return checksum
except ClientError as e:
raise ArchiveStorageError(f"Failed to upload object '{key}': {e}")
def get_object(self, key: str) -> bytes:
"""
Download an object from the archive storage.
Args:
key: Object key (path) within the bucket
Returns:
Binary data of the object
Raises:
ArchiveStorageError: If download fails
FileNotFoundError: If object does not exist
"""
try:
response = self.client.get_object(Bucket=self.bucket, Key=key)
return response["Body"].read()
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == "NoSuchKey":
raise FileNotFoundError(f"Archive object not found: {key}")
raise ArchiveStorageError(f"Failed to download object '{key}': {e}")
def get_object_stream(self, key: str) -> Generator[bytes, None, None]:
"""
Stream an object from the archive storage.
Args:
key: Object key (path) within the bucket
Yields:
Chunks of binary data
Raises:
ArchiveStorageError: If download fails
FileNotFoundError: If object does not exist
"""
try:
response = self.client.get_object(Bucket=self.bucket, Key=key)
yield from response["Body"].iter_chunks()
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == "NoSuchKey":
raise FileNotFoundError(f"Archive object not found: {key}")
raise ArchiveStorageError(f"Failed to stream object '{key}': {e}")
def object_exists(self, key: str) -> bool:
"""
Check if an object exists in the archive storage.
Args:
key: Object key (path) within the bucket
Returns:
True if object exists, False otherwise
"""
try:
self.client.head_object(Bucket=self.bucket, Key=key)
return True
except ClientError:
return False
def delete_object(self, key: str) -> None:
"""
Delete an object from the archive storage.
Args:
key: Object key (path) within the bucket
Raises:
ArchiveStorageError: If deletion fails
"""
try:
self.client.delete_object(Bucket=self.bucket, Key=key)
logger.debug("Deleted object: %s", key)
except ClientError as e:
raise ArchiveStorageError(f"Failed to delete object '{key}': {e}")
def generate_presigned_url(self, key: str, expires_in: int = 3600) -> str:
"""
Generate a pre-signed URL for downloading an object.
Args:
key: Object key (path) within the bucket
expires_in: URL validity duration in seconds (default: 1 hour)
Returns:
Pre-signed URL string.
Raises:
ArchiveStorageError: If generation fails
"""
try:
return self.client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": self.bucket, "Key": key},
ExpiresIn=expires_in,
)
except ClientError as e:
raise ArchiveStorageError(f"Failed to generate pre-signed URL for '{key}': {e}")
def list_objects(self, prefix: str) -> list[str]:
"""
List objects under a given prefix.
Args:
prefix: Object key prefix to filter by
Returns:
List of object keys matching the prefix
"""
keys = []
paginator = self.client.get_paginator("list_objects_v2")
try:
for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
for obj in page.get("Contents", []):
keys.append(obj["Key"])
except ClientError as e:
raise ArchiveStorageError(f"Failed to list objects with prefix '{prefix}': {e}")
return keys
@staticmethod
def _content_md5(data: bytes) -> str:
"""Calculate base64-encoded MD5 for Content-MD5 header."""
return base64.b64encode(hashlib.md5(data).digest()).decode()
@staticmethod
def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes:
"""
Serialize records to gzipped JSONL format.
Args:
records: List of dictionaries to serialize
Returns:
Gzipped JSONL bytes
"""
lines = []
for record in records:
# Convert datetime objects to ISO format strings
serialized = ArchiveStorage._serialize_record(record)
lines.append(orjson.dumps(serialized))
jsonl_content = b"\n".join(lines)
if jsonl_content:
jsonl_content += b"\n"
return gzip.compress(jsonl_content)
@staticmethod
def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]:
"""
Deserialize gzipped JSONL data to records.
Args:
data: Gzipped JSONL bytes
Returns:
List of dictionaries
"""
jsonl_content = gzip.decompress(data)
records = []
for line in jsonl_content.splitlines():
if line:
records.append(orjson.loads(line))
return records
@staticmethod
def _serialize_record(record: dict[str, Any]) -> dict[str, Any]:
"""Serialize a single record, converting special types."""
def _serialize(item: Any) -> Any:
if isinstance(item, datetime.datetime):
return item.isoformat()
if isinstance(item, dict):
return {key: _serialize(value) for key, value in item.items()}
if isinstance(item, list):
return [_serialize(value) for value in item]
return item
return cast(dict[str, Any], _serialize(record))
@staticmethod
def compute_checksum(data: bytes) -> str:
"""Compute MD5 checksum of data."""
return hashlib.md5(data).hexdigest()
# Singleton instance (lazy initialization)
_archive_storage: ArchiveStorage | None = None
_export_storage: ArchiveStorage | None = None
def get_archive_storage() -> ArchiveStorage:
"""
Get the archive storage singleton instance.
Returns:
ArchiveStorage instance
Raises:
ArchiveStorageNotConfiguredError: If archive storage is not configured
"""
global _archive_storage
if _archive_storage is None:
archive_bucket = dify_config.ARCHIVE_STORAGE_ARCHIVE_BUCKET
if not archive_bucket:
raise ArchiveStorageNotConfiguredError(
"Archive storage bucket is not configured. Required: ARCHIVE_STORAGE_ARCHIVE_BUCKET"
)
_archive_storage = ArchiveStorage(bucket=archive_bucket)
return _archive_storage
def get_export_storage() -> ArchiveStorage:
"""
Get the export storage singleton instance.
Returns:
ArchiveStorage instance
"""
global _export_storage
if _export_storage is None:
export_bucket = dify_config.ARCHIVE_STORAGE_EXPORT_BUCKET
if not export_bucket:
raise ArchiveStorageNotConfiguredError(
"Archive export bucket is not configured. Required: ARCHIVE_STORAGE_EXPORT_BUCKET"
)
_export_storage = ArchiveStorage(bucket=export_bucket)
return _export_storage

View File

@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '00bacef91f18'
down_revision = '8ec536f3c800'
@ -23,31 +20,17 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description', sa.Text(), nullable=False))
batch_op.drop_column('description_str')
else:
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
batch_op.drop_column('description_str')
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
batch_op.drop_column('description_str')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False))
batch_op.drop_column('description')
else:
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
batch_op.drop_column('description')
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
batch_op.drop_column('description')
# ### end Alembic commands ###

View File

@ -7,14 +7,10 @@ Create Date: 2024-01-10 04:40:57.257824
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '114eed84c228'
down_revision = 'c71211c8f604'
@ -32,13 +28,7 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False))
else:
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
# ### end Alembic commands ###

View File

@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '161cadc1af8d'
down_revision = '7e6a8693e07a'
@ -23,16 +20,9 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
# Step 1: Add column without NOT NULL constraint
op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False))
else:
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
# Step 1: Add column without NOT NULL constraint
op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
# Step 1: Add column without NOT NULL constraint
op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
# ### end Alembic commands ###

View File

@ -9,11 +9,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '6af6a521a53e'
down_revision = 'd57ba9ebb251'
@ -23,58 +18,30 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('document_id',
existing_type=sa.UUID(),
nullable=True)
batch_op.alter_column('data_source_type',
existing_type=sa.TEXT(),
nullable=True)
batch_op.alter_column('segment_id',
existing_type=sa.UUID(),
nullable=True)
else:
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('document_id',
existing_type=models.types.StringUUID(),
nullable=True)
batch_op.alter_column('data_source_type',
existing_type=models.types.LongText(),
nullable=True)
batch_op.alter_column('segment_id',
existing_type=models.types.StringUUID(),
nullable=True)
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('document_id',
existing_type=models.types.StringUUID(),
nullable=True)
batch_op.alter_column('data_source_type',
existing_type=models.types.LongText(),
nullable=True)
batch_op.alter_column('segment_id',
existing_type=models.types.StringUUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('segment_id',
existing_type=sa.UUID(),
nullable=False)
batch_op.alter_column('data_source_type',
existing_type=sa.TEXT(),
nullable=False)
batch_op.alter_column('document_id',
existing_type=sa.UUID(),
nullable=False)
else:
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('segment_id',
existing_type=models.types.StringUUID(),
nullable=False)
batch_op.alter_column('data_source_type',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('document_id',
existing_type=models.types.StringUUID(),
nullable=False)
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('segment_id',
existing_type=models.types.StringUUID(),
nullable=False)
batch_op.alter_column('data_source_type',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('document_id',
existing_type=models.types.StringUUID(),
nullable=False)
# ### end Alembic commands ###

View File

@ -8,7 +8,6 @@ Create Date: 2024-11-01 04:34:23.816198
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'd3f6769a94a3'

View File

@ -28,85 +28,45 @@ def upgrade():
op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
if _is_pg(conn):
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=sa.TEXT(),
nullable=False)
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=sa.TEXT(),
nullable=False)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=sa.TEXT(),
nullable=False)
else:
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.TEXT(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.TEXT(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.TEXT(),
type_=sa.VARCHAR(length=255),
nullable=True)
else:
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
# ### end Alembic commands ###

View File

@ -49,57 +49,33 @@ def upgrade():
op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL")
op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL")
op.execute("UPDATE workflows SET features = '' WHERE features IS NULL")
if _is_pg(conn):
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('graph',
existing_type=sa.TEXT(),
nullable=False)
batch_op.alter_column('features',
existing_type=sa.TEXT(),
nullable=False)
batch_op.alter_column('updated_at',
existing_type=postgresql.TIMESTAMP(),
nullable=False)
else:
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('graph',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('features',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('updated_at',
existing_type=sa.TIMESTAMP(),
nullable=False)
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('graph',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('features',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('updated_at',
existing_type=sa.TIMESTAMP(),
nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('updated_at',
existing_type=postgresql.TIMESTAMP(),
nullable=True)
batch_op.alter_column('features',
existing_type=sa.TEXT(),
nullable=True)
batch_op.alter_column('graph',
existing_type=sa.TEXT(),
nullable=True)
else:
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('updated_at',
existing_type=sa.TIMESTAMP(),
nullable=True)
batch_op.alter_column('features',
existing_type=models.types.LongText(),
nullable=True)
batch_op.alter_column('graph',
existing_type=models.types.LongText(),
nullable=True)
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('updated_at',
existing_type=sa.TIMESTAMP(),
nullable=True)
batch_op.alter_column('features',
existing_type=models.types.LongText(),
nullable=True)
batch_op.alter_column('graph',
existing_type=models.types.LongText(),
nullable=True)
if _is_pg(conn):
with op.batch_alter_table('messages', schema=None) as batch_op:

View File

@ -86,57 +86,30 @@ def upgrade():
def migrate_existing_provider_models_data():
"""migrate provider_models table data to provider_model_credentials"""
conn = op.get_bind()
# Define table structure for data manipulation
if _is_pg(conn):
provider_models_table = table('provider_models',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
column('provider_name', sa.String()),
column('model_name', sa.String()),
column('model_type', sa.String()),
column('encrypted_config', sa.Text()),
column('created_at', sa.DateTime()),
column('updated_at', sa.DateTime()),
column('credential_id', models.types.StringUUID()),
)
else:
provider_models_table = table('provider_models',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
column('provider_name', sa.String()),
column('model_name', sa.String()),
column('model_type', sa.String()),
column('encrypted_config', models.types.LongText()),
column('created_at', sa.DateTime()),
column('updated_at', sa.DateTime()),
column('credential_id', models.types.StringUUID()),
)
# Define table structure for data manipulatio
provider_models_table = table('provider_models',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
column('provider_name', sa.String()),
column('model_name', sa.String()),
column('model_type', sa.String()),
column('encrypted_config', models.types.LongText()),
column('created_at', sa.DateTime()),
column('updated_at', sa.DateTime()),
column('credential_id', models.types.StringUUID()),
)
if _is_pg(conn):
provider_model_credentials_table = table('provider_model_credentials',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
column('provider_name', sa.String()),
column('model_name', sa.String()),
column('model_type', sa.String()),
column('credential_name', sa.String()),
column('encrypted_config', sa.Text()),
column('created_at', sa.DateTime()),
column('updated_at', sa.DateTime())
)
else:
provider_model_credentials_table = table('provider_model_credentials',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
column('provider_name', sa.String()),
column('model_name', sa.String()),
column('model_type', sa.String()),
column('credential_name', sa.String()),
column('encrypted_config', models.types.LongText()),
column('created_at', sa.DateTime()),
column('updated_at', sa.DateTime())
)
provider_model_credentials_table = table('provider_model_credentials',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
column('provider_name', sa.String()),
column('model_name', sa.String()),
column('model_type', sa.String()),
column('credential_name', sa.String()),
column('encrypted_config', models.types.LongText()),
column('created_at', sa.DateTime()),
column('updated_at', sa.DateTime())
)
# Get database connection
@ -183,14 +156,8 @@ def migrate_existing_provider_models_data():
def downgrade():
# Re-add encrypted_config column to provider_models table
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))
else:
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True))
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True))
if not context.is_offline_mode():
# Migrate data back from provider_model_credentials to provider_models

View File

@ -8,7 +8,6 @@ Create Date: 2025-08-20 17:47:17.015695
from alembic import op
import models as models
import sqlalchemy as sa
from libs.uuid_utils import uuidv7
def _is_pg(conn):

View File

@ -9,8 +9,6 @@ from alembic import op
import models as models
def _is_pg(conn):
return conn.dialect.name == "postgresql"
import sqlalchemy as sa
@ -23,12 +21,7 @@ depends_on = None
def upgrade():
# Add encrypted_headers column to tool_mcp_providers table
conn = op.get_bind()
if _is_pg(conn):
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', sa.Text(), nullable=True))
else:
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True))
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True))
def downgrade():

View File

@ -44,6 +44,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'),
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx')
)
if _is_pg(conn):
op.create_table('datasource_oauth_tenant_params',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@ -70,6 +71,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='datasource_oauth_tenant_config_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='datasource_oauth_tenant_config_unique')
)
if _is_pg(conn):
op.create_table('datasource_providers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@ -104,6 +106,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', 'name', name='datasource_provider_unique_name')
)
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.create_index('datasource_provider_auth_type_provider_idx', ['tenant_id', 'plugin_id', 'provider'], unique=False)
@ -133,6 +136,7 @@ def upgrade():
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey')
)
with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op:
batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False)
@ -174,6 +178,7 @@ def upgrade():
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey')
)
if _is_pg(conn):
op.create_table('pipeline_customized_templates',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@ -193,7 +198,6 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
)
else:
# MySQL: Use compatible syntax
op.create_table('pipeline_customized_templates',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
@ -211,6 +215,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
)
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False)
@ -236,6 +241,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_recommended_plugin_pkey')
)
if _is_pg(conn):
op.create_table('pipelines',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@ -266,6 +272,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_pkey')
)
if _is_pg(conn):
op.create_table('workflow_draft_variable_files',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@ -292,6 +299,7 @@ def upgrade():
sa.Column('value_type', sa.String(20), nullable=False),
sa.PrimaryKeyConstraint('id', name=op.f('workflow_draft_variable_files_pkey'))
)
if _is_pg(conn):
op.create_table('workflow_node_execution_offload',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@ -316,6 +324,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name=op.f('workflow_node_execution_offload_pkey')),
sa.UniqueConstraint('node_execution_id', 'type', name=op.f('workflow_node_execution_offload_node_execution_id_key'))
)
if _is_pg(conn):
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True))
@ -342,6 +351,7 @@ def upgrade():
comment='Indicates whether the current value is the default for a conversation variable. Always `FALSE` for other types of variables.',)
)
batch_op.create_index('workflow_draft_variable_file_id_idx', ['file_id'], unique=False)
if _is_pg(conn):
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False))

View File

@ -9,8 +9,6 @@ from alembic import op
import models as models
def _is_pg(conn):
return conn.dialect.name == "postgresql"
import sqlalchemy as sa
@ -33,15 +31,9 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
else:
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False))
batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True))
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False))
batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True))
# ### end Alembic commands ###

View File

@ -9,7 +9,6 @@ Create Date: 2025-10-22 16:11:31.805407
from alembic import op
import models as models
import sqlalchemy as sa
from libs.uuid_utils import uuidv7
def _is_pg(conn):
return conn.dialect.name == "postgresql"

View File

@ -105,6 +105,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='trigger_oauth_tenant_client_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_trigger_oauth_tenant_client')
)
if _is_pg(conn):
op.create_table('trigger_subscriptions',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
@ -143,6 +144,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='trigger_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'provider_id', 'name', name='unique_trigger_provider')
)
with op.batch_alter_table('trigger_subscriptions', schema=None) as batch_op:
batch_op.create_index('idx_trigger_providers_endpoint', ['endpoint_id'], unique=True)
batch_op.create_index('idx_trigger_providers_tenant_endpoint', ['tenant_id', 'endpoint_id'], unique=False)
@ -176,6 +178,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='workflow_plugin_trigger_pkey'),
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node_subscription')
)
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id', 'event_name'], unique=False)
@ -207,6 +210,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='workflow_schedule_plan_pkey'),
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node')
)
with op.batch_alter_table('workflow_schedule_plans', schema=None) as batch_op:
batch_op.create_index('workflow_schedule_plan_next_idx', ['next_run_at'], unique=False)
@ -264,6 +268,7 @@ def upgrade():
sa.Column('finished_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey')
)
with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op:
batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False)
batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False)
@ -299,6 +304,7 @@ def upgrade():
sa.UniqueConstraint('app_id', 'node_id', name='uniq_node'),
sa.UniqueConstraint('webhook_id', name='uniq_webhook_id')
)
with op.batch_alter_table('workflow_webhook_triggers', schema=None) as batch_op:
batch_op.create_index('workflow_webhook_trigger_tenant_idx', ['tenant_id'], unique=False)

View File

@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '23db93619b9d'
down_revision = '8ae9bc661daa'
@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_files', sa.Text(), nullable=True))
else:
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True))
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True))
# ### end Alembic commands ###

View File

@ -62,14 +62,8 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True))
else:
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True))
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True))
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
batch_op.drop_index('app_annotation_settings_app_idx')

View File

@ -11,9 +11,6 @@ from alembic import op
import models as models
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '2a3aebbbf4bb'
down_revision = 'c031d46af369'
@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True))
else:
with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True))
with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True))
# ### end Alembic commands ###

View File

@ -7,14 +7,10 @@ Create Date: 2023-09-22 15:41:01.243183
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '2e9819ca5b28'
down_revision = 'ab23c11305d4'
@ -24,35 +20,19 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True))
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
batch_op.drop_column('dataset_id')
else:
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True))
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
batch_op.drop_column('dataset_id')
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True))
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
batch_op.drop_column('dataset_id')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True))
batch_op.drop_index('api_token_tenant_idx')
batch_op.drop_column('tenant_id')
else:
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True))
batch_op.drop_index('api_token_tenant_idx')
batch_op.drop_column('tenant_id')
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True))
batch_op.drop_index('api_token_tenant_idx')
batch_op.drop_column('tenant_id')
# ### end Alembic commands ###

View File

@ -7,14 +7,10 @@ Create Date: 2024-03-07 08:30:29.133614
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '42e85ed5564d'
down_revision = 'f9107f83abab'
@ -24,59 +20,31 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('app_model_config_id',
existing_type=postgresql.UUID(),
nullable=True)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True)
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=True)
else:
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('app_model_config_id',
existing_type=models.types.StringUUID(),
nullable=True)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True)
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('app_model_config_id',
existing_type=models.types.StringUUID(),
nullable=True)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True)
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('app_model_config_id',
existing_type=postgresql.UUID(),
nullable=False)
else:
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('app_model_config_id',
existing_type=models.types.StringUUID(),
nullable=False)
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('app_model_config_id',
existing_type=models.types.StringUUID(),
nullable=False)
# ### end Alembic commands ###

View File

@ -6,14 +6,10 @@ Create Date: 2024-01-12 03:42:27.362415
"""
from alembic import op
from sqlalchemy.dialects import postgresql
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '4829e54d2fee'
down_revision = '114eed84c228'
@ -23,39 +19,21 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
# PostgreSQL: Keep original syntax
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.alter_column('message_chain_id',
existing_type=postgresql.UUID(),
nullable=True)
else:
# MySQL: Use compatible syntax
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.alter_column('message_chain_id',
existing_type=models.types.StringUUID(),
nullable=True)
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.alter_column('message_chain_id',
existing_type=models.types.StringUUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
# PostgreSQL: Keep original syntax
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.alter_column('message_chain_id',
existing_type=postgresql.UUID(),
nullable=False)
else:
# MySQL: Use compatible syntax
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.alter_column('message_chain_id',
existing_type=models.types.StringUUID(),
nullable=False)
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.alter_column('message_chain_id',
existing_type=models.types.StringUUID(),
nullable=False)
# ### end Alembic commands ###

View File

@ -6,14 +6,10 @@ Create Date: 2024-03-14 04:54:56.679506
"""
from alembic import op
from sqlalchemy.dialects import postgresql
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '563cf8bf777b'
down_revision = 'b5429b71023c'
@ -23,35 +19,19 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('conversation_id',
existing_type=postgresql.UUID(),
nullable=True)
else:
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('conversation_id',
existing_type=models.types.StringUUID(),
nullable=True)
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('conversation_id',
existing_type=models.types.StringUUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('conversation_id',
existing_type=postgresql.UUID(),
nullable=False)
else:
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('conversation_id',
existing_type=models.types.StringUUID(),
nullable=False)
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('conversation_id',
existing_type=models.types.StringUUID(),
nullable=False)
# ### end Alembic commands ###

View File

@ -48,12 +48,9 @@ def upgrade():
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False)
if _is_pg(conn):
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True))
else:
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True))
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True))
# ### end Alembic commands ###

View File

@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '714aafe25d39'
down_revision = 'f2a6fc85e260'
@ -23,16 +20,9 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False))
batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False))
else:
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False))
batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False))
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_question', models.types.LongText(), nullable=False))
batch_op.add_column(sa.Column('annotation_content', models.types.LongText(), nullable=False))
# ### end Alembic commands ###

View File

@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '77e83833755c'
down_revision = '6dcb43972bdc'
@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True))
else:
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True))
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('retriever_resource', models.types.LongText(), nullable=True))
# ### end Alembic commands ###

View File

@ -27,7 +27,6 @@ def upgrade():
conn = op.get_bind()
if _is_pg(conn):
# PostgreSQL: Keep original syntax
op.create_table('tool_providers',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
@ -40,7 +39,6 @@ def upgrade():
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
else:
# MySQL: Use compatible syntax
op.create_table('tool_providers',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
@ -52,12 +50,9 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
if _is_pg(conn):
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('sensitive_word_avoidance', sa.Text(), nullable=True))
else:
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True))
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('sensitive_word_avoidance', models.types.LongText(), nullable=True))
# ### end Alembic commands ###

View File

@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '88072f0caa04'
down_revision = '246ba09cbbdb'
@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tenants', schema=None) as batch_op:
batch_op.add_column(sa.Column('custom_config', sa.Text(), nullable=True))
else:
with op.batch_alter_table('tenants', schema=None) as batch_op:
batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True))
with op.batch_alter_table('tenants', schema=None) as batch_op:
batch_op.add_column(sa.Column('custom_config', models.types.LongText(), nullable=True))
# ### end Alembic commands ###

View File

@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '89c7899ca936'
down_revision = '187385f442fc'
@ -23,39 +20,21 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('description',
existing_type=sa.VARCHAR(length=255),
type_=sa.Text(),
existing_nullable=True)
else:
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('description',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
existing_nullable=True)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('description',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
existing_nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('description',
existing_type=sa.Text(),
type_=sa.VARCHAR(length=255),
existing_nullable=True)
else:
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('description',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
existing_nullable=True)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('description',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
existing_nullable=True)
# ### end Alembic commands ###

View File

@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '8ec536f3c800'
down_revision = 'ad472b61a054'
@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('credentials_str', sa.Text(), nullable=False))
else:
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False))
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('credentials_str', models.types.LongText(), nullable=False))
# ### end Alembic commands ###

View File

@ -57,12 +57,9 @@ def upgrade():
batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False)
batch_op.create_index('message_file_message_idx', ['message_id'], unique=False)
if _is_pg(conn):
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True))
else:
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True))
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('file_upload', models.types.LongText(), nullable=True))
if _is_pg(conn):
with op.batch_alter_table('upload_files', schema=None) as batch_op:

View File

@ -24,7 +24,6 @@ def upgrade():
conn = op.get_bind()
if _is_pg(conn):
# PostgreSQL: Keep original syntax
with op.batch_alter_table('pinned_conversations', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'::character varying"), nullable=False))
batch_op.drop_index('pinned_conversation_conversation_idx')
@ -35,7 +34,6 @@ def upgrade():
batch_op.drop_index('saved_message_message_idx')
batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by_role', 'created_by'], unique=False)
else:
# MySQL: Use compatible syntax
with op.batch_alter_table('pinned_conversations', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'end_user'"), nullable=False))
batch_op.drop_index('pinned_conversation_conversation_idx')

View File

@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = 'a5b56fb053ef'
down_revision = 'd3d503a3471c'
@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('speech_to_text', sa.Text(), nullable=True))
else:
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True))
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('speech_to_text', models.types.LongText(), nullable=True))
# ### end Alembic commands ###

View File

@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = 'a9836e3baeee'
down_revision = '968fff4c0ab9'
@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True))
else:
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True))
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('external_data_tools', models.types.LongText(), nullable=True))
# ### end Alembic commands ###

View File

@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = 'b24be59fbb04'
down_revision = 'de95f5c77138'
@ -23,14 +20,8 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('text_to_speech', sa.Text(), nullable=True))
else:
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True))
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('text_to_speech', models.types.LongText(), nullable=True))
# ### end Alembic commands ###

View File

@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = 'b3a09c049e8e'
down_revision = '2e9819ca5b28'
@ -23,20 +20,11 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True))
else:
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True))
batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True))
batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True))
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple'))
batch_op.add_column(sa.Column('chat_prompt_config', models.types.LongText(), nullable=True))
batch_op.add_column(sa.Column('completion_prompt_config', models.types.LongText(), nullable=True))
batch_op.add_column(sa.Column('dataset_configs', models.types.LongText(), nullable=True))
# ### end Alembic commands ###

View File

@ -7,7 +7,6 @@ Create Date: 2024-06-17 10:01:00.255189
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models.types

View File

@ -54,12 +54,9 @@ def upgrade():
batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False)
batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False)
if _is_pg(conn):
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True))
else:
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True))
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), nullable=True))
if _is_pg(conn):
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
@ -68,54 +65,31 @@ def upgrade():
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'"), nullable=False))
if _is_pg(conn):
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.add_column(sa.Column('question', sa.Text(), nullable=True))
batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
batch_op.alter_column('conversation_id',
existing_type=postgresql.UUID(),
nullable=True)
batch_op.alter_column('message_id',
existing_type=postgresql.UUID(),
nullable=True)
else:
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True))
batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
batch_op.alter_column('conversation_id',
existing_type=models.types.StringUUID(),
nullable=True)
batch_op.alter_column('message_id',
existing_type=models.types.StringUUID(),
nullable=True)
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.add_column(sa.Column('question', models.types.LongText(), nullable=True))
batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False))
batch_op.alter_column('conversation_id',
existing_type=models.types.StringUUID(),
nullable=True)
batch_op.alter_column('message_id',
existing_type=models.types.StringUUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.alter_column('message_id',
existing_type=postgresql.UUID(),
nullable=False)
batch_op.alter_column('conversation_id',
existing_type=postgresql.UUID(),
nullable=False)
batch_op.drop_column('hit_count')
batch_op.drop_column('question')
else:
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.alter_column('message_id',
existing_type=models.types.StringUUID(),
nullable=False)
batch_op.alter_column('conversation_id',
existing_type=models.types.StringUUID(),
nullable=False)
batch_op.drop_column('hit_count')
batch_op.drop_column('question')
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.alter_column('message_id',
existing_type=models.types.StringUUID(),
nullable=False)
batch_op.alter_column('conversation_id',
existing_type=models.types.StringUUID(),
nullable=False)
batch_op.drop_column('hit_count')
batch_op.drop_column('question')
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.drop_column('type')

View File

@ -12,9 +12,6 @@ from sqlalchemy.dialects import postgresql
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = 'f2a6fc85e260'
down_revision = '46976cc39132'
@ -24,16 +21,9 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False))
batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
else:
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False))
batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_id', models.types.StringUUID(), nullable=False))
batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False)
# ### end Alembic commands ###

View File

@ -8,7 +8,7 @@ from uuid import uuid4
import sqlalchemy as sa
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column
from sqlalchemy.orm import Mapped, Session, mapped_column, validates
from typing_extensions import deprecated
from .base import TypeBase
@ -116,6 +116,12 @@ class Account(UserMixin, TypeBase):
role: TenantAccountRole | None = field(default=None, init=False)
_current_tenant: "Tenant | None" = field(default=None, init=False)
@validates("status")
def _normalize_status(self, _key: str, value: str | AccountStatus) -> str:
if isinstance(value, AccountStatus):
return value.value
return value
@property
def is_password_set(self):
return self.password is not None

View File

@ -16,6 +16,11 @@ celery_redis = Redis(
port=redis_config.get("port") or 6379,
password=redis_config.get("password") or None,
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
ssl=bool(dify_config.BROKER_USE_SSL),
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None,
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
)
logger = logging.getLogger(__name__)

View File

@ -1,3 +1,4 @@
import json
import logging
import os
from collections.abc import Sequence
@ -31,6 +32,11 @@ class BillingService:
compliance_download_rate_limiter = RateLimiter("compliance_download_rate_limiter", 4, 60)
# Redis key prefix for tenant plan cache
_PLAN_CACHE_KEY_PREFIX = "tenant_plan:"
# Cache TTL: 10 minutes
_PLAN_CACHE_TTL = 600
@classmethod
def get_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
@ -272,14 +278,110 @@ class BillingService:
data = resp.get("data", {})
for tenant_id, plan in data.items():
subscription_plan = subscription_adapter.validate_python(plan)
results[tenant_id] = subscription_plan
try:
subscription_plan = subscription_adapter.validate_python(plan)
results[tenant_id] = subscription_plan
except Exception:
logger.exception(
"get_plan_bulk: failed to validate subscription plan for tenant(%s)", tenant_id
)
continue
except Exception:
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
logger.exception("get_plan_bulk: failed to fetch billing info batch for tenants: %s", chunk)
continue
return results
@classmethod
def _make_plan_cache_key(cls, tenant_id: str) -> str:
return f"{cls._PLAN_CACHE_KEY_PREFIX}{tenant_id}"
@classmethod
def get_plan_bulk_with_cache(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
"""
Bulk fetch billing subscription plan with cache to reduce billing API loads in batch job scenarios.
NOTE: if you want to high data consistency, use get_plan_bulk instead.
Returns:
Mapping of tenant_id -> {plan: str, expiration_date: int}
"""
tenant_plans: dict[str, SubscriptionPlan] = {}
if not tenant_ids:
return tenant_plans
subscription_adapter = TypeAdapter(SubscriptionPlan)
# Step 1: Batch fetch from Redis cache using mget
redis_keys = [cls._make_plan_cache_key(tenant_id) for tenant_id in tenant_ids]
try:
cached_values = redis_client.mget(redis_keys)
if len(cached_values) != len(tenant_ids):
raise Exception(
"get_plan_bulk_with_cache: unexpected error: redis mget failed: cached values length mismatch"
)
# Map cached values back to tenant_ids
cache_misses: list[str] = []
for tenant_id, cached_value in zip(tenant_ids, cached_values):
if cached_value:
try:
# Redis returns bytes, decode to string and parse JSON
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
plan_dict = json.loads(json_str)
subscription_plan = subscription_adapter.validate_python(plan_dict)
tenant_plans[tenant_id] = subscription_plan
except Exception:
logger.exception(
"get_plan_bulk_with_cache: process tenant(%s) failed, add to cache misses", tenant_id
)
cache_misses.append(tenant_id)
else:
cache_misses.append(tenant_id)
logger.info(
"get_plan_bulk_with_cache: cache hits=%s, cache misses=%s",
len(tenant_plans),
len(cache_misses),
)
except Exception:
logger.exception("get_plan_bulk_with_cache: redis mget failed, falling back to API")
cache_misses = list(tenant_ids)
# Step 2: Fetch missing plans from billing API
if cache_misses:
bulk_plans = BillingService.get_plan_bulk(cache_misses)
if bulk_plans:
plans_to_cache: dict[str, SubscriptionPlan] = {}
for tenant_id, subscription_plan in bulk_plans.items():
tenant_plans[tenant_id] = subscription_plan
plans_to_cache[tenant_id] = subscription_plan
# Step 3: Batch update Redis cache using pipeline
if plans_to_cache:
try:
pipe = redis_client.pipeline()
for tenant_id, subscription_plan in plans_to_cache.items():
redis_key = cls._make_plan_cache_key(tenant_id)
# Serialize dict to JSON string
json_str = json.dumps(subscription_plan)
pipe.setex(redis_key, cls._PLAN_CACHE_TTL, json_str)
pipe.execute()
logger.info(
"get_plan_bulk_with_cache: cached %s new tenant plans to Redis",
len(plans_to_cache),
)
except Exception:
logger.exception("get_plan_bulk_with_cache: redis pipeline failed")
return tenant_plans
@classmethod
def get_expired_subscription_cleanup_whitelist(cls) -> Sequence[str]:
resp = cls._send_request("GET", "/subscription/cleanup/whitelist")

View File

@ -70,7 +70,6 @@ class ProviderResponse(BaseModel):
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: Sequence[ModelType]
@ -98,11 +97,6 @@ class ProviderResponse(BaseModel):
en_US=f"{url_prefix}/icon_small_dark/en_US",
zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans",
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
@ -116,7 +110,6 @@ class ProviderWithModelsResponse(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
@ -134,11 +127,6 @@ class ProviderWithModelsResponse(BaseModel):
self.icon_small_dark = I18nObject(
en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
@ -163,11 +151,6 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
self.icon_small_dark = I18nObject(
en_US=f"{url_prefix}/icon_small_dark/en_US", zh_Hans=f"{url_prefix}/icon_small_dark/zh_Hans"
)
if self.icon_large is not None:
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self

View File

@ -99,7 +99,6 @@ class ModelProviderService:
description=provider_configuration.provider.description,
icon_small=provider_configuration.provider.icon_small,
icon_small_dark=provider_configuration.provider.icon_small_dark,
icon_large=provider_configuration.provider.icon_large,
background=provider_configuration.provider.background,
help=provider_configuration.provider.help,
supported_model_types=provider_configuration.provider.supported_model_types,
@ -423,7 +422,6 @@ class ModelProviderService:
label=first_model.provider.label,
icon_small=first_model.provider.icon_small,
icon_small_dark=first_model.provider.icon_small_dark,
icon_large=first_model.provider.icon_large,
status=CustomConfigurationStatus.ACTIVE,
models=[
ProviderModelWithStatusEntity(
@ -488,7 +486,6 @@ class ModelProviderService:
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types,
),
)
@ -522,7 +519,7 @@ class ModelProviderService:
:param tenant_id: workspace id
:param provider: provider name
:param icon_type: icon type (icon_small or icon_large)
:param icon_type: icon type (icon_small or icon_small_dark)
:param lang: language (zh_Hans or en_US)
:return:
"""

View File

@ -7,7 +7,6 @@ from httpx import get
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.provider import ApiToolProviderController
@ -86,7 +85,9 @@ class ApiToolManageService:
raise ValueError(f"invalid schema: {str(e)}")
@staticmethod
def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
def convert_schema_to_tool_bundles(
schema: str, extra_info: dict | None = None
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
convert schema to tool bundles
@ -104,7 +105,7 @@ class ApiToolManageService:
provider_name: str,
icon: dict,
credentials: dict,
schema_type: str,
schema_type: ApiProviderSchemaType,
schema: str,
privacy_policy: str,
custom_disclaimer: str,
@ -113,9 +114,6 @@ class ApiToolManageService:
"""
create api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema}")
provider_name = provider_name.strip()
# check if the provider exists
@ -178,9 +176,6 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod
@ -245,18 +240,15 @@ class ApiToolManageService:
original_provider: str,
icon: dict,
credentials: dict,
schema_type: str,
_schema_type: ApiProviderSchemaType,
schema: str,
privacy_policy: str,
privacy_policy: str | None,
custom_disclaimer: str,
labels: list[str],
):
"""
update api tool provider
"""
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema}")
provider_name = provider_name.strip()
# check if the provider exists
@ -281,7 +273,7 @@ class ApiToolManageService:
provider.icon = json.dumps(icon)
provider.schema = schema
provider.description = extra_info.get("description", "")
provider.schema_type_str = ApiProviderSchemaType.OPENAPI
provider.schema_type_str = schema_type
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
provider.privacy_policy = privacy_policy
provider.custom_disclaimer = custom_disclaimer
@ -322,9 +314,6 @@ class ApiToolManageService:
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod
@ -347,9 +336,6 @@ class ApiToolManageService:
db.session.delete(provider)
db.session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod
@ -366,7 +352,7 @@ class ApiToolManageService:
tool_name: str,
credentials: dict,
parameters: dict,
schema_type: str,
schema_type: ApiProviderSchemaType,
schema: str,
):
"""

View File

@ -12,7 +12,6 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.name_generator import generate_incremental_name
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.helper.tool_provider_cache import ToolProviderListCache
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -205,9 +204,6 @@ class BuiltinToolManageService:
db_provider.name = name
session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e:
session.rollback()
raise ValueError(str(e))
@ -290,8 +286,6 @@ class BuiltinToolManageService:
session.rollback()
raise ValueError(str(e))
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id, "builtin")
return {"result": "success"}
@staticmethod
@ -409,9 +403,6 @@ class BuiltinToolManageService:
)
cache.delete()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod
@ -434,8 +425,6 @@ class BuiltinToolManageService:
target_provider.is_default = True
session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@staticmethod

View File

@ -1,6 +1,5 @@
import logging
from core.helper.tool_provider_cache import ToolProviderListCache
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
from services.tools.tools_transform_service import ToolTransformService
@ -16,14 +15,6 @@ class ToolCommonService:
:return: the list of tool providers
"""
# Try to get from cache first
cached_result = ToolProviderListCache.get_cached_providers(tenant_id, typ)
if cached_result is not None:
logger.debug("Returning cached tool providers for tenant %s, type %s", tenant_id, typ)
return cached_result
# Cache miss - fetch from database
logger.debug("Cache miss for tool providers, fetching from database for tenant %s, type %s", tenant_id, typ)
providers = ToolManager.list_providers_from_api(user_id, tenant_id, typ)
# add icon
@ -32,7 +23,4 @@ class ToolCommonService:
result = [provider.to_dict() for provider in providers]
# Cache the result
ToolProviderListCache.set_cached_providers(tenant_id, typ, result)
return result

View File

@ -5,9 +5,8 @@ from datetime import datetime
from typing import Any
from sqlalchemy import or_, select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
from core.helper.tool_provider_cache import ToolProviderListCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
@ -86,17 +85,13 @@ class WorkflowToolManageService:
except Exception as e:
raise ValueError(str(e))
with session_factory.create_session() as session, session.begin():
with Session(db.engine, expire_on_commit=False) as session, session.begin():
session.add(workflow_tool_provider)
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@classmethod
@ -184,9 +179,6 @@ class WorkflowToolManageService:
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@classmethod
@ -249,9 +241,6 @@ class WorkflowToolManageService:
db.session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
@classmethod

View File

@ -48,10 +48,6 @@ class MockModelClass(PluginModelClient):
en_US="https://example.com/icon_small.png",
zh_Hans="https://example.com/icon_small.png",
),
icon_large=I18nObject(
en_US="https://example.com/icon_large.png",
zh_Hans="https://example.com/icon_large.png",
),
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
models=[

View File

@ -0,0 +1,365 @@
import json
from unittest.mock import patch
import pytest
from extensions.ext_redis import redis_client
from services.billing_service import BillingService
class TestBillingServiceGetPlanBulkWithCache:
"""
Comprehensive integration tests for get_plan_bulk_with_cache using testcontainers.
This test class covers all major scenarios:
- Cache hit/miss scenarios
- Redis operation failures and fallback behavior
- Invalid cache data handling
- TTL expiration handling
- Error recovery and logging
"""
@pytest.fixture(autouse=True)
def setup_redis_cleanup(self, flask_app_with_containers):
"""Clean up Redis cache before and after each test."""
with flask_app_with_containers.app_context():
# Clean up before test
yield
# Clean up after test
# Delete all test cache keys
pattern = f"{BillingService._PLAN_CACHE_KEY_PREFIX}*"
keys = redis_client.keys(pattern)
if keys:
redis_client.delete(*keys)
def _create_test_plan_data(self, plan: str = "sandbox", expiration_date: int = 1735689600):
"""Helper to create test SubscriptionPlan data."""
return {"plan": plan, "expiration_date": expiration_date}
def _set_cache(self, tenant_id: str, plan_data: dict, ttl: int = 600):
"""Helper to set cache data in Redis."""
cache_key = BillingService._make_plan_cache_key(tenant_id)
json_str = json.dumps(plan_data)
redis_client.setex(cache_key, ttl, json_str)
def _get_cache(self, tenant_id: str):
"""Helper to get cache data from Redis."""
cache_key = BillingService._make_plan_cache_key(tenant_id)
value = redis_client.get(cache_key)
if value:
if isinstance(value, bytes):
return value.decode("utf-8")
return value
return None
def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers):
"""Test bulk plan retrieval when all tenants are in cache."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
expected_plans = {
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
"tenant-2": self._create_test_plan_data("professional", 1767225600),
"tenant-3": self._create_test_plan_data("team", 1798761600),
}
# Pre-populate cache
for tenant_id, plan_data in expected_plans.items():
self._set_cache(tenant_id, plan_data)
# Act
with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 3
assert result["tenant-1"]["plan"] == "sandbox"
assert result["tenant-1"]["expiration_date"] == 1735689600
assert result["tenant-2"]["plan"] == "professional"
assert result["tenant-2"]["expiration_date"] == 1767225600
assert result["tenant-3"]["plan"] == "team"
assert result["tenant-3"]["expiration_date"] == 1798761600
# Verify API was not called
mock_get_plan_bulk.assert_not_called()
def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers):
"""Test bulk plan retrieval when all tenants are not in cache."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2"]
expected_plans = {
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
"tenant-2": self._create_test_plan_data("professional", 1767225600),
}
# Act
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 2
assert result["tenant-1"]["plan"] == "sandbox"
assert result["tenant-2"]["plan"] == "professional"
# Verify API was called with correct tenant_ids
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
# Verify data was written to cache
cached_1 = self._get_cache("tenant-1")
cached_2 = self._get_cache("tenant-2")
assert cached_1 is not None
assert cached_2 is not None
# Verify cache content
cached_data_1 = json.loads(cached_1)
cached_data_2 = json.loads(cached_2)
assert cached_data_1 == expected_plans["tenant-1"]
assert cached_data_2 == expected_plans["tenant-2"]
# Verify TTL is set
cache_key_1 = BillingService._make_plan_cache_key("tenant-1")
ttl_1 = redis_client.ttl(cache_key_1)
assert ttl_1 > 0
assert ttl_1 <= 600 # Should be <= 600 seconds
def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers):
"""Test bulk plan retrieval when some tenants are in cache, some are not."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
# Pre-populate cache for tenant-1 and tenant-2
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
self._set_cache("tenant-2", self._create_test_plan_data("professional", 1767225600))
# tenant-3 is not in cache
missing_plan = {"tenant-3": self._create_test_plan_data("team", 1798761600)}
# Act
with patch.object(BillingService, "get_plan_bulk", return_value=missing_plan) as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 3
assert result["tenant-1"]["plan"] == "sandbox"
assert result["tenant-2"]["plan"] == "professional"
assert result["tenant-3"]["plan"] == "team"
# Verify API was called only for missing tenant
mock_get_plan_bulk.assert_called_once_with(["tenant-3"])
# Verify tenant-3 data was written to cache
cached_3 = self._get_cache("tenant-3")
assert cached_3 is not None
cached_data_3 = json.loads(cached_3)
assert cached_data_3 == missing_plan["tenant-3"]
def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers):
"""Test fallback to API when Redis mget fails."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2"]
expected_plans = {
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
"tenant-2": self._create_test_plan_data("professional", 1767225600),
}
# Act
with (
patch.object(redis_client, "mget", side_effect=Exception("Redis connection error")),
patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk,
):
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 2
assert result["tenant-1"]["plan"] == "sandbox"
assert result["tenant-2"]["plan"] == "professional"
# Verify API was called for all tenants (fallback)
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
# Verify data was written to cache after fallback
cached_1 = self._get_cache("tenant-1")
cached_2 = self._get_cache("tenant-2")
assert cached_1 is not None
assert cached_2 is not None
def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers):
"""Test fallback to API when cache contains invalid JSON."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
# Set valid cache for tenant-1
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
# Set invalid JSON for tenant-2
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
redis_client.setex(cache_key_2, 600, "invalid json {")
# tenant-3 is not in cache
expected_plans = {
"tenant-2": self._create_test_plan_data("professional", 1767225600),
"tenant-3": self._create_test_plan_data("team", 1798761600),
}
# Act
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 3
assert result["tenant-1"]["plan"] == "sandbox" # From cache
assert result["tenant-2"]["plan"] == "professional" # From API (fallback)
assert result["tenant-3"]["plan"] == "team" # From API
# Verify API was called for tenant-2 and tenant-3
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
# Verify tenant-2's invalid JSON was replaced with correct data in cache
cached_2 = self._get_cache("tenant-2")
assert cached_2 is not None
cached_data_2 = json.loads(cached_2)
assert cached_data_2 == expected_plans["tenant-2"]
assert cached_data_2["plan"] == "professional"
assert cached_data_2["expiration_date"] == 1767225600
# Verify tenant-2 cache has correct TTL
cache_key_2_new = BillingService._make_plan_cache_key("tenant-2")
ttl_2 = redis_client.ttl(cache_key_2_new)
assert ttl_2 > 0
assert ttl_2 <= 600
# Verify tenant-3 data was also written to cache
cached_3 = self._get_cache("tenant-3")
assert cached_3 is not None
cached_data_3 = json.loads(cached_3)
assert cached_data_3 == expected_plans["tenant-3"]
def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers):
"""Test fallback to API when cache data doesn't match SubscriptionPlan schema."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2", "tenant-3"]
# Set valid cache for tenant-1
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600))
# Set invalid plan data for tenant-2 (missing expiration_date)
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
invalid_data = json.dumps({"plan": "professional"}) # Missing expiration_date
redis_client.setex(cache_key_2, 600, invalid_data)
# tenant-3 is not in cache
expected_plans = {
"tenant-2": self._create_test_plan_data("professional", 1767225600),
"tenant-3": self._create_test_plan_data("team", 1798761600),
}
# Act
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 3
assert result["tenant-1"]["plan"] == "sandbox" # From cache
assert result["tenant-2"]["plan"] == "professional" # From API (fallback)
assert result["tenant-3"]["plan"] == "team" # From API
# Verify API was called for tenant-2 and tenant-3
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers):
"""Test that pipeline failure doesn't affect return value."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2"]
expected_plans = {
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
"tenant-2": self._create_test_plan_data("professional", 1767225600),
}
# Act
with (
patch.object(BillingService, "get_plan_bulk", return_value=expected_plans),
patch.object(redis_client, "pipeline") as mock_pipeline,
):
# Create a mock pipeline that fails on execute
mock_pipe = mock_pipeline.return_value
mock_pipe.execute.side_effect = Exception("Pipeline execution failed")
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert - Function should still return correct result despite pipeline failure
assert len(result) == 2
assert result["tenant-1"]["plan"] == "sandbox"
assert result["tenant-2"]["plan"] == "professional"
# Verify pipeline was attempted
mock_pipeline.assert_called_once()
def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers):
"""Test with empty tenant_ids list."""
with flask_app_with_containers.app_context():
# Act
with patch.object(BillingService, "get_plan_bulk") as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache([])
# Assert
assert result == {}
assert len(result) == 0
# Verify no API calls
mock_get_plan_bulk.assert_not_called()
# Verify no Redis operations (mget with empty list would return empty list)
# But we should check that mget was not called at all
# Since we can't easily verify this without more mocking, we just verify the result
def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers):
"""Test that expired cache keys are treated as cache misses."""
with flask_app_with_containers.app_context():
# Arrange
tenant_ids = ["tenant-1", "tenant-2"]
# Set cache for tenant-1 with very short TTL (1 second) to simulate expiration
self._set_cache("tenant-1", self._create_test_plan_data("sandbox", 1735689600), ttl=1)
# Wait for TTL to expire (key will be deleted by Redis)
import time
time.sleep(2)
# Verify cache is expired (key doesn't exist)
cache_key_1 = BillingService._make_plan_cache_key("tenant-1")
exists = redis_client.exists(cache_key_1)
assert exists == 0 # Key doesn't exist (expired)
# tenant-2 is not in cache
expected_plans = {
"tenant-1": self._create_test_plan_data("sandbox", 1735689600),
"tenant-2": self._create_test_plan_data("professional", 1767225600),
}
# Act
with patch.object(BillingService, "get_plan_bulk", return_value=expected_plans) as mock_get_plan_bulk:
result = BillingService.get_plan_bulk_with_cache(tenant_ids)
# Assert
assert len(result) == 2
assert result["tenant-1"]["plan"] == "sandbox"
assert result["tenant-2"]["plan"] == "professional"
# Verify API was called for both tenants (tenant-1 expired, tenant-2 missing)
mock_get_plan_bulk.assert_called_once_with(tenant_ids)
# Verify both were written to cache with correct TTL
cache_key_1_new = BillingService._make_plan_cache_key("tenant-1")
cache_key_2 = BillingService._make_plan_cache_key("tenant-2")
ttl_1_new = redis_client.ttl(cache_key_1_new)
ttl_2 = redis_client.ttl(cache_key_2)
assert ttl_1_new > 0
assert ttl_1_new <= 600
assert ttl_2 > 0
assert ttl_2 <= 600

View File

@ -228,7 +228,6 @@ class TestModelProviderService:
mock_provider_entity.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"}
mock_provider_entity.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}
mock_provider_entity.icon_small_dark = None
mock_provider_entity.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}
mock_provider_entity.background = "#FF6B6B"
mock_provider_entity.help = None
mock_provider_entity.supported_model_types = [ModelType.LLM, ModelType.TEXT_EMBEDDING]
@ -302,7 +301,6 @@ class TestModelProviderService:
mock_provider_entity_llm.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"}
mock_provider_entity_llm.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}
mock_provider_entity_llm.icon_small_dark = None
mock_provider_entity_llm.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}
mock_provider_entity_llm.background = "#FF6B6B"
mock_provider_entity_llm.help = None
mock_provider_entity_llm.supported_model_types = [ModelType.LLM]
@ -316,7 +314,6 @@ class TestModelProviderService:
mock_provider_entity_embedding.description = {"en_US": "Cohere provider", "zh_Hans": "Cohere 提供商"}
mock_provider_entity_embedding.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}
mock_provider_entity_embedding.icon_small_dark = None
mock_provider_entity_embedding.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}
mock_provider_entity_embedding.background = "#4ECDC4"
mock_provider_entity_embedding.help = None
mock_provider_entity_embedding.supported_model_types = [ModelType.TEXT_EMBEDDING]
@ -419,7 +416,6 @@ class TestModelProviderService:
provider="openai",
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"),
icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"),
supported_model_types=[ModelType.LLM],
configurate_methods=[],
models=[],
@ -431,7 +427,6 @@ class TestModelProviderService:
provider="openai",
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"),
icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"),
supported_model_types=[ModelType.LLM],
configurate_methods=[],
models=[],
@ -655,7 +650,6 @@ class TestModelProviderService:
provider="openai",
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"),
icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"),
supported_model_types=[ModelType.LLM],
),
)
@ -1027,7 +1021,6 @@ class TestModelProviderService:
label={"en_US": "OpenAI", "zh_Hans": "OpenAI"},
icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"},
icon_small_dark=None,
icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"},
),
model="gpt-3.5-turbo",
model_type=ModelType.LLM,
@ -1045,7 +1038,6 @@ class TestModelProviderService:
label={"en_US": "OpenAI", "zh_Hans": "OpenAI"},
icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"},
icon_small_dark=None,
icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"},
),
model="gpt-4",
model_type=ModelType.LLM,

View File

@ -0,0 +1,69 @@
import builtins
from types import SimpleNamespace
from unittest.mock import patch
from flask.views import MethodView as FlaskMethodView
_NEEDS_METHOD_VIEW_CLEANUP = False
if not hasattr(builtins, "MethodView"):
builtins.MethodView = FlaskMethodView
_NEEDS_METHOD_VIEW_CLEANUP = True
from controllers.common.fields import Parameters, Site
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import IconType
def test_parameters_model_round_trip():
parameters = get_parameters_from_feature_dict(features_dict={}, user_input_form=[])
model = Parameters.model_validate(parameters)
assert model.model_dump(mode="json") == parameters
def test_site_icon_url_uses_signed_url_for_image_icon():
site = SimpleNamespace(
title="Example",
chat_color_theme=None,
chat_color_theme_inverted=False,
icon_type=IconType.IMAGE,
icon="file-id",
icon_background=None,
description=None,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
default_language="en-US",
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
with patch("controllers.common.fields.file_helpers.get_signed_file_url", return_value="signed") as mock_helper:
model = Site.model_validate(site)
assert model.icon_url == "signed"
mock_helper.assert_called_once_with("file-id")
def test_site_icon_url_is_none_for_non_image_icon():
site = SimpleNamespace(
title="Example",
chat_color_theme=None,
chat_color_theme_inverted=False,
icon_type=IconType.EMOJI,
icon="file-id",
icon_background=None,
description=None,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
default_language="en-US",
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
with patch("controllers.common.fields.file_helpers.get_signed_file_url") as mock_helper:
model = Site.model_validate(site)
assert model.icon_url is None
mock_helper.assert_not_called()

Some files were not shown because too many files have changed in this diff Show More